/*
 * Dropbear - a SSH2 server
 * 
 * Copyright (c) 2002,2003 Matt Johnston
 * All rights reserved.
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE. */

#include "includes.h"
#include "session.h"
#include "dbutil.h"
#include "packet.h"
#include "algo.h"
#include "buffer.h"
#include "dss.h"
#include "ssh.h"
#include "random.h"
#include "kex.h"
#include "channel.h"
#include "atomicio.h"

static void checktimeouts();
static int ident_readln(int fd, char* buf, int count);

/* Chained list of session globals */
session_globals_t *session_globals = NULL;
/* Key for thread specific varibale access */
pthread_key_t session_globals_key;

void session_globals_init()
{
	pthread_key_create(&session_globals_key, NULL);
}

void session_globals_deinit()
{
	session_globals_t *g = session_globals;
	session_globals_t *next;
	while (g) {
		next = g->next;
		free(g);
		g = next;
	}
	session_globals = NULL;
	pthread_key_delete(session_globals_key);
}

/* Create a set of session globals for thread 'thread_id' */
session_globals_t *new_session_globals(pthread_t thread_id)
{
	session_globals_t *g;

	if (session_globals == NULL) {
		g = session_globals = malloc(sizeof(session_globals_t));
	} else {
		g = session_globals;
		while (g->next) g = g->next;
		g->next = malloc(sizeof(session_globals_t));
		g = g->next;
	}
	g->thread_id = thread_id;
	g->sessinitdone = 0;
	g->authenticated = 0;
	g->terminated = 0;
	g->next = NULL;
	return g;
}

/* Remove session globals for thread 'thread_id' from list */
void delete_session_globals(pthread_t thread_id)
{
	session_globals_t *g, *g_last;

	assert(session_globals);

	g = session_globals;
	g_last = NULL;
	do {
	    if (pthread_equal(thread_id, g->thread_id)) {
		break;
	    }
	    g_last = g;
	} while ((g = g->next));

	assert(g);

	if (g == session_globals) {
	    session_globals = g->next;
	    free(g);
	} else {
	    g_last->next = g->next;
	    free(g);	    
	}
}

/* called only at the start of a session, set up initial state */
void common_session_init(int sock, char* remotehost) {
	GLOBALS(g);

	TRACE(("enter session_init"))
	g->ses.remotehost = remotehost;

	g->ses.sock = sock;
	g->ses.maxfd = sock;

	g->ses.connecttimeout = 0;
	
	kexfirstinitialise(); /* initialise the kex state */
	chaninitialise(); /* initialise the channel state */

	g->ses.writepayload = buf_new(MAX_TRANS_PAYLOAD_LEN);
	g->ses.transseq = 0;

	g->ses.readbuf = NULL;
	g->ses.decryptreadbuf = NULL;
	g->ses.payload = NULL;
	g->ses.recvseq = 0;

	initqueue(&g->ses.writequeue);

	g->ses.requirenext = SSH_MSG_KEXINIT;
	g->ses.dataallowed = 0; /* don't send data yet, we'll wait until after kex */
	g->ses.ignorenext = 0;
	g->ses.lastpacket = 0;

	/* set all the algos to none */
	g->ses.keys = (struct key_context*)m_malloc(sizeof(struct key_context));
	g->ses.newkeys = NULL;
	g->ses.keys->recv_algo_crypt = &dropbear_nocipher;
	g->ses.keys->trans_algo_crypt = &dropbear_nocipher;
	
	g->ses.keys->recv_algo_mac = &dropbear_nohash;
	g->ses.keys->trans_algo_mac = &dropbear_nohash;

	g->ses.keys->algo_kex = -1;
	g->ses.keys->algo_hostkey = -1;
	g->ses.keys->recv_algo_comp = DROPBEAR_COMP_NONE;
	g->ses.keys->trans_algo_comp = DROPBEAR_COMP_NONE;

#ifndef DISABLE_ZLIB
	g->ses.keys->recv_zstream = NULL;
	g->ses.keys->trans_zstream = NULL;
#endif

	/* key exchange buffers */
	g->ses.session_id = NULL;
	g->ses.kexhashbuf = NULL;
	g->ses.transkexinit = NULL;
	g->ses.dh_K = NULL;
	g->ses.remoteident = NULL;

	g->ses.chantypes = NULL;

	g->ses.allowprivport = 0;
	g->ses.authstate.pw = NULL;

	TRACE(("leave session_init"))
}

void session_loop(void(*loophandler)()) {

	fd_set readfd, writefd;
	struct timeval timeout;
	int val;
	GLOBALS(g);

	/* main loop, select()s for all sockets in use */
	for(;;) {
		timeout.tv_sec = SELECT_TIMEOUT;
		timeout.tv_usec = 0;
		FD_ZERO(&writefd);
		FD_ZERO(&readfd);
		assert(g->ses.payload == NULL);
		if (g->ses.sock != -1) {
			FD_SET(g->ses.sock, &readfd);
			if (!isempty(&g->ses.writequeue)) {
				FD_SET(g->ses.sock, &writefd);
			}
		}

		/* set up for channels which require reading/writing */
		if (g->ses.dataallowed) {
			setchannelfds(&readfd, &writefd);
		}

		val = select(g->ses.maxfd+1, &readfd, &writefd, NULL, &timeout);

		if (val < 0) {
			if (errno == EINTR) {
				/* This must happen even if we've been interrupted, so that
				 * changed signal-handler vars can take effect etc */
				if (loophandler) {
					loophandler();
				}
				continue;
			} else {
				dropbear_exit("Error in select");
			}
		}

		/* check for auth timeout, rekeying required etc */
		checktimeouts();
		
		if (val == 0) {
			/* timeout */
			TRACE(("select timeout"))
			continue;
		}

		/* process session socket's incoming/outgoing data */
		if (g->ses.sock != -1) {
			if (FD_ISSET(g->ses.sock, &writefd) && !isempty(&g->ses.writequeue)) {
				write_packet();
			}

			if (FD_ISSET(g->ses.sock, &readfd)) {
				read_packet();
			}
			
			/* Process the decrypted packet. After this, the read buffer
			 * will be ready for a new packet */
			if (g->ses.payload != NULL) {
				process_packet();
			}
		}

		/* process pipes etc for the channels, ses.dataallowed == 0
		 * during rekeying ) */
		if (g->ses.dataallowed) {
			channelio(&readfd, &writefd);
		}

		if (loophandler) {
			loophandler();
		}

	} /* for(;;) */
	
	/* Not reached */
}

