/*----------------------------------------------------------------------------

	CSCConnect.cpp

	Copyright (c) 2000, Raritan Computer, Inc.

	Class for connecting to CSC protocols

----------------------------------------------------------------------------*/

#include	<assert.h>
#include	<pp/syms.h>
#include	<pp/NetConn.h>
#include	<pp/CSCConnect.h>
#include	<pp/SXDB_Parse.h>
#include	<pp/SXDB_Parse_Table.h>

#if 0
#include    "Debug.h"
#else
#define		DBLog(a)
#define		DBLog2(a)
#endif

//----------------------------------------
//				Macros
//----------------------------------------

#define	MAX_CSC_INFO		1024

	// misc.

#ifdef OS_POSIX
#define	S_ADDR		s_addr
#define	FD_SET_T	fd_set
#define	SOCKADDR	struct sockaddr
#else
#define	S_ADDR		S_un.S_addr
#define	FD_SET_T	struct fd_set
#define	socklen_t	int
#define	SHUT_RDWR	2
#endif

//----------------------------------------
//				Data Types
//----------------------------------------

//----------------------------------------
//				Function Prototypes
//----------------------------------------

//----------------------------------------
//				Static Data
//----------------------------------------

static char	CSC_Discover[] = {	0x00,0x00,0x00,0x14,
							'<','C','S','C','_','D','i','s','c','o','v','e','r','/','>',
							0x00};
static int	CSC_Dis_Len = 20;

//----------------------------------------
//		CSC_Test1 Parse Table
//----------------------------------------

typedef	struct
{
	char	*pClearText;	
} CSC_Test1_DATA;

