//////////////////////////////////////////////////////////////////////////////
//
// Crytek Source File.
// Copyright (C), Crytek Studios, 2007.
// ---------------------------------------------------------------------------
// Description:
// mtrace database for storing allocation information from cryengine3
// ps3 applications
// ---------------------------------------------------------------------------
// History:
// - June 14 2009 - Created by Christopher Raine 
//
//////////////////////////////////////////////////////////////////////////////

#if defined(WIN32) || defined(WIN64)

#ifndef WIN32_LEAN_AND_MEAN
#define WIN32_LEAN_AND_MEAN
#endif

#include "socket.h"

#include <windows.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <Mstcpip.h>
#include <mtracedb/mtracedb.h>

#include <boost/thread.hpp>

#include <iostream>

#define DEFAULT_PORT "33333"

using mtracedb::out; 
using mtracedb::cerr;

namespace
{
	// Static initializer to ensure that WSAStartup has been called
	// before any windows winsock functions have been called 
	struct winsock_init 
	{
		// Initialize winsockets during static initialization
		winsock_init() 
		{
			WSADATA wsaData;
			if(::WSAStartup(MAKEWORD(2,2), &wsaData))
			{
				cerr() << "mtrace: WSAStartup() failed!" << std::endl; 
				EXIT();
			}
		}
		// During static destruction  call WSACleanup()
		~winsock_init()
		{	::WSACleanup();	}
	};  

	// Static instance of winsock_init 
	static winsock_init _wsa_init; 
  
  // IOBuffer 
  template<size_t BlockSize>
  class ring_buffer_t 
  {
    boost::mutex  m_lock; 
    boost::condition_variable m_condition;
    size_t m_read_count; 
    size_t m_write_count;
    char* m_begin;
    char* m_end; 
    char  m_data[BlockSize];

  public:

    ring_buffer_t()
      : m_lock(), 
        m_condition(), 
        m_read_count(),
        m_write_count(),
        m_begin(m_data), 
        m_end(m_data)
    {}

    void write(const char* buffer, size_t length)
    {
      boost::unique_lock<boost::mutex> lock(m_lock);
      while (true)
      {
        size_t distance = 0;
        if (m_write_count - m_read_count >= BlockSize) 
          goto write_wait; 
        if (m_end < m_begin)
        {
          distance = m_begin - m_end; 
          if (distance < length) 
            goto write_wait; 
          distance = (distance < length) ? distance : length;
          memcpy(m_end, buffer, distance); 
          m_end  += distance; 
          if (m_end == m_data+BlockSize)
            m_end = m_data;
          m_write_count += distance; 
          break;
        }
        else 
        {
          distance = (m_data+BlockSize) - m_end; 
          if (distance < length && 
              (size_t)(m_begin-m_data) < (length-distance)) 
            goto write_wait; 
          distance = (distance < length) ? distance : length;
          memcpy(m_end, buffer, distance); 
          m_end  += distance; 
          buffer += distance; 
          length -= distance; 
          m_write_count += distance; 
          if (m_end == m_data+BlockSize)
            m_end = m_data;
          if (length) 
          { 
            memcpy(m_end, buffer, length); 
            m_end += length; 
            m_write_count += length; 
            if (m_end == m_data+BlockSize)
              m_end = m_data;
          } 
          break; 
        }
        if (false)  { write_wait: m_condition.wait(lock); }
      }
    }

    size_t read(char* buffer, size_t length)
    {
      boost::unique_lock<boost::mutex> lock(m_lock);
      if (m_write_count - m_read_count == 0) 
      {
        return 0; 
      }
      size_t read = 0;
      if (m_begin < m_end)
        read = m_end - m_begin; 
      else 
        read = (m_data+BlockSize) - m_begin; 
      read = (read < length) ? read : length;
      if (read != 0) 
      {
        memcpy(buffer, m_begin, read); 
        m_begin += read; 
        m_read_count += read; 
        if (m_begin == m_data+BlockSize)
          m_begin = m_data;
      }
      m_condition.notify_all();
      boost::thread::yield();
      return read; 
    }
  }; 
  typedef ring_buffer_t<8<<20> ring_buffer;
}

