/**	
 *	@file	NetConn_SSL.cpp
 *	@brief	CNetConn derived class for encrypted socket communications
 * 
 * 	Implementation of the CNetConn_SSL class.
 *	Communications channel through a SSL socket
 *
 */

#include	"pp/NetConn_SSL.h"
#include	"assert.h"

/*----------------------------------------
 *	Equates
 *--------------------------------------*/

	// Timeout values (in milliseconds)

//#define NOTIMEOUT	// for debugging only

#define	MINUTES					60000
#define	SELECT_TIMEOUT			15000		// Must be evenly divisable by a thousand
#define	SSL_IO_TIMEOUT			4*MINUTES
#define	SSL_SHUTDOWN_TIMEOUT	1*MINUTES

	// misc.

#ifdef OS_POSIX
#define	S_ADDR		s_addr
#define	FD_SET_T	fd_set
#define	SOCKADDR	struct sockaddr
#else
#define	S_ADDR		S_un.S_addr
#define	FD_SET_T	struct fd_set
#define	socklen_t	int
#define	SHUT_RDWR	2
#endif

	// Private Error Codes

enum
{
	NETCONN_SUCCESS				= 0,	// No error
	NETCONN_ERROR_TIMEOUT,				// Timeout was reached
	NETCONN_ERROR_UNKNOWN,				// Some other socket error
};

/*----------------------------------------
 *	Data Types
 *--------------------------------------*/

/*----------------------------------------
 *	Function Prototypes
 *--------------------------------------*/

void ssl_locking_callback ( int mode, int type, const char *file, int line );
unsigned long ssl_thread_id(void);

/*----------------------------------------
 *	static data
 *--------------------------------------*/

extern DWORD			dataIn;			// # of bytes received
extern DWORD			dataOut;		// # of bytes sent

	// SSL Library init data
static int				sslInitialized;	// !0 = we have called SSL_Library_Init
static OS_CRITICAL_SECTION *ssllock_cs;	// used by openssl lib
static long			*	ssllock_count;	// used by openssl lib
static const char 		*	pDefaultCipherList; // Default Cipher list used by CNetConn_SSL_CTX::CNetConn_SSL_CTX

/*----------------------------------------
 *	CNetConn_SSL Class
 *--------------------------------------*/

/*  --------------------------------------------------------------------*/
/** Constructors **/

CNetConn_SSL::CNetConn_SSL():
	CNetConn_Socket()
{
	pSSL = NULL;
}

CNetConn_SSL::CNetConn_SSL(SOCKET s) :
	CNetConn_Socket( s )
{
	pSSL = NULL;
}

CNetConn_SSL::CNetConn_SSL(CNetConn_Socket * s) :
	CNetConn_Socket( s->GetSocket() )
{
	pSSL = NULL;
	s->SetSocket( INVALID_SOCKET );
}