#define	PT_STRUCT	CSC_Test1_DATA
PT_BEGIN	( "CSC_Test1",		CSC_Test1_Table,	PT_NO_UNKNOWN )
PT_ATT		( "ClearText",		pClearText,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_Test3 Parse Table
//----------------------------------------

typedef	struct
{
	char	*pEncrypted;
} CSC_Test3_DATA;

#define	PT_STRUCT	CSC_Test3_DATA
PT_BEGIN	( "CSC_Test3",		CSC_Test3_Table,	PT_NO_UNKNOWN )
PT_ATT		( "Encrypted",		pEncrypted,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//				Code
//----------------------------------------

//----------------------------------------------------------------------------
//						Class
//----------------------------------------------------------------------------

//----------------------------------------------------------------------------
//
	CCSCConnect::CCSCConnect
	(
	)
//
//	Initialize
//
//----------------------------------------------------------------------------
{
	int result = CNetConn_SSL::SSL_Library_Init(SSL2_TXT_RC4_128_WITH_MD5);
	assert( result == 0 );
	(void)result;

	connected		= 0;
	isSSL			= 0;
	authenticated	= 0;
	netConn			= NULL;
	sslCtx 			= new CNetConn_SSL_CTX;
}

//----------------------------------------------------------------------------
//
	CCSCConnect::~CCSCConnect
	(
	)
//
//	Cleanup
//
//----------------------------------------------------------------------------
{
	Disconnect();
	delete sslCtx;
}

//----------------------------------------------------------------------------
//
	int									// 0= connected <0 = error
	CCSCConnect::Connect
	(
		int		ipAddress,
		short		tcpPort,
		const char	* NOTUSED(connectionID)
	)
//
//	Creates the TCP Connection
//
//----------------------------------------------------------------------------
{
	if (tcpPort == 0)
		tcpPort = 5000;

	CNetConn_Socket	*newNetConn = new CNetConn_Socket();
	netConn = (CNetConn *) newNetConn;
	int result = newNetConn->OpenSocket(ipAddress,tcpPort);

	if (result >= 0)
		connected = 1;

	return result;
}

//----------------------------------------------------------------------------
//
	char *								// Ptr to CSC_Info document or NULL
	CCSCConnect::GetCSCInfo
	(
	)
//
//	Does the CSC Handshake and gets the CSC_Info document.
//	delete the char * returned by this method when done with the document.
//
//----------------------------------------------------------------------------
{
	const char		*pCSC_Ack = "<CSC_Ack/>";
	char		string[200];
	int			len,nLen;
	char		*pCSC_Info = NULL;
	int			result;
	
	// ----------------------------------------
	// CSC Negotiate Protocol

	// Read CSC

	netConn->Read( &nLen, 4 );

	len = ntohl(nLen) - 4;

	if (len > 200)
		return NULL;

	netConn->Read( string, len );

	// strstr

	// Write CSC_Ack

	len = strlen(pCSC_Ack)+1;
	nLen = htonl(len+4);

	netConn->Write( &nLen, 4 );
	netConn->Write( pCSC_Ack, len );

	// Read CSC_Info

	netConn->Read( &nLen, 4 );

	len = ntohl(nLen) - 4;

	pCSC_Info = new char[len];

	if (pCSC_Info == NULL)
		return NULL;

	result = netConn->Read( pCSC_Info, len );

	if (result <0)
	{
		delete [] pCSC_Info;
		return NULL;
	}

	// done

	return pCSC_Info;
}

//----------------------------------------------------------------------------
//
	int									// <0 = error
	CCSCConnect::StartNewCSCSession
	(
		const char	*	protocol,		// Protocol ID
		const char	*	userName,		// User name
		const char	*	password,		// password
		const char	*	encryption,		// Encryption type
		const char	*	authentication		// Authentication type
	)
//
//	Sends the CSC_Start_Session, does the authentication phase and the authentication
//	phase.
//
//----------------------------------------------------------------------------
{
	char		string[200];
	int			len,nLen;
	//char		*pCSC_Info = NULL;
	int			result;

	// ----------------------------------------
	// Write CSC_Start

	DBLog2(("CCSCConnect::StartNewCSCSession(): -- username(%s), passwd(%s)\n", userName, password));

	sprintf(string,"<CSC_Start_Session ProtocolID=\"%s\"/>",protocol);

	len = strlen(string)+1;
	nLen = htonl(len+4);

	netConn->Write( &nLen, 4 );
	netConn->Write( string, len );

	// ----------------------------------------
	// CSC Encryption Phase

	if (strcmp(encryption,"SSL")==0)
	{
		// SSL Encryption

		CNetConn_SSL * pSSL = new CNetConn_SSL( (CNetConn_Socket *) netConn);
		delete netConn;
		netConn = (CNetConn *) pSSL;
		if (netConn == NULL)
			return -1;

		result = pSSL->SSL_Connect(sslCtx);

		if (result < 0)
			return result;

	}

	// ----------------------------------------
	// Authentication Phase

	if ( strcmp(authentication,"CSC") == 0 )
	{
		// CSC User name and Password authentication

		sprintf(string,"<CSC_Auth UserName=\"%s\" Password=\"%s\"/>",userName,password);

		len = strlen(string)+1;
		nLen = htonl(len+4);

		netConn->Write( &nLen, 4 );
		netConn->Write( string, len );

		// Wait for pass/fail

		netConn->Read( &nLen, 4 );

		len = ntohl(nLen) - 4;

		if (len <200)
			netConn->Read( string, len );

		if (strcmp(string,"<CSC_Pass/>") == 0)
		{
			authenticated = 1;
			return 0;
		}
		else
		{
			DBLog(("CCSCConnect::StartNewCSCSession(): Unexpected response [%s]\n", string));
			Disconnect();
			return -1;
		}
	}
	else
	{
		authenticated = 1;
		return 0;
	}
}

//----------------------------------------------------------------------------
//
	int
	CCSCConnect::StartReferralCSCSession
	(
		const char	*	protocol,		// Protocol ID
		const char	*	sessionID,		// NULL for new session, or sessionID
		const char	*	sessionKey,		// Must be spcified if seesionID is used
		const char	*	encryption		// Encryption type
	)
//
//	Opens the read and write sockets and does the hand shake with the TR server
//
//----------------------------------------------------------------------------
{
	char		string[200];
	int			len,nLen;
	//char		*pCSC_Info = NULL;
	int			result;

	// ----------------------------------------
	// Write CSC_Start

	sprintf(string,"<CSC_Start_Session ProtocolID=\"%s\" SessionID=\"%s\"/>",
				protocol,
				sessionID );

	len = strlen(string)+1;
	nLen = htonl(len+4);

	netConn->Write( &nLen, 4 );
	netConn->Write( string, len );

	// ----------------------------------------
	// CSC Encryption Phase

	if (strcmp(encryption,"SSL")==0)
	{
		// SSL Encryption

		CNetConn_SSL * pSSL = new CNetConn_SSL( (CNetConn_Socket *) netConn);
		delete netConn;
		netConn = (CNetConn *) pSSL;
		if (netConn == NULL)
			return -1;

		result = pSSL->SSL_Connect(sslCtx);

		if (result != 0)
			return result;
	}

	// ----------------------------------------
	// Authentication Phase

	result = CSC_Test( sessionKey );

	return result;
}

//---------------------------------------------------------------------------
//
	int									// <0 = error
	CCSCConnect::CSC_Test
	(
		const char *pKey					// Key in Base 64
	)
//
//	Transacts the Authentication Selection phase of CSC
//
//---------------------------------------------------------------------------
{
	int				result;
	unsigned char	serverData[256];
	unsigned char	clientData[256];
	char			base64[512];
	char			base64_2[512];
	char			msg[512];
	RC4_KEY			key;
	int				keyLength;
	unsigned char	keyData[16];
	int				x;
	EVP_ENCODE_CTX	ctx;
	const char		*pTestData = "1234567890"; // "This is the client challenge SAFGWAGQ@#EARG43";
	char			*pMsg;
	CSXDB_Parse		parser;
	CSC_Test1_DATA	CSC_Test1;
	CSC_Test3_DATA	CSC_Test3;

	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, keyData, &x, (const unsigned char *) pKey, strlen(pKey));
	keyLength = x;
	EVP_DecodeFinal(&ctx, &keyData[x], &x); 
	keyLength += x;

	// --------------------------------
	// Receive the CSC_Test1 message

	result = ReadMessage( &pMsg );

	if (result <0)
		return result;

	parser.Parse( pMsg, strlen(pMsg));

	result = SXDB_PT_Get(parser.GetDataBase(), "/CSC_Test1", CSC_Test1_Table, &CSC_Test1 );

	if (result <0)
		return result;

	// --------------------------------
	// Send the CSC_Test2 message

	// Encrypt the server data

	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, serverData, &x, (unsigned char *) CSC_Test1.pClearText, strlen(CSC_Test1.pClearText));
	result = x;
	EVP_DecodeFinal(&ctx, &serverData[x], &x); 
	result += x; 

	RC4_set_key( &key, keyLength, keyData);
	RC4( &key, result, serverData, serverData );

	x = result;

	EVP_EncodeInit(&ctx);
	EVP_EncodeUpdate(&ctx, (unsigned char *)base64, &result, (unsigned char *) serverData, x);
	EVP_EncodeFinal(&ctx, (unsigned char *)&base64[result], &result);

	// Create Client test data

	result = (int) OS_GetTickCount();
	for (x=0;x<(int)strlen(pTestData);x++)
	{
		clientData[x] = pTestData[x] ^ (char) result;
		result >>= 3;
		result ^= OS_GetTickCount();
	}

	EVP_EncodeInit(&ctx);
	EVP_EncodeUpdate(&ctx, (unsigned char *)base64_2, &result, (unsigned char *) clientData, strlen(pTestData));
	EVP_EncodeFinal(&ctx, (unsigned char *)&base64_2[result], &result);

	// Send CSC_Test3

	sprintf(msg,"<CSC_Test2 Encrypted=\"%s\" ClearText=\"%s\"/>",base64,base64_2);

	result = WriteMessage(msg);

	if (result <0)
		return result;

	// --------------------------------
	// Receive the CSC_Test3 message

	result = ReadMessage( &pMsg );

	if (result <0)
		return result;

	parser.Parse( pMsg, strlen(pMsg));

	result = SXDB_PT_Get(parser.GetDataBase(),"/CSC_Test3", CSC_Test3_Table, &CSC_Test3 );

	if (result <0)
		return result;

	// Verify that the caller encrypted the data correctly

	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, serverData, &result, (unsigned char *) CSC_Test3.pEncrypted, strlen(CSC_Test3.pEncrypted));
	x = result;
	EVP_DecodeFinal(&ctx, &serverData[x], &x); 

	RC4_set_key( &key, keyLength, keyData);
	RC4( &key, result, serverData, serverData );

	if (result != (int) strlen(pTestData))
		return -1;

	for (x=0;x<(int)strlen(pTestData);x++)
	{
		if (clientData[x] != serverData[x])
			return -1;
	}

	return 0;
}	