namespace mtrace
{
	namespace net 
	{
		// Forward declaration of the internal connection structure
		struct connection_t
		{
			// The server socket (listens and accepts)
			SOCKET server_socket;
			// The connected client socket (used for data transfer)
			SOCKET client_socket;
      // The inbound ringbuffer 
      ring_buffer in_buffer;
      // The outbound ringbuffer 
      ring_buffer out_buffer;
      // Flag describing if the buffer is active 
      volatile bool active; 
      // Flag describing if the buffer is active 
      volatile bool exit_request; 
      // The thread io pump 
      boost::thread io_pump; 

      // Initializing constructor
      connection_t() 
        : server_socket(INVALID_SOCKET),
          client_socket(INVALID_SOCKET),
          in_buffer(), out_buffer(), 
          active(false),
          exit_request(false)
      {}
		};  
	}; 
}; 

namespace 
{ 
  // The threaded buffer pump loop
  void pump(mtrace::net::connection_t* connection) 
  { 
    char rcv_buffer[512<<10]; 
    char send_buffer[8<<10]; 
    clock_t last_data = clock(); 
    size_t to_be_sent = 0, sent = 0; 
    while (connection->exit_request == false && 
      (clock() - last_data) < (CLOCKS_PER_SEC * 30)) 
    { 
      timeval timeout; 
      timeout.tv_sec = 60;
      timeout.tv_usec = 0; 

      fd_set socket_read, socket_write, socket_except; 
      FD_ZERO(&socket_read); FD_ZERO(&socket_write);
      FD_ZERO(&socket_except);
      FD_SET(connection->client_socket, &socket_read); 
      FD_SET(connection->client_socket, &socket_write); 

      signed ready = ::select(
        (int)connection->client_socket+1, 
        &socket_read, 
        &socket_write,
        &socket_except, 
        &timeout);
      if (ready == SOCKET_ERROR || ready == 0)
      {
        cerr() << "mtrace: client recv error" << std::endl; 
        break;
      }

      if (FD_ISSET(connection->client_socket, &socket_read))
      {
        signed result = ::recv(
          connection->client_socket, 
          rcv_buffer, 
          sizeof(rcv_buffer), 0);
        if (result == SOCKET_ERROR || result == 0) 
        { 
          cerr() << "mtrace: connection lost" << std::endl; 
          break; 
        } 
        connection->in_buffer.write(rcv_buffer, result);
        last_data = clock();
      }

      if (FD_ISSET(connection->client_socket, &socket_write))
      {
        if (to_be_sent == 0)
        {
          to_be_sent = connection->out_buffer.read(
            send_buffer, sizeof(send_buffer));
          if (to_be_sent == 0)
            continue; 
        }
        signed result = 
          ::send(
            connection->client_socket, 
            send_buffer+sent,
            (int)(to_be_sent-sent), NULL);
        if (result < 0 || result == SOCKET_ERROR) 
        { 
          cerr() << "mtrace: invalid connection during send: " 
                    << std::endl;
          break;
        } 
        sent += static_cast<size_t>(result); 
        if (sent == to_be_sent) 
          sent = to_be_sent = 0; 
      }
    }

    while (true)
    {
      signed result = ::recv(
        connection->client_socket, 
        rcv_buffer, 
        sizeof(rcv_buffer), 0);
      if (result != SOCKET_ERROR && result != 0) 
        connection->in_buffer.write(rcv_buffer, result);
      else 
        break; 
    }

    // Close the socket 
    ::shutdown(connection->client_socket, SD_BOTH);
    ::closesocket(connection->client_socket);
    ::closesocket(connection->server_socket);
    connection->active = false; 
  }
}

using namespace mtrace::net; 

