/*--------------------------------------------------------------------------

	CSCAccept.cpp

	Copyright (c) 2003, Raritan Computer, Inc.

	Raritan Common Socket Connection.

---------------------------------------------------------------------------*/


#ifdef OS_MFC
#include "stdafx.h"
#endif

#include	<string.h>
#include	"pp/OS_Port.h"

#include	"pp/CSCAccept.h"
#include	"pp/CSCInfo.h"
#include	"pp/CSCProtocol.h"
#include	"pp/NetConn.h"
//#define	USING_RANDOMDATA_CPP
#include	"pp/RandomData.h"

//#define	CSC_DEBUG

#ifdef	CSC_DEBUG
#include	"Debug.h"
#else
#define	DBLog(a)
#endif

#ifndef _SSL_H_
#define	_SSL_H_
#include <openssl/rsa.h>       /* SSLeay stuff */
#include <openssl/crypto.h>
#include <openssl/x509.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/rc4.h>
#include <openssl/evp.h>
#endif


#ifdef OS_MFC
#ifdef _DEBUG
#define _CRTDBG_MAP_ALLOC
#include <stdlib.h>
#include <crtdbg.h>
#endif //_DEBUG
#else //OS_MFC
#ifdef DMALLOC
#include "dmalloc.h"
#endif //DMALLOC
#endif //OS_MFC

//----------------------------------------
//				Equates
//----------------------------------------

#define	CSC_MAX_MESSAGE			512

//----------------------------------------
//				Data Types
//----------------------------------------

//----------------------------------------
//				Function Prototypes
//----------------------------------------

bool CSC_AuthenticateUser
(
	const char                *name,
	const char                *password,
	DWORD                     ipAddress,
	CUserObject 	**	ppUserObject
);

//----------------------------------------
//				Static Data
//----------------------------------------


//----------------------------------------
//		CSC_Ack Parse Table
//----------------------------------------

#define	PT_STRUCT	int
PT_BEGIN	( "CSC_Ack",		CSC_Ack_Table,		PT_NO_UNKNOWN )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_Connect Parse Table
//----------------------------------------

#define	PT_STRUCT	int
PT_BEGIN	( "CSC_Connect",	CSC_Connect_Table,	PT_UNKNOWN_OK )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_Start_Session Parse Table
//----------------------------------------

typedef	struct
{
	char	*pProtocolID;
	char	*pSessionID;	
} CSC_Start_DATA;

#define	PT_STRUCT	CSC_Start_DATA
PT_BEGIN	( "CSC_Start_Session",CSC_Start_Table,	PT_NO_UNKNOWN )
PT_ATT		( "ProtocolID",		pProtocolID,	0,	PT_STRING_PTR | PT_REQUIRED )
PT_ATT		( "SessionID",		pSessionID,		0,	PT_STRING_PTR )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_Test2 Parse Table
//----------------------------------------

typedef	struct
{
	char	*pEncrypted;
	char	*pClearText;	
} CSC_Test2_DATA;

