#include "DeVirt.h"
#include "Util.h"

#include <iostream>
#include <sstream>
#include <fstream>
#include <map>
#include <set>
#include <memory>
#include <algorithm>
#include <cstring>
#include <cctype>
#include <cstdlib>
#include <ctime>
#include <cassert>
#include <cerrno>

#include <sys/types.h>
#include <sys/stat.h>
#if defined _MSC_VER
#include <windows.h>
//#define _OPENMP_NOFORCE_MANIFEST
#include <omp.h>
namespace
{
  int getpid() { return static_cast<int>(GetCurrentProcessId()); }
}
#else
#include <unistd.h>
#endif
bool stop = true;

extern DeVirt deVirt;

namespace
{
  std::string TempName(const char *name, const char *extension = "cpp")
  {
    const std::string &tempDir = deVirt.TempDir();
    std::ostringstream tempNameOut;

    tempNameOut << tempDir << "/devirt_";
    if (name != NULL && name[0] != 0)
      tempNameOut << name << '_';
    tempNameOut << getpid() << '.' << extension;
    return tempNameOut.str();
  }

  std::string TempName(const std::string &name, const char *extension = "cpp")
  {
    return TempName(name.c_str(), extension);
  }

  class TempFile
  {
    std::string m_FileName;

  public:
    TempFile() { }

    TempFile(const char *name, const char *extension = "cpp")
      : m_FileName(TempName(name, extension))
      { Util::ConvertPath(m_FileName); }

    TempFile(const std::string &name, const char *extension = "cpp")
      : m_FileName(TempName(name, extension))
      { Util::ConvertPath(m_FileName); }

    ~TempFile()
      {
#ifdef NDEBUG
        if (!m_FileName.empty())
          std::remove(m_FileName.c_str());
#endif
      }

    TempFile &operator= (const std::string &name)
      {
        m_FileName = name;
        return *this;
      }

    TempFile &operator= (const char *name)
      {
        m_FileName = name;
        return *this;
      }

    const char *c_str() const { return m_FileName.c_str(); }

    operator const std::string & () const { return m_FileName; }

    void Retain() { m_FileName.clear(); }
  };
}

std::string DeVirt::GenCompileCommand(
  const char *cppName,
  const char *asmName,
	const char *path
  ) const
{
  std::stringstream cmdOut;
  cmdOut << m_CompileCommand << " -iquote" << path << ' ' << cppName << " -o " << asmName;
  return cmdOut.str();
}

#if 0
// Code for matching a method signature against a flat list of mangled
// symbols.
//
// Currently unused.
namespace
{
  // Skip a mangled name component.
  //
  // The method returns a pointer to the character following the component
  // (which may be the end of the entire mangled name).
  //
  // If the mangled component is not recognized, the method returns NULL.
  const char *SkipMangledComponent(const char *p)
  {
    bool haveDelimiter = false;
    char buffer[8];

    while (std::isupper(p[0]))
    {
      if (p[0] == 'N')
      {
        if (haveDelimiter)
          return NULL;
        else
          haveDelimiter = true;
      }
      ++p;
    }
    if (std::islower(p[0]))
      return p + 1;
    if (!std::isdigit(p[0]))
      return NULL;
    buffer[0] = p[0];
    ++p;
    char *q = buffer + 1;
    while (q < buffer + sizeof buffer && std::isdigit(p[0]))
      *q++ = *p++;
    if (q == buffer + sizeof buffer)
      return NULL;
    *q = 0;
    int len = std::atoi(buffer);
    if (std::strlen(p) <= static_cast<size_t>(len))
      return NULL;
    p += len;
    if (p[0] == 'I')
    {
      ++p;
      while (p[0] != 'E')
      {
        p = SkipMangledComponent(p);
        if (p == NULL)
          return NULL;
      }
      ++p;
    }
    if (haveDelimiter)
    {
      while (p[0] != 'E')
      {
        p = SkipMangledComponent(p);
        if (p == NULL)
          return NULL;
      }
      ++p;
    }
    return p;
  }