// Create a connection to listen to a given port 
connection_t* mtrace::net::create_connection()
{
	signed iResult = -1; 
	struct addrinfo *result = NULL, *ptr = NULL, hints;
  connection_t* connection = new connection_t(); 

	ZeroMemory(&hints, sizeof (hints));
	hints.ai_family = AF_INET;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_protocol = IPPROTO_TCP;
	hints.ai_flags = AI_PASSIVE;

  // Resolve the local address and port to be used by the server
	iResult = getaddrinfo(NULL, DEFAULT_PORT, &hints, &result);
	if (iResult != 0) 
	{
		cerr() << "mtrace: getaddrinfo failed: " << iResult <<
			std::endl;
    WSACleanup();
    return NULL;
	}

  // Create a SOCKET for the server to listen for client connections
	connection->server_socket = 
		socket(result->ai_family, result->ai_socktype,
					 result->ai_protocol);
	if (connection->server_socket == INVALID_SOCKET) 
	{
		cerr() << "mtrace: error at socket(): " <<  WSAGetLastError() <<
			std::endl;
    freeaddrinfo(result);
    WSACleanup();
    return NULL;
	}

  // Setup the TCP listening socket
	iResult = ::bind(
		connection->server_socket, result->ai_addr,
		(int)result->ai_addrlen);
	if (iResult == SOCKET_ERROR) 
	{
		cerr() << "mtrace: bind failed: " <<  WSAGetLastError() <<
			std::endl;
    freeaddrinfo(result);
    closesocket(connection->server_socket);
    WSACleanup();
    return NULL;
	}

	freeaddrinfo(result);
	return connection; 
}

// Destroy the connection
void mtrace::net::destroy_connection(connection_t* connection)
{
  if (!connection) return; 
  if (connection->active)
    shutdown(connection);
  delete connection; 
}