/* clean up a session on exit */
void common_session_cleanup() {
	GLOBALS(g);
	TRACE(("enter session_cleanup"))
	
	/* we can't cleanup if we don't know the session state */
	if (!g->sessinitdone) {
		TRACE(("leave session_cleanup: !sessinitdone"))
		return;
	}
	
	m_free(g->ses.authstate.pw);
	m_free(g->ses.session_id);
	m_burn(g->ses.keys, sizeof(struct key_context));
	m_free(g->ses.keys);

	chancleanup();

	TRACE(("leave session_cleanup"))
}


void session_identification() {

	/* max length of 255 chars */
	char linebuf[256];
	int len = 0;
	char done = 0;
	int i;
	GLOBALS(g);

	/* write our version string, this blocks */
	if (atomicio(write, g->ses.sock, LOCAL_IDENT "\r\n",
				strlen(LOCAL_IDENT "\r\n")) == DROPBEAR_FAILURE) {
		dropbear_exit("Error writing ident string");
	}

	/* We allow up to 9 lines before the actual version string, to
	 * account for wrappers/cruft etc. According to the spec only the client
	 * needs to handle this, but no harm in letting the server handle it too */
	for (i = 0; i < 10; i++) {
		len = ident_readln(g->ses.sock, linebuf, sizeof(linebuf));

		if (len < 0 && errno != EINTR) {
			/* It failed */
			break;
		}

		if (len >= 4 && memcmp(linebuf, "SSH-", 4) == 0) {
			/* start of line matches */
			done = 1;
			break;
		}
	}

	if (!done) {
		TRACE(("err: %s for '%s'\n", strerror(errno), linebuf))
		dropbear_exit("Failed to get remote version");
	} else {
		/* linebuf is already null terminated */
		g->ses.remoteident = m_malloc(len);
		memcpy(g->ses.remoteident, linebuf, len);
	}

	TRACE(("remoteident: %s", g->ses.remoteident))

}

/* returns the length including null-terminating zero on success,
 * or -1 on failure */
static int ident_readln(int fd, char* buf, int count) {
	
	char in;
	int pos = 0;
	int num = 0;
	fd_set fds;
	struct timeval timeout;

	TRACE(("enter ident_readln"))

	if (count < 1) {
		return -1;
	}

	FD_ZERO(&fds);

	/* select since it's a non-blocking fd */
	
	/* leave space to null-terminate */
	while (pos < count-1) {

		FD_SET(fd, &fds);

		timeout.tv_sec = 1;
		timeout.tv_usec = 0;
		if (select(fd+1, &fds, NULL, NULL, &timeout) < 0) {
			if (errno == EINTR) {
				continue;
			}
			TRACE(("leave ident_readln: select error"))
			return -1;
		}

		checktimeouts();
		
		/* Have to go one byte at a time, since we don't want to read past
		 * the end, and have to somehow shove bytes back into the normal
		 * packet reader */
		if (FD_ISSET(fd, &fds)) {
			num = read(fd, &in, 1);
			/* a "\n" is a newline, "\r" we want to read in and keep going
			 * so that it won't be read as part of the next line */
			if (num < 0) {
				/* error */
				if (errno == EINTR) {
					continue; /* not a real error */
				}
				TRACE(("leave ident_readln: read error"))
				return -1;
			}
			if (num == 0) {
				/* EOF */
				TRACE(("leave ident_readln: EOF"))
				return -1;
			}
			if (in == '\n') {
				/* end of ident string */
				break;
			}
			/* we don't want to include '\r's */
			if (in != '\r') {
				buf[pos] = in;
				pos++;
			}
		}
	}

	buf[pos] = '\0';
	TRACE(("leave ident_readln: return %d", pos+1))
	return pos+1;
}

/* Check all timeouts which are required. Currently these are the time for
 * user authentication, and the automatic rekeying. */
static void checktimeouts() {

	struct timeval tv;
	long secs;
	GLOBALS(g);

	if (gettimeofday(&tv, 0) < 0) {
		dropbear_exit("Error getting time");
	}

	secs = tv.tv_sec;
	
	if (g->ses.connecttimeout != 0 && secs > g->ses.connecttimeout) {
			dropbear_close("Timeout before auth");
	}

	/* we can't rekey if we haven't done remote ident exchange yet */
	if (g->ses.remoteident == NULL) {
		return;
	}

	if (!g->ses.kexstate.sentkexinit
			&& (secs - g->ses.kexstate.lastkextime >= KEX_REKEY_TIMEOUT
			|| g->ses.kexstate.datarecv+g->ses.kexstate.datatrans >= KEX_REKEY_DATA)){
		TRACE(("rekeying after timeout or max data reached"))
		send_msg_kexinit();
	}
}

