#include "StdAfx.h"
#include "TestSuiteLoader.h"
#include "TestSuite.h"
#include <CryUnitInterfaces.h>
#include <vcclr.h>
#include <vector>

using namespace System;
using namespace System::IO;
using namespace System::Collections::Generic;
using namespace CryUnitWrapper;

TestSuiteLoader::TestSuiteLoader(System::String^ baseDir, System::String^ dllName, bool loadSmokeTest)
    : m_library(NULL)
    , m_testEnumerator(NULL)
{
    m_library = LoadDll(baseDir, dllName);
    m_testEnumerator = GetTestEnumerator(loadSmokeTest);
    InitializeSuiteList();
}

void TestSuiteLoader::InitializeSuiteList()
{
    m_suiteList = gcnew List<TestSuite^>();

    m_testEnumerator->Initialise();
    int numberOfTestSuites = m_testEnumerator->NumberOfTestSuites();
    std::vector<CryUnit::ITestSuite*> testSuites;
    testSuites.resize(numberOfTestSuites);
    m_testEnumerator->EnumerateTestSuites(&testSuites[0]);

    for (int i = 0; i < numberOfTestSuites; i++)
    {				
        TestSuite^ testSuite = gcnew TestSuite(testSuites[i]);
        m_suiteList->Add(testSuite);
    }
}

CryUnit::ITestSuite::Type GetSuiteType(bool loadSmokeTest)
{
    return (loadSmokeTest) ? CryUnit::ITestSuite::SMOKE_TEST : CryUnit::ITestSuite::UNIT_TEST;
}

CryUnit::ITestEnumerator* TestSuiteLoader::GetTestEnumerator(bool loadSmokeTest)
{
    CryUnit::GetTestEnumerator getTestEnumerator = (CryUnit::GetTestEnumerator) ::GetProcAddress(m_library, "GetTestEnumerator");
    if (!getTestEnumerator)
    {
        throw gcnew Exception("Can't find GetTestEnumerator().");
    }

    CryUnit::ITestEnumerator* testEnumerator = getTestEnumerator(GetSuiteType(loadSmokeTest));
    if (!testEnumerator)
    {
        throw gcnew Exception("Can't create TestEnumerator.");
    }

    return testEnumerator;
}

HMODULE TestSuiteLoader::LoadDll(String^ basePath, String^ dllName)
{
    HMODULE library = NULL;

    String^ currentDir = System::Environment::CurrentDirectory;
    String^ dllPath = Path::Combine(basePath, dllName) + gcnew String(".dll");
    String^ filePath = System::IO::Path::GetFullPath(dllPath);
    String^ fileDir = System::IO::Path::GetDirectoryName(filePath);

    Console::WriteLine("Loading {0}", filePath);

    System::IO::Directory::SetCurrentDirectory(fileDir);

	if (!File::Exists(filePath))
	{
		throw gcnew Exception(gcnew String("Can't find ") + filePath);
	}

	pin_ptr<const wchar_t> dllPathStr = PtrToStringChars(filePath);
	library = ::LoadLibrary(dllPathStr);
	if (library == NULL)
	{
		throw gcnew Exception(gcnew String("Can't load ") + dllPath);
	}

    System::IO::Directory::SetCurrentDirectory(currentDir);

    return library;
}

TestSuiteLoader::~TestSuiteLoader()
{
    if (m_library)
        ::FreeLibrary(m_library);

    if (m_testEnumerator)
        m_testEnumerator->Shutdown();
}

ICollection<TestSuite^>^ TestSuiteLoader::SuiteList::get()
{
    return m_suiteList;
}