  // Get the number of parameters from a mangled method name.
  //
  // The function returns -1 if the mangling is not recognized.
  int NumParamsFromMangledName(const std::string &mangledName)
  {
    const char *p = mangledName.c_str();

    if (p[0] != '_' || p[1] != 'Z')
      return -1;
    p += 2;
    p = SkipMangledComponent(p);
    if (p == NULL)
      return -1;
    unsigned numParams = 0;
    if (p[0] == 'v' && p[1] == 0)
      return 0;
    while (p[0] != 0)
    {
      p = SkipMangledComponent(p);
      if (p == NULL)
        return -1;
      numParams += 1;
    }
    return numParams;
  }

  // Find the mangled name for a method in a sorted vector of mangled names.
  const std::string &FindMangledName(
    const std::string &className,
    const Method *method,
    const std::vector<std::string> &symbolList
    )
  {
    static const std::string empty;
    const std::string &methodName = method->Name();
    std::stringstream symbolPrefixOut;

    symbolPrefixOut << "_ZN";
    if (method->IsConst())
      symbolPrefixOut << 'K';
    symbolPrefixOut << className.length() << className
                    << methodName.length() << methodName << 'E';
    const std::string &symbolPrefix = symbolPrefixOut.str();
    std::vector<std::string>::const_iterator it = std::lower_bound(
      symbolList.begin(), symbolList.end(), symbolPrefix);
    if (it == symbolList.end())
      return empty;
    const std::string &symbol = *it;
    if (symbol.length() < symbolPrefix.length()
        || symbol.compare(0, symbolPrefix.length(), symbolPrefix) != 0)
      return empty;

    // Now 'it' refers to the first symbol in a sequence of symbols matching
    // the specified prefix.
    std::vector<std::string>::const_iterator itNext = it;
    ++itNext;
    if (itNext == symbolList.end())
      return symbol;
    const std::string &symbolNext = *itNext;
    if (symbolNext.length() < symbolPrefix.length()
        || symbolNext.compare(0, symbolPrefix.length(), symbolPrefix) != 0)
      return symbol;

    // There are in fact multiple matches, so we're now in the unfortunate
    // situation where we have to figure out which symbol we need.
    std::vector<std::string> matchList;
    matchList.push_back(symbol);
    matchList.push_back(symbolNext);
    while (true)
    {
      ++itNext;
      if (itNext == symbolList.end())
        break;
      const std::string &symbolNext = *itNext;
      if (symbolNext.length() < symbolPrefix.length()
          || symbolNext.compare(0, symbolPrefix.length(), symbolPrefix) != 0)
        break;
      matchList.push_back(symbolNext);
    }
    size_t numParams = method->Params().size();
    const std::string *match = NULL;
    for (
      std::vector<std::string>::const_iterator
        it = matchList.begin(), itEnd = matchList.end();
      it != itEnd;
      ++it)
    {
      const std::string &symbol = *it;
      int symbolNumParams = NumParamsFromMangledName(symbol);
      if (symbolNumParams == -1)
      {
        DeVirt::Error(
          "internal error: "
          "can not determine number of parameters from mangled symbol '%s'",
          symbol.c_str());
        std::abort();
      }
      if (static_cast<size_t>(symbolNumParams) == numParams)
      {
        if (match != NULL)
        {
          DeVirt::Error(
            "can not deduce mangled name for method '%s'",
            method->QualifiedName().c_str());
          DeVirt::Error("candidate: %s", match->c_str());
          DeVirt::Error("candidate: %s", symbol.c_str());
          return empty;
        }
        else
          match = &symbol;
      }
    }
    if (match != NULL)
      return *match;
    else
      return empty;
  }
}
#endif