//----------------------------------------------------------------------------
//
	void
	CCSCConnect::Disconnect
	(
	)
//
//	Opens the read and write sockets and does the hand shake with the TR server
//
//----------------------------------------------------------------------------
{
	if (connected)
	{
		delete netConn;
		connected = 0;
	}
}

//----------------------------------------------------------------------------
//
	int									// 0 if no error
	CCSCConnect::WriteMessage
	(
		const char	*	pMessage			// Message to send
	)
//
//	Writes an XML message to the connection.
//	The format of the message is:
//		lenght - 4 bytes binary = 4+xml_length+1
//		XML data - variable length xml data
//		null - 1 byte of 0x00
//
//----------------------------------------------------------------------------
{
	int	result;
	int	length = 4 + strlen(pMessage) + 1;
	int netLength = htonl( length );

	// Send length

	result = netConn->Write( &netLength, 4 );

	if (result < 0)
		return result;

	// Send XML Data + null

	result = netConn->Write( pMessage, length - 4 );

	return result;
}

//----------------------------------------------------------------------------
//
	int									// 0 if no error
	CCSCConnect::ReadMessage
	(
		char	**	pMessage,			// receives the ptr to the received msg
		int		maxLength				// Optional max length, 0 = no limit
	)
//
//	Receives a message from the connection.
//	The memory for the message is allocated with new. The caller must use
//	delete when finished with the message data.
//
//----------------------------------------------------------------------------
{
	int	result;
	int	length;
	int netLength;
	char * pMsg;

	*pMessage = NULL;

	// Read length

	result = netConn->Read( &netLength, 4 );

	if (result < 0)
		return result;

	length = ntohl(netLength);

	if (maxLength)
	{
		if (length > maxLength)
			return -1;
	}

	if (length < 9)
		return -1;

	// Send XML Data + null

	pMsg = new char[length-4];

	result = netConn->Read( pMsg, length - 4 );

	if (result < 0)
		delete pMsg;
	else
		*pMessage = pMsg;

	return result;
}