// Listen and accept a single connection 
signed mtrace::net::listen(connection_t* connection)
{
	// Listen for a connection request
  int listen_result = ::listen(
    connection->server_socket, SOMAXCONN); 
	if (listen_result == SOCKET_ERROR ||
      listen_result == WSAEINTR) 
	{
		cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
			std::endl;
    closesocket(connection->server_socket);
    WSACleanup();
    return -1;
	}
  // Accept a client socket
	connection->client_socket = 
		::accept(connection->server_socket, NULL, NULL);
	if (connection->client_socket == INVALID_SOCKET || 
      connection->client_socket == WSAEINTR) 
	{
		cerr() << "mtrace: accept failed: " << WSAGetLastError() <<
			std::endl;
    closesocket(connection->server_socket);
    WSACleanup();
    return -1;
	}

  int opt_val = 1; 
  int opt_len = sizeof(opt_val);

  int result = ::setsockopt(
    connection->client_socket, 
    SOL_SOCKET, 
    SO_KEEPALIVE, 
    (const char*) &opt_val, 
    opt_len);
  if (result == SOCKET_ERROR) 
  {
		cerr() << "mtrace: could not set keepalive socket option" <<
			std::endl;
		cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
			std::endl;
    ::closesocket(connection->server_socket);
    ::shutdown(connection->client_socket, SD_BOTH);
    ::closesocket(connection->client_socket); 
    WSACleanup();
    return -1;
  } 

  u_long nonblocking = 1;
  result = ioctlsocket(
    connection->client_socket, 
    FIONBIO,
    &nonblocking);
  if (result == SOCKET_ERROR)
  {
		cerr() << "mtrace: could not set socket to nonblocking" <<
			std::endl;
		cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
			std::endl;
    ::closesocket(connection->server_socket);
    ::shutdown(connection->client_socket, SD_BOTH);
    ::closesocket(connection->client_socket); 
    WSACleanup();
  }

  result = ::getsockopt(
    connection->client_socket, 
    SOL_SOCKET, 
    SO_RCVBUF, 
    (char*) &opt_val, 
    &opt_len);
  if (result == SOCKET_ERROR)
  {
		cerr() << "mtrace: could not set retrieve socket rcvbuf size " <<
			std::endl;
		cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
			std::endl;
    ::closesocket(connection->server_socket);
    ::shutdown(connection->client_socket, SD_BOTH);
    ::closesocket(connection->client_socket); 
    WSACleanup();
    return -1;
  }
  if (opt_val < 4<<20) 
  {
    opt_val = 4 << 20; 
    opt_len = sizeof(opt_val);
    result = ::setsockopt(
      connection->client_socket, 
      SOL_SOCKET, 
      SO_RCVBUF, 
      (const char*) &opt_val, 
      opt_len);
    if (result == SOCKET_ERROR) 
    {
      cerr() << "mtrace: could not set rcvbuf socket option" <<
        std::endl;
      cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
        std::endl;
      ::closesocket(connection->server_socket);
      ::shutdown(connection->client_socket, SD_BOTH);
      ::closesocket(connection->client_socket); 
      WSACleanup();
      return -1;
    } 
    result = ::getsockopt(
      connection->client_socket, 
      SOL_SOCKET, 
      SO_RCVBUF, 
      (char*) &opt_val, 
      &opt_len);
    opt_len = sizeof (opt_val);
    out(mtracedb::LOG_INFO) << "mtrace: set rcvbuf to " << opt_val << std::endl; 
  }

  opt_val = 1; 
  result = ::setsockopt(
    connection->client_socket, 
    IPPROTO_TCP, 
    TCP_NODELAY, 
    (char*) &opt_val, 
    opt_len);
  if (result == SOCKET_ERROR) 
  {
    cerr() << "mtrace: could not set tco nodelay option for socket" <<
      std::endl;
    cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
      std::endl;
    ::closesocket(connection->server_socket);
    ::shutdown(connection->client_socket, SD_BOTH);
    ::closesocket(connection->client_socket); 
    WSACleanup();
    return -1;
  } 
  
  tcp_keepalive keep_alive; 
  keep_alive.onoff = 1; 
  keep_alive.keepalivetime = 60;
  keep_alive.keepaliveinterval = 1;

  DWORD bytes_returned; 
  result =  WSAIoctl(
    connection->client_socket,
    SIO_KEEPALIVE_VALS,
    &keep_alive,
    sizeof(keep_alive),
    NULL,
    0,
    &bytes_returned,
    NULL, 
    NULL);
  if (result == SOCKET_ERROR)
  {
		cerr() << "mtrace: could not set keepalive values" <<
			std::endl;
		cerr() << "mtrace: error at listen(): " << WSAGetLastError() <<
			std::endl;
    ::closesocket(connection->server_socket);
    ::shutdown(connection->client_socket, SD_BOTH);
    ::closesocket(connection->client_socket); 
    WSACleanup();
    return -1; 
  }

  connection->active = true; 
  connection->io_pump = boost::thread(pump, connection);
	return 0; 
}

// Determine if the connection is active 
bool mtrace::net::is_active(connection_t* connection) 
{ return connection->active; }

// Send data over the network to the attached client 
std::size_t mtrace::net::send(
  connection_t* connection, 
  const void* buf, 
  size_t len)
{
  if (connection->active == false) return 0;
  connection->out_buffer.write((const char*)buf, len);
	return len;
}

// Recieve data from the attached client 
std::size_t mtrace::net::recv(
  connection_t* connection, 
  void* buf, 
  size_t len)
{
  std::size_t read = 0; 
  while (read < len) 
  {
    if (connection->active == false) break;
    read += connection->in_buffer.read(((char*)buf)+read, len-read);
  }
  return read; 
}

// Shutdown the connection 
signed mtrace::net::shutdown(connection_t* connection)
{
  connection->exit_request = true; 
  connection->io_pump.join();
	return 0;
}

// Retrieve the ipaddress of the current host
const char* mtrace::net::ipconfig()
{
  char hostName[255];
  gethostname(hostName, 255);

  hostent *host = gethostbyname(hostName);
  if (host && host->h_addr_list[0])
  { 
    struct in_addr addr;
    addr.s_addr = *(u_long *) host->h_addr_list[0];
    return inet_ntoa(addr);
  } 
  return NULL;
}


#endif defined(WIN32) || defined(WIN64)