bool DeVirt::LoadMangledNames(
  const std::map<std::string, std::string> &symbolMap,
  File *file,
  bool reportMissing
  )
{
  bool missingMangledNames = false;

  for (
    std::vector<Iface *>::const_iterator
      it = file->m_IfaceList.begin(), itEnd = file->m_IfaceList.end();
    it != itEnd;
    ++it)
  {
    const Iface *iface = *it;
    if (iface->IsEmpty())
      continue;
    const std::vector<Method *> &methods = iface->Methods();
    for (
      std::vector<Method *>::const_iterator
        it = methods.begin(), itEnd = methods.end();
      it != itEnd;
      ++it)
    {
      Method *ifaceMethod = *it;
			if (!ifaceMethod->HaveImpl() || !ifaceMethod->DefaultImpl().empty() )
        continue;
      Method *implMethod = ifaceMethod->GetImplMethod();
      if (implMethod == NULL)
        continue;
      if (!implMethod->DefaultImpl().empty())
      {
        // Inline implementations are handled via wrappers.
        continue;
      }
      if (ifaceMethod->MangledName().empty())
      {
        const std::string &fullName = ifaceMethod->FullName();
        std::map<std::string, std::string>::const_iterator it
          = symbolMap.find(fullName);
        if (it == symbolMap.end())
        {
					if (reportMissing)
            Error(
              "can not find mangled name for method '%s' in ASM file",
              ifaceMethod->QualifiedName().c_str());
          missingMangledNames = true;
        }
        else
        {
          const std::string &mangledName = it->second;
          ifaceMethod->m_MangledName = mangledName;
        }
      }
      if (implMethod->MangledName().empty())
      {
        const std::string &fullName = implMethod->FullName();
        std::map<std::string, std::string>::const_iterator it
          = symbolMap.find(fullName);
        if (it == symbolMap.end())
        {
          if (reportMissing)
            Error(
              "can not find mangled name for method '%s' in ASM file",
              implMethod->QualifiedName().c_str());
          missingMangledNames = true;
        }
        else
        {
          const std::string &mangledName = it->second;
          implMethod->m_MangledName = mangledName;
        }
      }
    }
  }
  if (missingMangledNames)
    return false;
  return true;
}

bool DeVirt::LoadMangledNames(std::istream &in, File *file)
{
  std::map<std::string, std::string> symbolMap;

  while (!in.eof())
  {
    std::string symbol;
    in >> symbol;
    if (!symbol.empty() && in.get() == ' ')
    {
      char buffer[4096];
      in.getline(buffer, sizeof buffer);
      buffer[sizeof buffer - 1] = 0;
      char *q = buffer + std::strlen(buffer);
      while (q > buffer && std::isspace(q[-1]))
        --q;
      *q = 0;
      symbolMap[buffer] = symbol;
    }
  }
  LoadMangledNames(symbolMap, file, false);
  return true;
}

void DeVirt::SaveMangledNames(
  std::ostream &out,
  const File *file
  ) const
{
  for (        
    std::vector<Iface *>::const_iterator
      it = file->m_IfaceList.begin(), itEnd = file->m_IfaceList.end();
    it != itEnd;
    ++it)
  {
    const Iface *iface = *it;
    if (iface->IsEmpty())
      continue;
    const std::vector<Method *> &methods = iface->Methods();
    for (
			std::vector<Method *>::const_iterator
        it = methods.begin(), itEnd = methods.end();
			it != itEnd;
			++it)
    {
      Method *ifaceMethod = *it;
      if (!ifaceMethod->HaveImpl())
				continue;
      Method *implMethod = ifaceMethod->GetImplMethod();
      
	    if (implMethod == NULL)
				continue;
      
	    if (!implMethod->DefaultImpl().empty())
				continue;
      
	    if (!ifaceMethod->MangledName().empty())
				out
					<< ifaceMethod->MangledName() << ' '
					<< ifaceMethod->FullName() << std::endl;
      
	    if (!implMethod->MangledName().empty())
				out
					<< implMethod->MangledName() << ' '
					<< implMethod->FullName() << std::endl;
    }
  }
}