//----------------------------------------------------------------------------
//
	char *								// Ptr to CSC_Info document
	Query_CSC_Info
	(
		int			ipAddress,			// The IP Address
		short		tcpPort,			// Which port (0 = default = 5000)
		int			seconds				// # of seconds to wait (0 = default = 10)
	)
//
//	Queries a server's CSC_Info document using the UDP discovery method.
//
//----------------------------------------------------------------------------
{
	SOCKET					s;
	struct	sockaddr_in		sa;
	struct	sockaddr_in		from;
	FD_SET_T		        f_read;
	struct	timeval			selectTimeOut;
	int						result;
	int						length;
	char			*		pBuffer = new char[MAX_CSC_INFO];
	int						x;

	if (pBuffer == NULL)
		return NULL;

	// Create a new socket

	s = socket( AF_INET, SOCK_DGRAM, 0 );

	if (s == INVALID_SOCKET)
		return NULL;

	// Get the port

	if (tcpPort == 0)
		tcpPort = 5000;
	tcpPort = htons(tcpPort);

	// Set the address info

	sa.sin_family = AF_INET;
	sa.sin_addr.S_ADDR = htonl(ipAddress);

	// Send the request

	if (seconds == 0)
		seconds = 10;

	seconds = (seconds * 1000) + OS_GetTickCount();

	do
	{
		// send the command

		sa.sin_port = tcpPort;
		result = sendto( s, CSC_Discover, CSC_Dis_Len, 0, (sockaddr *) &sa, sizeof(sa) );

		// Wait for something to read

		FD_ZERO(&f_read);
		FD_SET( (SOCKET) s, &f_read);

		selectTimeOut.tv_usec = 0;
		selectTimeOut.tv_sec  = 1;

		result = select( s+1, &f_read, NULL, NULL, &selectTimeOut );

		if (result == 0)
			continue;

		// Read the data

		from = sa;
		length = sizeof(from);
		result = recvfrom( s, pBuffer, MAX_CSC_INFO, 0, (sockaddr *) &from, (socklen_t*) &length );

		if (result < 4)
		{
			result = WSAGetLastError();
			continue;
		}

		// Look for CSC_Info

		length = * (int *) pBuffer;
		length = ntohl( length );

		if (	length >= 11 || length <= MAX_CSC_INFO && 
				length <= result && 
				pBuffer[length-1] == 0x00 )
		{
			// Scan for CSC_Info tag

			if (strstr(&pBuffer[4],"CSC_Info") != NULL)
			{
				// Found a CSC_Info

				for (x=0;x<length-4;x++)
					pBuffer[x] = pBuffer[x+4];
				
				closesocket( s );
				return pBuffer;
			}
		}

	} while ((DWORD) seconds >= OS_GetTickCount());

	delete pBuffer;
	closesocket(s);
	return NULL;
}

