////////////////////////////////////////////////////////////////////////////
//
//  Crytek Source File.
//  Copyright (C), Crytek Studios, 2009.
// -------------------------------------------------------------------------
//  File name:   ProtocolBuilder.h
//  Version:     v1.00
//  Created:     13/08/2009 by Younggi Lim
//  Compilers:   Visual Studio.NET
//  Description: make a simple protocol based on key/value string
// -------------------------------------------------------------------------
//  History:
//
////////////////////////////////////////////////////////////////////////////

#include "StdAfx.h"
#ifdef USING_LICENSE_PROTECTION

#include "ProtocolBuilder.h"
#include "AesCryptography.h"

const uint8 ProtocolBuilderType = 0xD2;

CProtocolBuilder::CProtocolBuilder(void) : m_encrypt(false), m_cryptObj(new CAesCryptography())
{
}

CProtocolBuilder::CProtocolBuilder(const CProtocolBuilder& other)
{
	m_container.clear();
	StringMapConstIterator endIter = other.m_container.end();
	for(StringMapConstIterator iter = other.m_container.begin(); iter != endIter; ++iter)
		Add(iter->first, iter->second);
}

CProtocolBuilder::~CProtocolBuilder(void)
{
	if (NULL != m_cryptObj)
	{
		delete m_cryptObj;
		m_cryptObj = NULL;
	}
}

bool CProtocolBuilder::AddImpl(const string& key, const string& value)
{
	std::pair<StringMapIterator, bool> result;
	result = m_container.insert(std::pair<string, string>(key, value));
	if (false == result.second)
		return false;
	return true;
}

bool CProtocolBuilder::Remove(const string& key)
{
	StringMapIterator iter = m_container.find(key);
	if (iter == m_container.end())
		return false;
	m_container.erase(iter);
	return true;
}

bool CProtocolBuilder::Parse(const string& source)
{
	if (-1 == source.find("=") || -1 == source.find(";"))
		return false;

	string subSource = source;
	string key, value;

	while(subSource.length() > 0)
	{
		int keyPos = (int)subSource.find("=");
		if (keyPos<0)
			break;	

		key = subSource.substr(0, keyPos);
		subSource = subSource.substr(keyPos+1);
		int valuePos = (int)subSource.find(";");
		if (valuePos<0)
			break;
		value = subSource.substr(0, valuePos);
		Add(key, value);
		subSource = subSource.substr(valuePos+1);
	}
	return true;
}

bool CProtocolBuilder::Parse( char* buffer, uint32 &offset, const uint32 bufferLength)
{
	uint32 protocolHeaderSize = sizeof(SProtocolHeader);
	if (bufferLength < protocolHeaderSize)
		return false;

	SProtocolHeader header;
	if (false == GetHeader(buffer, offset, bufferLength, header))
		return false;
	if (bufferLength < header.totalLength)
		return false;
	
	uint32 payloadSize = header.totalLength-protocolHeaderSize;
	if (m_encrypt)
	{
		uint32 bufferLength = payloadSize;
		char* cryptBuffer = new char[bufferLength];
		m_cryptObj->DecryptBuffer((const unsigned char*)&buffer[offset+protocolHeaderSize], bufferLength, (unsigned char*)cryptBuffer);
		memcpy(&buffer[offset+protocolHeaderSize], cryptBuffer, bufferLength);
		delete [] cryptBuffer;
		cryptBuffer = NULL;
	}
	string payload(&buffer[offset+protocolHeaderSize], payloadSize);
	
	if (false == Parse(payload))
		return false;

	offset += header.totalLength;
	return true;
}

string CProtocolBuilder::GetValueImpl(const string& key) const
{
	StringMapConstIterator iter = m_container.find(key);
	if (iter != m_container.end())
		return iter->second;
	return "";
}

void CProtocolBuilder::Clear()
{
	m_container.clear();
}

string CProtocolBuilder::MakeString() const
{
	string result;
	StringMapReverseIterator endIter = m_container.rend();
	for(StringMapReverseIterator iter = m_container.rbegin(); iter != endIter; ++iter)
		result += iter->first + "=" + iter->second + ";";
	return result;
}

uint32 CProtocolBuilder::MakeBuffer(char* buffer, uint32 length) const
{
	string strResult = MakeString();
	if (strResult.length()<CryptBlockSize)
		strResult += "forsmalldata=dummy;";
	uint32 checkSum = 0;
	uint32 headerLength = MakeHeader(buffer, (uint32)strResult.length(), checkSum);
	memcpy(&buffer[headerLength], strResult.c_str(), strResult.length());
	if (m_encrypt)
	{
		uint32 bufferLength = (uint32)strResult.length();
		char* cryptBuffer = new char[bufferLength];
		m_cryptObj->EncryptBuffer((const unsigned char*)&buffer[headerLength], bufferLength, (unsigned char*)cryptBuffer);
		memcpy(&buffer[headerLength], cryptBuffer, bufferLength);
		delete [] cryptBuffer;
		cryptBuffer = NULL;
	}
	return headerLength+(uint32)strResult.length();
}

uint32 CProtocolBuilder::MakeHeader( char* buffer, uint32 dataLength, uint32 checkSum ) const
{
	uint32 protocolHeaderSize = sizeof(SProtocolHeader);
	SProtocolHeader header;
	header.type = ProtocolBuilderType;
	header.totalLength = dataLength+protocolHeaderSize;
	header.checkSum = checkSum;
	memcpy(buffer, &header, protocolHeaderSize);
	return protocolHeaderSize;
}

bool CProtocolBuilder::GetHeader( const char* buffer, uint32 offset, uint32 bufferLength, SProtocolHeader& header ) const
{
	const uint32 protocolHeaderSize = sizeof(SProtocolHeader);
	if (bufferLength < protocolHeaderSize)
		return false;

	memcpy(&header, &buffer[offset], protocolHeaderSize);
	if (ProtocolBuilderType != header.type)
		return false;

	if (protocolHeaderSize > header.totalLength)
		return false;

	// todo : checksum check

	return true;
}

void CProtocolBuilder::EncryptKey(unsigned char* key, unsigned char keyLen)
{
	m_encrypt = true;
	m_cryptObj->SetKeyValue(key, keyLen);
}

#endif // USING_LICENSE_PROTECTION