bool DeVirt::LoadMangledNamesFromAsm(std::istream &in, File *file)
{
  std::map<std::string, std::string> symbolMap;
  bool lineStart = true;
  bool stmtStart = true;

  // Build a sorted list of all symbols.
  std::string symbol;
  while (!in.eof())
  {
    int c = in.peek();
    if (lineStart && c == '_')
    {
      in >> symbol;
      if (!symbol.empty() && symbol[symbol.length() - 1] == ':')
        symbol.erase(symbol.length() - 1);
      lineStart = false;
      stmtStart = false;
      continue;
    }
    lineStart = c == '\n';
    if (stmtStart && !lineStart && c == '#')
    {
      in.ignore(1);
      if (in.get() == '@' && in.get() == '@')
      {
        char buffer[4096];
        in.getline(buffer, sizeof buffer);
        buffer[sizeof buffer - 1] = 0;
        if (std::memcmp(buffer, "NAME=", 5) == 0)
        {
          char *nameBegin = buffer + 5;
          char *nameEnd = nameBegin + std::strlen(nameBegin);
          while (nameEnd > nameBegin && std::isspace(nameEnd[-1]))
            --nameEnd;
          *nameEnd = 0;
          symbolMap[nameBegin] = symbol;
        }
      }
    }
    if (lineStart)
      stmtStart = true;
    else if (!std::isspace(c))
      stmtStart = false;
    in.ignore(1);
  }
  return LoadMangledNames(symbolMap, file, true);
}

std::string DeVirt::CacheFileName(const File *file) const
{
#if 0
  static const std::string empty;
  const std::string &cacheDir = m_CacheDir;
  struct stat sbuf;

  if (stat(cacheDir.c_str(), &sbuf) == 0)
  {
    if (!S_ISDIR(sbuf.st_mode))
    {
      Warning(
        "cache directory path '%s' is not a directory",
        cacheDir.c_str());
      return empty;
    }
  }
  else if (mkdir(cacheDir.c_str(), 0755) == -1)
  {
    Warning(
      "can not create cache directory '%s': %s",
      cacheDir.c_str(), std::strerror(errno));
    return empty;
  }
#endif

  std::string cacheFile = m_CacheDir + Util::PathSeperator + file->BaseName() + ".sym";
  return cacheFile;
}

void DeVirt::LoadMangledNamesFromCache(File *file)
{
  const std::string &cacheFile = CacheFileName(file);
  struct stat sbuf;

  if (cacheFile.empty())
    return;
  if (stat(cacheFile.c_str(), &sbuf) == -1)
    return;
  std::time_t cacheMtime = sbuf.st_mtime;
  if (stat(file->FileName().c_str(), &sbuf) == -1)
  {
    Error(
      "error loading symbol cache file '%s': "
      "can not access interface file '%s' for up-to-date check",
      cacheFile.c_str(), file->FileName().c_str());
    return;
  }
  std::time_t ifaceMtime = sbuf.st_mtime;
  if (std::difftime(cacheMtime, ifaceMtime) < 0.0)
  {
    // Cache file not up-to-date - ignored.
    return;
  }
  std::ifstream in(cacheFile.c_str());
  if (!in.is_open())
    return;
  LoadMangledNames(in, file);
}

void DeVirt::SaveMangledNamesToCache(const File *file) const
{
  const std::string &cacheFile = CacheFileName(file);

  if (cacheFile.empty())
    return;
  std::ofstream out(cacheFile.c_str());
  if (!out)
  {
    Error(
      "can not open symbol cache file '%s' for writing",
      cacheFile.c_str());
    return;
  }
  SaveMangledNames(out, file);
}

namespace
{
  std::ostream &WriteMethodBody(std::ostream &out, const Method *method)
  {
    out
      << "\n"
      << "{\n"
      << "#if defined(PS3) && !defined(CRYCG_CM)\n"
      << "  __asm__ __volatile__ (\n"
      << "    \"#@@NAME=" << method->FullName() << "\");\n"
      << "#endif\n"
      << "}\n\n";
    return out;
  }
}

// functor to compare a method against all in a container
struct MethodCompare
{
	MethodCompare(const Method *m) : m_Meth(m){}
	bool operator()( const Method *other ) const { return (*m_Meth) == (*other); } 
	const Method *m_Meth;
};

// struct to hold all class data for output
struct OutputClass
{
	OutputClass(){}
	OutputClass( const std::string &n) : name(n) {}
	std::string name;
	std::vector<std::string> bases;
	std::vector<Method*> methods;
};

typedef std::map< std::string, OutputClass> ClassMap;
typedef std::map< std::string, OutputClass>::iterator  ClassMapIter;


