#include "send.h"
#include "config.h"
#include "fileio.h"
#include "sock.h"
#include "bstream.h"
#include "hash.h"
#include "dataqueue.h"
#include "hashcache.h"
#include "bandwidthmonitor.h"
#include <time.h>

struct SNoCallback : public IHashProgressCallback
{
	void OnProgress() {}
};
static SNoCallback noCallback;

class CSend
{
public:
	CSend( const std::string& filename );
	void Run();

private:
	std::string m_filename;
	CHash m_hash;
	CFile m_file;
	CSock m_sock;
	CBandwidthMonitor m_bandwidthMonitor;
	CDataQueue m_sendQueue;
	CDataQueue m_mergeQueue;
	fpos_t m_fileLength;
	unsigned m_numPackets;
	int m_finishCounter;
	bool m_readyToReceive;
	bool m_needAnnounce;
	bool m_flooding;
	time_t m_lastAnnounceSend;
	time_t m_lastAnnounceFinish;
	time_t m_lastMerge;
	time_t m_lastChangeBandwidth;

	typedef double (CSend::*CalculateDelayFunction)( double maxDelay );
	typedef std::vector<CalculateDelayFunction> DelayFunctions;
	DelayFunctions m_delayFunctions;

	typedef void (CSend::*NormalJobFunction)();
	typedef std::vector<NormalJobFunction> JobFunctions;
	JobFunctions m_jobFunctions;

	typedef void (CSend::*PacketHandlerFunction)( BInputStream& in );
	typedef std::map<char, PacketHandlerFunction> PacketHandlers;
	PacketHandlers m_packetHandlers;

	static const time_t TIMEOUT_MERGE = 5;
	static const time_t TIMEOUT_ANNOUNCE_SEND = 3;
	static const time_t TIMEOUT_ANNOUNCE_FINISH = 1;
	static const time_t TIMEOUT_CHANGE_BANDWIDTH = 1;
	static const int NUM_FINISH_ANNOUNCEMENTS = 30;

	void SendDataPacket( unsigned i );
	void SendAnnouncePacket();
	void SendFinishPacket();

	double CalculateDelay();
	void PerformJobs();

	double CalculateDelayDataPacket( double );
	double CalculateDelayMergeQueue( double );
	double CalculateDelayAnnounceSend( double );
	double CalculateDelayAnnounceFinish( double );

	void JobDataPacket();
	void JobMergeQueue();
	void JobReceivePacket();
	void JobAnnounceSend();
	void JobAnnounceFinish();
	void JobRecalculateDataRate();

	void HandleRequestData( BInputStream& );
	void HandleRequestAnnounce( BInputStream& );
	void HandleKeepAlive( BInputStream& );
};

CSend::CSend( const std::string& filename ) :
	m_filename(filename),
	m_hash(GetHash(filename, &noCallback)),
	m_file(filename, true),
	m_finishCounter(30),
	m_readyToReceive(false),
	m_needAnnounce(true),
	m_flooding(false),
	m_lastAnnounceSend(0),
	m_lastAnnounceFinish(0),
	m_lastMerge(0),
	m_lastChangeBandwidth(0)
{
	m_sock.Bind(0);
	m_sock.SetSockOpt(SOL_SOCKET, SO_BROADCAST, 1);

	// get file size in bytes
	m_file.Seek(0, SEEK_END);
	m_fileLength = m_file.GetPos();
	m_file.Seek(0, SEEK_SET);
	// and in packets
	m_numPackets = unsigned( m_fileLength / DATA_SIZE + (m_fileLength % DATA_SIZE != 0) );

	// strip path from filename
	std::string::size_type lastSlash = filename.rfind('\\');
	if (lastSlash != std::string::npos)
		m_filename = m_filename.substr( lastSlash+1 );

	// fill the send queue
	m_sendQueue.Insert( CDataRange(0, m_numPackets) );

	m_delayFunctions.push_back( &CSend::CalculateDelayDataPacket );
	m_delayFunctions.push_back( &CSend::CalculateDelayMergeQueue );
	m_delayFunctions.push_back( &CSend::CalculateDelayAnnounceSend );
	m_delayFunctions.push_back( &CSend::CalculateDelayAnnounceFinish );

	m_jobFunctions.push_back( &CSend::JobMergeQueue );
	m_jobFunctions.push_back( &CSend::JobAnnounceSend );
	m_jobFunctions.push_back( &CSend::JobAnnounceFinish );
	m_jobFunctions.push_back( &CSend::JobReceivePacket );
	m_jobFunctions.push_back( &CSend::JobDataPacket );
	m_jobFunctions.push_back( &CSend::JobRecalculateDataRate );

	m_packetHandlers['R'] = &CSend::HandleRequestData;
	m_packetHandlers['W'] = &CSend::HandleRequestAnnounce;
	m_packetHandlers['K'] = &CSend::HandleKeepAlive;
}