#define	PT_STRUCT	CSC_Test2_DATA
PT_BEGIN	( "CSC_Test2",		CSC_Test2_Table,	PT_NO_UNKNOWN )
PT_ATT		( "ClearText",		pClearText,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_ATT		( "Encrypted",		pEncrypted,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_Auth Parse Table
//----------------------------------------

typedef	struct
{
	char	*pUserName;
	char	*pPassword;	
} CSC_Auth_DATA;

#define	PT_STRUCT	CSC_Auth_DATA
PT_BEGIN	( "CSC_Auth",		CSC_Auth_Table,		PT_NO_UNKNOWN )
PT_ATT		( "UserName",		pUserName,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_ATT		( "Password",		pPassword,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_RADIUS_Response Parse Table
//----------------------------------------

typedef	struct
{
	char	*pAnswer;
} CSC_RADUIS_DATA;

#define	PT_STRUCT	CSC_RADUIS_DATA
PT_BEGIN	( "CSC_RADIUS_Response", CSC_RADIUS_Table,	PT_NO_UNKNOWN )
PT_ATT		( "Answer",			pAnswer,		0,	PT_STRING_PTR | PT_REQUIRED )
PT_END
#undef	PT_STRUCT

//----------------------------------------
//		CSC_Test2 Parse Table
//----------------------------------------

#define	PT_STRUCT	CSC_Protocol_Info
PT_BEGIN	( "Protocol",		protocolTable,				PT_UNKNOWN_OK )
PT_ATT		( "id",				id,				RDM_MAX_ID,	PT_STRING | PT_REQUIRED )
PT_ATT		( "Encrypt",		encrypt,		RDM_MAX_ID,	PT_STRING | PT_REQUIRED )
PT_ATT		( "Version",		version,		RDM_MAX_ID,	PT_STRING | PT_REQUIRED )
PT_ATT		( "OldestVersion",	oldestVersion,	RDM_MAX_ID,	PT_STRING | PT_REQUIRED )
PT_ATT		( "Compression",	compression,	RDM_MAX_ID,	PT_STRING | PT_REQUIRED )
PT_ATT		( "Auth",			auth,			RDM_MAX_ID,	PT_STRING | PT_REQUIRED )
PT_END	
#undef	PT_STRUCT

//----------------------------------------
//				Code
//----------------------------------------

//---------------------------------------------------------------------------
//								CRDM_System
//---------------------------------------------------------------------------

//---------------------------------------------------------------------------
//
	CCSCAccept::CCSCAccept
	(
		 CSessionManager *pNewSessionManager,
		 CCSCInfo	 *pNewCSCInfo
	)
//
//	Initialize data items
//
//---------------------------------------------------------------------------
{
	pUserObject = NULL;
	referralSession = 0;
	sessionID[0] = 0;
	pSessionManager = pNewSessionManager;
	pCSCInfo = pNewCSCInfo;
}

//---------------------------------------------------------------------------
//

	CCSCAccept::~CCSCAccept

	(
	)
//
//	Cleanup
//
//---------------------------------------------------------------------------
{
}

//---------------------------------------------------------------------------
//
	int									// <0 = error

	CCSCAccept::NegotiateProtocol

	(

		int		*pBytesSent				// Optional parameter used to return

										// the number of bytes sent if CSC

										// failed

	)
//
//	Transacts the Protocol Selection phase of CSC
//
//---------------------------------------------------------------------------
{
	int					result;
	bool				found = false;
	int					x;
	CSC_Start_DATA		CSC_Start;
	const char				*pProtocolString;

	// Send the <CSC/> message

	result = WriteMessage("<CSC/>");

	if (result != 0)
		return -1;


	// See if the client is responding

	if (!m_pNetConn->Select(30))
	{
		if (pBytesSent != NULL)
			*pBytesSent = strlen("<CSC/>") + 5;
		return -2;
	}

	// Get the CSC_Ack message

	result = GetMessage(CSC_Ack_Table, &CSC_Start, CSC_Connect_Table, &CSC_Start);

	// See if we received a CSC_Connent

	if (result == 1)
	{
		// Ok, ignore the CSC_Connect and wait for the real CSC_Ack

		result = GetMessage(CSC_Ack_Table, &CSC_Start );
	}

	if (result != 0)
	{
		return result;
	}

	// Send the CSC_Info

	char *p = pCSCInfo->MakeCSCInfo();
	result = WriteMessage( p );
	delete [] p;

	// Get the CSC_Start_Session message

	CSC_Start.pSessionID = NULL;

	result = GetMessage(CSC_Start_Table, &CSC_Start);

	if (result != 0)
		return result;

	if (CSC_Start.pSessionID != NULL)
	{
		strncpy(sessionID,CSC_Start.pSessionID,RDM_MAX_ID-1);
		sessionID[RDM_MAX_ID-1] = 0;
		referralSession = 1;
	}
	else
		referralSession = 0;

	strncpy(protocolID,CSC_Start.pProtocolID,RDM_MAX_ID-1);
	protocolID[RDM_MAX_ID-1] = 0;

	// Find the protocol in the list of protcols

	for (x=0; (pProtocol = pCSCInfo->EnumProtocols(x)) != NULL;x++)
	{
		pProtocolString = pProtocol->GetProtocolXML();
		if (pProtocolString == NULL)
			continue;

		result = dbParser.Parse( pProtocolString, strlen(pProtocolString) );

		if (result < 0)
			continue;

		result = SXDB_PT_Get( dbParser.GetDataBase(), "/Protocol", protocolTable, &protocolInfo, 0 );

		if (result != 0)
			continue;

		if (strcmp(protocolID,protocolInfo.id) == 0)
		{
			found = true;
			break;
		};
	}

	if (!found)
		result = -1;

	return result;;
}

//---------------------------------------------------------------------------
//
	int									// <0 = error, 0 = native auth, 1 = good
	CCSCAccept::Authenticate
	(
	)
//
//	Transacts the Authentication Selection phase of CSC
//
//---------------------------------------------------------------------------

{
	// Referral Authentication ?

	if (referralSession)
		return Referral_Authenticate();

	// If protocol is native authentication, then do nothing

	if (strcmp(protocolInfo.auth,"Native") == 0)
		return 0;

	// CSC Authentication ?

	if (strcmp(protocolInfo.auth,"CSC") == 0)
		return CSC_Authenticate();

	return -1;
}

//---------------------------------------------------------------------------
//
	int									// <0 = error, 1 = good
	CCSCAccept::CSC_Authenticate
	(
	)
//
//	Transacts the Authentication Selection phase of CSC
//
//---------------------------------------------------------------------------

{
	int						result;
	CSC_Auth_DATA			CSC_Auth;
	CUserObject	*		pAccount = NULL;

	// Get the CSC_Auth message

	result = GetMessage( CSC_Auth_Table, &CSC_Auth );

	if (result != 0)
	{
		DBLog(("%s: GetMessage() failed \n", __PRETTY_FUNCTION__));
		return result;
	}

	if (CSC_Auth.pUserName == NULL || CSC_Auth.pPassword == NULL)
	{
		DBLog(("%s: NULL user name(%x) or password(%x) \n", 
				__PRETTY_FUNCTION__, (int)&(CSC_Auth.pUserName), (int)&(CSC_Auth.pPassword)));
		return -1;
	}

	// Authenticate user

	result = CSC_AuthenticateUser ( CSC_Auth.pUserName, CSC_Auth.pPassword, this->ipAddress, &pAccount);

	if (!result)
	{
		WriteMessage("<CSC_Fail/>");
		return -1;
	}

	this->pUserObject = pAccount;
	
	WriteMessage("<CSC_Pass/>");
	
	return 1;	

#if 0

	// no neccesary to authenticate if request from local (127.0.0.1)

    if((unsigned int)this->ipAddress == (unsigned int)0x7F000001)
    {
		pAccount = new CPTUserObject;
		memset(&(pAccount->user), 0, sizeof(pAccount->user));
        strncpy(pAccount->user.userName, CSC_Auth.pUserName, MAX_USER_NAME-1);
        strncpy(pAccount->user.userName, CSC_Auth.pPassword, MAX_PASSWORD-1);
		this->pUserObject = pAccount;
        WriteMessage("<CSC_Pass/>");
		return 1;
	}

	// CC Authentication

	if ( AreWeCCManaged() )
	{
		this->pUserObject = gCommandCenter->Authenticate(CSC_Auth.pUserName,CSC_Auth.pPassword);

		if (this->pUserObject != NULL)
		{
			WriteMessage("<CSC_Pass/>");
			return 1;
		}
		else if (IsCCModeActive())
		{
			WriteMessage("<CSC_Fail/>");
			return -1;
		}
	}

	result = AuthenticateUser( CSC_Auth.pUserName, CSC_Auth.pPassword );

	pAccount = new CPTUserObject;

	if (result)
	{
		memset(&(pAccount->user), 0, sizeof(pAccount->user));
		strncpy(pAccount->user.userName, CSC_Auth.pUserName, MAX_USER_NAME-1);
		strncpy(pAccount->user.userName, CSC_Auth.pPassword, MAX_PASSWORD-1);

		this->pUserObject = pAccount;

		WriteMessage("<CSC_Pass/>");
	}
	else
	{
        DBLog(("CCSCAccept::CSC_Authenticate(): User authentication FAILED\n"));
		delete pAccount;
		WriteMessage("<CSC_Fail/>");
		return -1;
	}

	return 1;

#endif
}

//---------------------------------------------------------------------------
//
	int									// <0 = error, 1 = good
	CCSCAccept::Referral_Authenticate
	(
	)
//
//	Transacts the Authentication Selection phase of CSC
//
//---------------------------------------------------------------------------

{
	int			result;
	CSession	*pSession;

	// Get the session

	pSession = pSessionManager->GetSession( sessionID );

	if (pSession == NULL)
		return -1;

	result = CSC_Test( pSession->GetSessionKey() );
	pSession->Release();

	// referral session fix for PS-A

	if (result >= 0)
		this->pUserObject = pSession->GetUserObject();

	return result < 0 ? -1 : 1;
}

//---------------------------------------------------------------------------
//
	int									// <0 = error
	CCSCAccept::CSC_Test
	(
		const char *pKey					// Key in Base 64
	)
//
//	Transacts the Authentication Selection phase of CSC
//
//---------------------------------------------------------------------------

{
	int				result;
	CSC_Test2_DATA	CSC_Test2;
	unsigned char	serverData[256];
	unsigned char	clientData[256];
	char			base64[512];
	char			msg[512];
	RC4_KEY			key;
	int				keyLength;
	unsigned char	keyData[16];
	int				x;
	EVP_ENCODE_CTX	ctx;
	int				serverDataLen = 16;//151 + (OS_GetTickCount() & 0x3f);

	memset(serverData, 0, 256);

	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, keyData, &x, (const unsigned char *) pKey, strlen(pKey));
	keyLength = x;
	EVP_DecodeFinal(&ctx, &keyData[x], &x);
	keyLength += x; 

	// Send the CSC_Test1 message

	GetRandomData( serverData, serverDataLen );
	EVP_EncodeInit(&ctx);
	EVP_EncodeUpdate(&ctx, (unsigned char *)base64, &result, (unsigned char *) serverData, serverDataLen);
	EVP_EncodeFinal(&ctx, (unsigned char *)&base64[result], &result); 

	sprintf(msg,"<CSC_Test1 ClearText=\"%s\"/>",base64);

	result = WriteMessage(msg);

	if (result < 0)
		return result;

	// Get CSC_Test2

	result = GetMessage(CSC_Test2_Table, &CSC_Test2);

	if (result != 0)
		return result;

	// Verify that the caller encrypted the data correctly

	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, clientData, &x, (unsigned char *) CSC_Test2.pEncrypted, strlen(CSC_Test2.pEncrypted));
	result = x;
	EVP_DecodeFinal(&ctx, &clientData[x], &x); 
	result += x; 

	RC4_set_key( &key, keyLength, keyData);
	RC4( &key, result, clientData, clientData );

	if (result != serverDataLen)
		return -1;

	for (x=0;x<serverDataLen;x++)
	{
		if (clientData[x] != serverData[x])
			return -1;
	}

	// Encrypt the client data

	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, clientData, &x, (unsigned char *) CSC_Test2.pClearText, strlen(CSC_Test2.pClearText));
	result = x;
	EVP_DecodeFinal(&ctx, &clientData[x], &x); 
	result += x; 

	RC4_set_key( &key, keyLength, keyData);
	RC4( &key, result, clientData, clientData );

	x = result;

	EVP_EncodeInit(&ctx);
	EVP_EncodeUpdate(&ctx, (unsigned char *)base64, &result, (unsigned char *) clientData, x);
	EVP_EncodeFinal(&ctx, (unsigned char *)&base64[result], &result); 

	// Send CSC_Test3

	sprintf(msg,"<CSC_Test3 Encrypted=\"%s\"/>",base64);

	result = WriteMessage(msg);

	return result;
}	


//---------------------------------------------------------------------------
//
	int								// 0 = no error
	CCSCAccept::GetMessage
	(
		SXDB_PT		*parseTable,	// PT Parse table for the message
		void		*pStruct,		// Where to put the data
		SXDB_PT		*altParseTable,	// Alternte PT Parse table for the message
		void		*pAltStruct		// Where to put the Alternate data
	)
//
//	Reads a message from the socket and parses it according to the PT table.
//
//	Note: If altParseTable is not NULL, then altParseTable & pAltStruct are
//		  in the case the first parseTable did not match the incoming msg.
//		  This feature is used when we are waiting for CSC_ACK, but might
//		  receive a CSC_Connect message.
//		  If the altParseTable is matched, then 1 is returned
//
//---------------------------------------------------------------------------

{
	int			result;
	int			length;
	char		data[CSC_MAX_MESSAGE];
	CSXDB_Node	*pNode;

	// Read the data length

	result = m_pNetConn->Read( (BYTE*) &length, 4);

	if (!result)
	{
		DBLog(("%s: Error -- FAILED to receive lenth field (4 bytes)\n", __PRETTY_FUNCTION__));
		return -1;
	}

	length = ntohl(length);

	if (length-4 > CSC_MAX_MESSAGE)
	{
		DBLog(("%s: Error -- INVALID length 0x%X\n", __PRETTY_FUNCTION__, length));
		return -1;
	}

	// Read the data

	result = m_pNetConn->Read( (BYTE*) data, length-4);

	if (!result)
	{
		DBLog(("%s: Error --FAILED to receive message of (%d) bytes\n", __PRETTY_FUNCTION__, length));
		return -1;
	}

	if (data[length-5] != 0)
    {
        DBLog(("%s: Error -- Invalid terminal byte (0x%X)\n", __PRETTY_FUNCTION__, data[length-5]));
		return -1;
    }

	// debug purpose; remove it in the future.

	DBLog(("%s: length(%d) data: -----Begin--------\n%s\n------End---------\n", __PRETTY_FUNCTION__, length, data));

	// Parse the XML

	result = dbParser.Parse( data, strlen(data) );

	if (result != 0)
    {
        DBLog(("%s: Error -- Fail to parse XML, result[%d]; [%s]\n", __PRETTY_FUNCTION__, data, result));
		return result;
	}

	// Parse the data from the XML database

	pNode = dbParser.GetDataBase()->Root();

	if (pNode != NULL)
		pNode = pNode->Child();

	if (pNode == NULL)
	{
		DBLog(("%s: Error -- No child node in XML[%s]\n", __PRETTY_FUNCTION__, data));
		return -1;
	}

	result = SXDB_PT_Get( pNode, parseTable, pStruct, 0 );

	if (result < 0 && altParseTable != NULL)
	{
		// try the alternate message format

		DBLog(("%s: Fail to parse parseTable, try altParseTable. result[%d], XML[%s]\n", 
				__PRETTY_FUNCTION__, result, data));

		if((altParseTable == NULL)||(pAltStruct == NULL)) return result;

		result = SXDB_PT_Get( pNode, altParseTable, pAltStruct, 0 );

		if (result >= 0)
			result = 1;
        else
		{
			DBLog(("%s: Fail to parse altParseTable. result[%d]\n",
                __PRETTY_FUNCTION__, result));
		}
	}

	return result;
}

//---------------------------------------------------------------------------
//
	int								// 0 = no error

	CCSCAccept::WriteMessage

	(
		const char	*pData				// Ptr to the message
	)
//
//	Writes a message with formating to the socket 
//
//---------------------------------------------------------------------------
{
	int		length = htonl(strlen(pData)+5);
	int		result;

	result = m_pNetConn->Write( (BYTE*) &length, 4);

	if (!result)
	{
		//DEBUG_TRACE("FAILED to send length field (4 bytes)!\n");
		return -1;
	}

	result = m_pNetConn->Write( (const BYTE*) pData, strlen(pData) + 1);

	if (!result)
	{
		//DEBUG_TRACE("FAILED to send (%d bytes) ... %s\n", ntohl(length), pData);
        return -1;
	}
	else
	{
		//DEBUG_TRACE("SENT (%d bytes) ... %s\n", ntohl(length), pData);

	}

	return 0;
}

/*
//---------------------------------------------------------------------------
//
	int								// < 0 = error

	CCSCAccept::EncodeBase64

	(
		void	*pData,				// Input data
		int		length,				// length of the input data
		char	*pOutput			// ptr to the output string
		int		max					// max output length
	)
//
//	Encodes data into base64
//
//---------------------------------------------------------------------------
{
	EVP_ENCODE_CTX	ctx;
	unsigned char *p = (unsigned char *) pData;
	int				totalLen,len;
	unsigned char base64[256];

	EVP_EncodeInit(&ctx);
	EVP_EncodeUpdate(&ctx, base64, &len, p, length);
	EVP_EncodeFinal(&ctx, &base64[len], &len); 


	EVP_DecodeInit(&ctx);
	EVP_DecodeUpdate(&ctx, code, &len, p, strlen(pKey));
	keyLength = x;
	EVP_DecodeFinal(&ctx, &keyData[x], &x);
	keyLength += x; 

}
*/