bool DeVirt::GenMangledNames()
{
  // Collect all relevant interface header files.
  std::vector<File *> files;
  for (
    std::list<File>::iterator
      it = m_FileList.begin(), itEnd = m_FileList.end();
    it != itEnd;
    ++it)
  {
    File &file = *it;
    if (file.GetType() != File::IFACE_FILE || file.m_IfaceList.empty())
      continue;

    // Load the mangled names from the cache.
    LoadMangledNamesFromCache(&file);

    bool fileRequired = false;
    for (
      std::vector<Iface *>::const_iterator
        it = file.m_IfaceList.begin(), itEnd = file.m_IfaceList.end();
      it != itEnd;
      ++it)
    {
      const Iface *iface = *it;
      if (iface->IsEmpty())
        continue;
      const std::vector<Method *> &methods = iface->Methods();
      for (
        std::vector<Method *>::const_iterator
          it = methods.begin(), itEnd = methods.end();
        it != itEnd;
        ++it)
      {
        Method *method = *it;
        if (!method->HaveImpl())
          continue;
        
        Method *implMethod = method->GetImplMethod();
        
				// don't need mangle names for functions which have default implementation
				if( !method->DefaultImpl().empty() )
					continue;

        if (implMethod == NULL)
          continue;

        if (!implMethod->DefaultImpl().empty())
        {
          // Inline implementations are handled via wrappers.
          continue;
        }

        if (method->MangledName().empty()
            && implMethod->MangledName().empty())
        {
          fileRequired = true;
          break;
        }
      }

      if (fileRequired)
        break;
    }
    if (fileRequired)
      files.push_back(&file);
  }

  if (files.empty())
    return true;

	int nErrors = 0;
	int nNumFiles = (int)files.size();
#if defined(_OPENMP)
	#pragma omp parallel for shared(nErrors)
#endif
	for( int i = 0 ; i < nNumFiles ; ++i )
	{
		if( nErrors == 0 )
			GenMangledNames( files[i], nErrors );
	}
	
	return nErrors == 0;
 
}