void CSend::SendDataPacket( unsigned i )
{
	fpos_t startPos = fpos_t(i) * DATA_SIZE;
	fpos_t endPos = min( startPos + DATA_SIZE, m_fileLength );
	size_t length = size_t( endPos - startPos );

	BOutputStream stm;
	stm << 'D' << PROTOCOL_VERSION << i;
	m_file.SetPos( startPos );
	m_file.Read( stm.Put(length), length );

	m_sock.Send( BROADCAST_ADDR, CLIENT_PORT, stm.GetPtr(), stm.GetSize() );
	m_bandwidthMonitor.Add( stm.GetSize() );
}

void CSend::SendAnnouncePacket()
{
	std::cout << "Send announce" << std::endl;

	BOutputStream stm;
	stm << 'A' << PROTOCOL_VERSION << m_filename << m_numPackets << m_hash;

	m_sock.Send( BROADCAST_ADDR, CLIENT_PORT, stm.GetPtr(), stm.GetSize() );
	m_bandwidthMonitor.Add( stm.GetSize() );
}

void CSend::SendFinishPacket()
{
	std::cout << "Send finish " << m_finishCounter << std::endl;
	BOutputStream out;
	out << 'F' << PROTOCOL_VERSION;
	m_sock.Send( BROADCAST_ADDR, CLIENT_PORT, out.GetPtr(), out.GetSize() );
	m_bandwidthMonitor.Add( out.GetSize() );
}

double CSend::CalculateDelay()
{
	double delay = 10.0;
	for (size_t i=0; i<m_delayFunctions.size(); i++)
		delay = (this->*(m_delayFunctions[i]))(delay);

	return delay;
}

void CSend::PerformJobs()
{
	for (size_t i=0; i<m_jobFunctions.size(); i++)
		(this->*(m_jobFunctions[i]))();
}

double CSend::CalculateDelayDataPacket( double delay )
{
	if (!m_sendQueue.Empty() && m_bandwidthMonitor.GetMaxPacketRate())
		delay = min( delay, 1.0 / m_bandwidthMonitor.GetMaxPacketRate() );
	return delay;
}

double CSend::CalculateDelayAnnounceFinish( double delay )
{
	if (m_finishCounter && m_mergeQueue.Empty() && m_sendQueue.Empty() && !m_needAnnounce)
		delay = min(1.0, delay);
	return delay;
}

double CSend::CalculateDelayAnnounceSend( double delay )
{
	if (m_needAnnounce)
		delay = min(1.0, delay);
	return delay;
}

double CSend::CalculateDelayMergeQueue( double delay )
{
	if (!m_mergeQueue.Empty())
		delay = min(3.0, delay);
	return delay;
}

void CSend::JobDataPacket()
{
	if (!m_sendQueue.Empty() && m_bandwidthMonitor.AllowData())
	{
		SendDataPacket( m_sendQueue.Next() );
		m_finishCounter = NUM_FINISH_ANNOUNCEMENTS;
	}
}

void CSend::JobAnnounceFinish()
{
	if (m_finishCounter && m_sendQueue.Empty() && m_mergeQueue.Empty() && !m_needAnnounce)
	{
		if (time(NULL) - m_lastAnnounceFinish < TIMEOUT_ANNOUNCE_FINISH)
			return;
		SendFinishPacket();
		m_finishCounter --;
		m_lastAnnounceFinish = time(NULL);
	}
}

void CSend::JobAnnounceSend()
{
	if (m_needAnnounce && (time(NULL) - m_lastAnnounceSend) >= TIMEOUT_ANNOUNCE_SEND)
	{
		SendAnnouncePacket();
		m_finishCounter = NUM_FINISH_ANNOUNCEMENTS;
		m_needAnnounce = false;
		m_lastAnnounceSend = time(NULL);
	}
}