CNetConn_SSL::~CNetConn_SSL()
{
	Shutdown();
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

int CNetConn_SSL::SSL_Library_Init( const char * _pDefaultCipherList )
{
	// Initialize SSL

	if (!sslInitialized)
	{
		pDefaultCipherList = _pDefaultCipherList;

		// Init the library

		sslInitialized = 1;
		SSL_library_init();
		SSL_load_error_strings();

		// setup thread safe locks in SSL

		ssllock_cs= (OS_CRITICAL_SECTION *) OPENSSL_malloc(CRYPTO_num_locks() * sizeof(OS_CRITICAL_SECTION));
		ssllock_count= (long *) OPENSSL_malloc(CRYPTO_num_locks() * sizeof(long));
		for (int i=0; i<CRYPTO_num_locks(); i++)
		{
			ssllock_count[i]=0;
			ssllock_cs[i] = OS_CreateCriticalSection( OS_CRITICAL_SECTION_NORMAL );
		}

		CRYPTO_set_id_callback((unsigned long (*)())ssl_thread_id);
		CRYPTO_set_locking_callback((void (*)(int,int,const char*,int))ssl_locking_callback);
	}

	return 0;
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

void CNetConn_SSL::SSL_Shutdown()
{
	// Free ssl thread safe locks

	if (ssllock_cs != NULL)
	{
		CRYPTO_set_locking_callback(NULL);
		for (int i=0; i<CRYPTO_num_locks(); i++)
		{
				OS_DeleteCriticalSection(ssllock_cs[i]);
		}
		OPENSSL_free(ssllock_cs);
		ssllock_cs = NULL;
	}

	if (ssllock_count != NULL)
	{
		OPENSSL_free(ssllock_count);
	}
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

int CNetConn_SSL::SSL_Connect( CNetConn_SSL_CTX * ctx )
{
	int	error;

	assert(ctx->GetCTX() != NULL);

	if (ctx->GetCTX() == NULL)
		return -1;

	pSSL = SSL_new(ctx->GetCTX());

	assert(pSSL != NULL);

	if (pSSL == NULL)
		return -2;

	SSL_set_fd(pSSL, ioSocket);

	while (1)
	{

		error = SSL_connect(pSSL);

		//ERR_print_errors_fp(stdout);

		if (error <= 0)
		{
			error = SSL_get_error(pSSL,error);
			error = SSL_Wants_Something(error,SSL_IO_TIMEOUT);
			if (error == NETCONN_SUCCESS)
				continue;
			//DBLog1(("SSL WrConn Err %ld\n",error));
		}

		break;
	}

	if (error < 1)
	{
		fprintf(stderr,"SSL_connect failed\n");
		ERR_print_errors_fp(stderr);	// Debug
	}

	return error;
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

int CNetConn_SSL::SSL_Accept( CNetConn_SSL_CTX * ctx )
{
	int	error;

	assert(ctx->GetCTX() != NULL);

	if (ctx->GetCTX() == NULL)
		return -1;

	pSSL = SSL_new(ctx->GetCTX());

	assert(pSSL != NULL);

	if (pSSL == NULL)
		return -1;

	SSL_set_fd(pSSL, ioSocket);

	while (1)
	{

		error = SSL_accept(pSSL);

		//ERR_print_errors_fp(stdout);

		if (error <= 0)
		{
			error = SSL_get_error(pSSL,error);
			error = SSL_Wants_Something(error,SSL_IO_TIMEOUT);
			if (error == NETCONN_SUCCESS)
				continue;
			//DBLog1(("SSL WrConn Err %ld\n",error));
		}

		break;
	}

	if (error < 1)
	{
		fprintf(stderr,"SSL_accept failed\n");
		ERR_print_errors_fp(stderr);	// Debug
	}

	return error;
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

SSL * CNetConn_SSL::GetSSL( )
{
	return pSSL;
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

void CNetConn_SSL::Shutdown( )
{
	if (pSSL != NULL)
	{
		Do_SSL_shutdown();
		pSSL = NULL;
	}
}

/*  --------------------------------------------------------------------*/
/** See NetConn_Socket.h **/

BOOL CNetConn_SSL::Write(const void * _pData, int count)
{
	const BYTE		*	pData = (const BYTE *) _pData;
	int			result;
	DWORD			error;
	DWORD			retry = 6;

	assert(pSSL != NULL);

	// Write until all data has been written or there is an error

	while (count && running)
	{
		// write some data

		result = SSL_write( pSSL, (const char *) pData, count );

		// See if there was an error during the write

		if (result <= 0)
		{
			if (result == 0)
				break;	// Socket shutdown.

			error = SSL_get_error( pSSL, result );

			//DBLog2(("SSL_Write Result %d %d\n",result,error ));

			error = SSL_Wants_Something(error,SSL_IO_TIMEOUT);

			if (error == NETCONN_SUCCESS)
				continue;

			if (retry--)
				continue;

			//DBLog1(("SSL Read Err %ld errno %ld\n",error,errno));
			//DBLog1(("SSL Last socket error %d\n",OS_GetLastSocketError()));

			break;
		}

		// Account for the data we have read so far

		dataOut += result;

		count -= result;
		pData += result;

	}

	return count == 0 ;
}

/*  --------------------------------------------------------------------*/
/** See NetConn_Socket.h **/

BOOL CNetConn_SSL::Read(void * _pData, int count)
{
	BYTE		*	pData = (BYTE *) _pData;
	int				result;
	DWORD			error;
	DWORD			retry = 4;

	assert(pSSL != NULL);

	// Read until all data has been read or there is an error
	while (count && running)
	{
		// Read some data

		result = SSL_read( pSSL, (char *) pData, count );

		// See if there was an error during the read

		if (result <= 0)
		{
			if (result == 0)
				break;	// Socket shutdown.

			error = SSL_get_error( pSSL, result );

			//DBLog2(("SSL_Read Result %d %d\n",result,error ));

			error = SSL_Wants_Something(error,SSL_IO_TIMEOUT);

			if (error == NETCONN_SUCCESS)
				continue;

			if (retry--)
				continue;

			//DBLog1(("SSL Read Err %ld errno %ld\n",error,errno));
			//DBLog1(("SSL Last socket error %d\n",OS_GetLastSocketError()));

			break;
		}

		// Account for the data we have read so far

		dataIn += result;

		count -= result;
		pData += result;
	}

	return count == 0 ;
}



/*  --------------------------------------------------------------------*/
/** 
 *	An SSL call has returned an error code indicating that it needs to wait for
 *	a read or write... here we will use select() to wait for what it wants
 *
 *	@param	error		  	The error returned by the SSL_xxx function
 *	@param	timeout			The time out period
 *	@return					Returns an error code, 0 = success
 *							other errors are OS_GetLastSocketError values or NETCONN_ERROR_xxx
 */
int CNetConn_SSL::SSL_Wants_Something( int error, int timeout )
{
	int				result = 0;
	FD_SET_T		f_set;
	FD_SET_T		*f_read;
	FD_SET_T		*f_write;
	struct	timeval	selectTimeOut;

	// Compute time out

	timeout = timeout / SELECT_TIMEOUT;

//	DBLog2(("SSL_Wants %d %d %d\n", (int) s, error, (int) timeout )); // debug

	// Handle the SSL request

	if (error == SSL_ERROR_WANT_READ || error == SSL_ERROR_WANT_WRITE)
	{
		while (result == 0 && timeout && running)
		{
			FD_ZERO(&f_set);
			FD_SET( ioSocket, &f_set);

			selectTimeOut.tv_usec = 0;
			selectTimeOut.tv_sec  = SELECT_TIMEOUT / 1000;

			if (error == SSL_ERROR_WANT_READ)
			{
				f_read = &f_set;
				f_write = NULL;
			}
			else
			{
				f_write = &f_set;
				f_read = NULL;
			}

			// Wait for something to happen

//			DBLog2(("SSL_Select\n"));

			result = select( ioSocket+1, f_read, f_write, NULL, &selectTimeOut );

//			DBLog2(("SSL_Select done %d\n",result));

			timeout--;

		}

		if (result > 0)
		{
				error = NETCONN_SUCCESS;
		}
		else if (timeout == 0 || !running)
		{
			error = NETCONN_ERROR_TIMEOUT;
		}
		else
		{
			error = OS_GetLastSocketError();
			if (error == NETCONN_SUCCESS)
				error =  NETCONN_ERROR_UNKNOWN;
		}
	}

	return error;
}

/*  --------------------------------------------------------------------*/
/** 
 *	@brief	Utility function to take care of the details of calling SSL_shutdown
 *
 *	@return					SSL Error or NETCONN_SUCCESS
 */
int CNetConn_SSL::Do_SSL_shutdown()
{
	int	error = 0;
	int	result;

	while (1)
	{
		result = SSL_shutdown(pSSL);

		if (result <= 0)
		{
			error = SSL_get_error(pSSL,result);
			error = SSL_Wants_Something(error,SSL_SHUTDOWN_TIMEOUT);
			if (error == NETCONN_SUCCESS)
				continue;
		}

		break;
	}

	return error;
}

/*  --------------------------------------------------------------------*/
/** 
 *	@brief	Callback function used by OpenSSL to control locks
 *
 */
void ssl_locking_callback ( int mode, int type, const char *file, int line )
{
        
#ifdef OPENSSL_DEBUG
    fprintf(stderr,"thread=%4d mode=%s lock=%s %s:%d\n",
        CRYPTO_thread_id(),
        (mode&CRYPTO_LOCK)?"l":"u",
        (type&CRYPTO_READ)?"r":"w",file,line);
#else
    (void)file;
    (void)line;
#endif
/*
    if (CRYPTO_LOCK_SSL_CERT == type)
        fprintf(stderr,"(t,m,f,l) %ld %d %s %d\n",
        CRYPTO_thread_id(),
        mode,file,line);
*/
    if (mode & CRYPTO_LOCK)
    {
        OS_EnterCriticalSection(ssllock_cs[type]);
        ssllock_count[type]++;
    }
    else
    {
        OS_LeaveCriticalSection(ssllock_cs[type]);
    }
}

/*  --------------------------------------------------------------------*/
/** 
 *	@brief	Callback function used by OpenSSL to control locks
 *
 */
unsigned long ssl_thread_id(void)
{
    unsigned long ret;

    ret=(unsigned long)OS_GetCurrentThread();
    return(ret);
}


/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

CNetConn_SSL_CTX::CNetConn_SSL_CTX( const char * pPrivateKeyFile, const char * pCertFile, const char * pCipherList )
{
	sslCtx = NULL;
	sslMethod = NULL;

	if (pCipherList == NULL)
		pCipherList = pDefaultCipherList;

	// Create the context

	sslMethod = SSLv3_server_method();
	sslCtx = SSL_CTX_new(sslMethod);
	if (sslCtx == NULL)
	{
		ERR_print_errors_fp(stderr);	// Debug
		assert(0);
		return;
	}

	SSL_CTX_set_cipher_list(sslCtx,pCipherList);

	// Setup our certificates

	if (SSL_CTX_use_certificate_file(sslCtx, pCertFile, SSL_FILETYPE_PEM) <= 0)
	{
//		ERR_print_errors_fp(stderr);	// Debug
		assert(0);
		goto EXIT;
	}
	if (SSL_CTX_use_PrivateKey_file(sslCtx, pPrivateKeyFile, SSL_FILETYPE_PEM) <= 0)
	{
//		ERR_print_errors_fp(stderr);	//Debug
		assert(0);
		goto EXIT;
	}

	if (!SSL_CTX_check_private_key(sslCtx))
	{
		assert(0);
		goto EXIT;
	}
EXIT:
	return;
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

CNetConn_SSL_CTX::CNetConn_SSL_CTX( const char * pCipherList )
{
	sslCtx = NULL;
	sslMethod = NULL;
//		ERR_print_errors_fp(stderr);	// Debug

	if (pCipherList == NULL)
	{
		assert(pDefaultCipherList != NULL);
		pCipherList = pDefaultCipherList;
	}

	// Create the context

	sslMethod = SSLv3_client_method();
	sslCtx = SSL_CTX_new(sslMethod);
	if (sslCtx == NULL)
	{
		ERR_print_errors_fp(stderr);	// Debug
		assert(0);
		return;
	}

	SSL_CTX_set_cipher_list(sslCtx,pCipherList);

	assert( sslCtx != NULL );
}

/*  --------------------------------------------------------------------*/
/** See NetConn_SSL.h **/

SSL_CTX * CNetConn_SSL_CTX::GetCTX(  )
{
	return sslCtx;
}