void DeVirt::GenMangledNames( File *file, int &nErrors )
{
	ClassMap classMap;

	TempFile tempFile(file->BaseName());
	std::remove(tempFile.c_str());
	std::ofstream tempOut(tempFile.c_str());
	if (!tempOut.is_open())
	{
		Error("can not open output file '%s' for writing", tempFile.c_str());
#if defined(_OPENMP)
		#pragma  omp atomic
#endif
		nErrors += 1;
		return;
	}

	std::vector<Method *> ifaceMethods;
	std::vector<Method *> implMethods;

	// Required includes - hard coded.
	tempOut
		<< "#include <CryModuleDefs.h>\n"
		<< "#define eCryModule eCryM_Launcher\n"
		<< "#include <platform.h>\n"
		<< "#include <stdlib.h>\n"
		<< "#include <smartptr.h>\n"
		<< "#include <Cry_Math.h>\n"
		<< "#include <Cry_XOptimise.h>\n"
		<< "#include <Cry_Geo.h>\n"
		<< "#include <ISerialize.h>\n"
		<< "\n";

	if (!WriteFile(tempOut, *file))
	{
#if defined(_OPENMP)
		#pragma  omp atomic
#endif
		nErrors += 1;
		return;
	}

	for (
		std::vector<Iface *>::const_iterator
		it = file->m_IfaceList.begin(), itEnd = file->m_IfaceList.end();
	it != itEnd;
	++it)
	{
		const Iface *iface = *it;
		if (iface->IsEmpty())
			continue;
		const IClass *iclass = iface->GetIClass();

		if( iclass == NULL )
			continue;

		ClassMapIter classIter = classMap.find( iclass->Name() );

		// don't add dummy for the not understandable case of interface and implementation in the same header
		if( iclass->GetFile()->FileName() != iface->GetFile()->FileName() )
		{
			//check if we have this class already			
			if( classIter == classMap.end() )
			{
				std::pair<ClassMapIter,bool> result;
				result = classMap.insert(std::make_pair(iclass->Name(),
					OutputClass(iclass->Name())));
				if( !result.second)
				{
					Error("could not add class %s", 
						iclass->Name().c_str());
					exit(1);
				}
				classIter = result.first;		
			}
			classIter->second.bases.push_back( iface->Name() );

		}

		// Write out a dummy declaration of the implementation class.  The dummy
		// declaration is derived from the interface.
		for (
			std::vector<Method *>::const_iterator
			it = iface->Methods().begin(), itEnd = iface->Methods().end();
		it != itEnd;
		++it)
		{
			Method *ifaceMethod = *it;				
			if (!ifaceMethod->HaveImpl())
				continue;

			Method *implMethod = ifaceMethod->GetImplMethod();
			if (implMethod == NULL)
				continue;
			if (!implMethod->DefaultImpl().empty() || !ifaceMethod->DefaultImpl().empty())
				continue;


			// check that we haven't already written this function
			if( std::find_if(implMethods.begin(), implMethods.end(),
				MethodCompare( implMethod ) ) !=
				implMethods.end() )
			{
				//           (*it)->WriteDecl(std::cout);
				//           implMethod->WriteDecl(std::cout);
				continue;
			}

			if (ifaceMethod->MangledName().empty() )
				ifaceMethods.push_back(ifaceMethod);

			if (implMethod->MangledName().empty())
			{
				if( iclass->GetFile()->FileName() != iface->GetFile()->FileName() )
					(*classIter).second.methods.push_back(implMethod);
				implMethods.push_back(implMethod);
			}

		}
	}

	// write implementation class dummy
	for ( ClassMapIter classIter = classMap.begin(); 
		classIter != classMap.end(); ++classIter )
	{
		bool first = true;
		tempOut << "class " << classIter->second.name << " : ";
		for( std::vector<std::string>::iterator it = classIter->second.bases.begin() ; it != classIter->second.bases.end() ; ++it )
		{
			if( !first ) tempOut << " , ";
			tempOut << " public " << (*it);
			first = false;
		}
		tempOut << "{\n";

		for( std::vector<Method*>::iterator it = classIter->second.methods.begin() ; it != classIter->second.methods.end() ; ++it )
		{
			(*it)->WriteDecl(tempOut);
		}

		tempOut << "\n};\n";
	}

	// Write out dummy implementation of interface and implementation
	// methods.
	for (
		std::vector<Method *>::const_iterator
		it = ifaceMethods.begin(), itEnd = ifaceMethods.end();
	it != itEnd;
	++it)
	{
		Method *method = *it;
		method->WriteDef(tempOut);
		WriteMethodBody(tempOut, method);
	}

	for (
		std::vector<Method *>::const_iterator
		it = implMethods.begin(), itEnd = implMethods.end();
	it != itEnd;
	++it)
	{
		Method *method = *it;
		method->WriteDef(tempOut);
		WriteMethodBody(tempOut, method);
	}

	tempOut.close();

	// Compile the temporarily file to ASM.
	const TempFile tempAsmFile(file->BaseName(), "s");
	std::string compileCommand = GenCompileCommand(tempFile, tempAsmFile, Util::GetPath( file->FileName() ) );
	int compileStatus = std::system(compileCommand.c_str());
	if (compileStatus != 0)
	{
#ifdef WIFSIGNALED
		if (WIFSIGNALED(compileStatus))
		{
			int signal = WTERMSIG(compileStatus);
			Error("compiler terminated on signal %d", signal);
		}
		else
		{
			int code = WEXITSTATUS(compileStatus);
			Error("compiler terminated with exit code %d", code);
		}
#else
		Error("compiler terminated with status code %d\nCompile cmd was: %s", compileStatus, compileCommand.c_str());
#endif
		// On error, retain the temp C++ file.
		tempFile.Retain();
#if defined(_OPENMP)
		#pragma  omp atomic
#endif
		nErrors += 1;
		return;
	}

	// Read back the ASM and locate the mangled names.
	assert(Util::CheckPath(tempAsmFile));
	std::ifstream in(tempAsmFile.c_str());
	if (!in)
	{
		Error(
			"can not open generated ASM file '%s' for reading",
			tempAsmFile.c_str());
#if defined(_OPENMP)
		#pragma  omp atomic
#endif
		nErrors += 1;
		return;
	}
	if (!LoadMangledNamesFromAsm(in, file))
	{
#if defined(_OPENMP)
		#pragma  omp atomic
#endif
		nErrors += 1;
		return;
	}
	SaveMangledNamesToCache(file);

	return;
}