void CSend::JobMergeQueue()
{
	if (!m_mergeQueue.Empty() && (time(NULL) - m_lastMerge) >= TIMEOUT_MERGE)
	{
		unsigned count0 = m_sendQueue.Count();
		m_sendQueue.Merge( m_mergeQueue );
		m_lastMerge = time(NULL);
		m_mergeQueue.Clear();
		unsigned count1 = m_sendQueue.Count();

		std::cout << "Resend " << (count1-count0) << " packets" << std::endl;
	}
}

void CSend::JobRecalculateDataRate()
{
	if ((time(NULL) - m_lastChangeBandwidth) < TIMEOUT_CHANGE_BANDWIDTH)
		return;
	if (m_flooding)
		m_bandwidthMonitor.SetMaxPacketRate( m_bandwidthMonitor.GetMaxPacketRate()/4 );
	else
	{
		size_t curPacketRate = m_bandwidthMonitor.GetCurPacketRate();
		size_t maxPacketRate = m_bandwidthMonitor.GetMaxPacketRate();
		if (curPacketRate*3/2 > maxPacketRate)
			m_bandwidthMonitor.SetMaxPacketRate( maxPacketRate + 10 );
		else if (maxPacketRate/2 < curPacketRate)
			m_bandwidthMonitor.SetMaxPacketRate( curPacketRate );
		else
			m_bandwidthMonitor.SetMaxPacketRate( maxPacketRate + 1 );
	}
	m_lastChangeBandwidth = time(NULL);

	std::cout << "Max packet rate: " << unsigned(m_bandwidthMonitor.GetMaxPacketRate()) 
		<< " Current: " << unsigned(m_bandwidthMonitor.GetCurPacketRate()) 
		<< " Bandwidth: " << unsigned(m_bandwidthMonitor.Bandwidth()) << std::endl;
}

void CSend::JobReceivePacket()
{
	if (m_readyToReceive)
	{
		static const size_t LENGTH = 65536;
		char buf[LENGTH];
		unsigned addr; unsigned short port;
		size_t length = m_sock.Receive( addr, port, buf, LENGTH );
		// input bandwidth usage is just as bad as output
		m_bandwidthMonitor.Add( length );
		BInputStream input(buf, length);
		try
		{
			char packetType; input >> packetType;
			unsigned short version; input >> version;
			if (version != PROTOCOL_VERSION)
			{
				std::cout << "Invalid protocol version " << version << " (should be " << PROTOCOL_VERSION << ")" << std::endl;
				return;
			}
			PacketHandlers::iterator iter = m_packetHandlers.find(packetType);
			if (iter == m_packetHandlers.end())
				std::cout << "Ignored packet of type " << packetType << std::endl;
			else
				(this->*(iter->second))(input);
		}
		catch (std::exception& e)
		{
			std::cout << "Exception occured on input: " << e.what() << std::endl;
		}
	}
}

void CSend::HandleKeepAlive(BInputStream&)
{
	m_finishCounter = NUM_FINISH_ANNOUNCEMENTS;
}

void CSend::HandleRequestAnnounce(BInputStream&)
{
	m_needAnnounce = true;
}

void CSend::HandleRequestData(BInputStream& in)
{
	if (m_mergeQueue.Empty())
		m_lastMerge = time(NULL);
	CDataQueue tempQueue;
	in >> tempQueue;
	if (tempQueue.HasIslands(m_numPackets))
		m_flooding = true;
	m_mergeQueue.Merge( tempQueue );
	m_finishCounter = NUM_FINISH_ANNOUNCEMENTS;
}

void CSend::Run()
{
	while (!m_sendQueue.Empty() || !m_mergeQueue.Empty() || m_finishCounter)
	{
		double delay = CalculateDelay();

		timeval tv;
		tv.tv_sec = long(delay);
		tv.tv_usec = long( 1e6 * (delay - long(delay)) );
		m_readyToReceive = m_sock.Select( &tv, true, false );

		PerformJobs();
	}
}

void DoSend( const std::string& filename )
{
	CSend send(filename);
	send.Run();
}
