#include "stdafx.h"
#include "LST.h"

namespace
{
	struct LayerNodeInfo
	{
		u64 evBegin, evEnd;
		s32 szCnt, szReq, szCon, szGlo;
	};
}

AllocSet::AllocSet()
{
}

void AllocSet::Serialise(u8*& buffer, size_t& bufferLen)
{
	MakeImmutable();
	if (m_immutableAllocs.IsValid() && m_immutableAllocs->empty() == false)
	{
		buffer = reinterpret_cast<u8*>(&m_immutableAllocs->front());
		bufferLen = sizeof(ImmutableAllocSet::value_type) * m_immutableAllocs->size();
	}
	else
	{
		buffer = NULL;
		bufferLen = 0;
	}
}

void AllocSet::Deserialise(const u8* buffer, size_t bufferLen)
{
	m_mutableAllocs = NULL;
	m_immutableAllocs = new ImmutableAllocSet(bufferLen / sizeof(ImmutableAllocSet::value_type));

	memcpy(&m_immutableAllocs->front(), buffer, bufferLen);
}

void AllocSet::Compact()
{
	m_mutableAllocs = NULL;
	m_immutableAllocs = NULL;
}

template <typename InputIteratorT>
static void AddSets(AllocSet::MutableAllocSet& out, InputIteratorT rhs, InputIteratorT rhsEnd)
{
	typedef AllocSet::MutableAllocSet::iterator OutIterator;

	// For each alloc in the set - can be in any of the following states
	// * exists in both:
	//   if alloced in lhs, freed in rhs: erase
	//   if alloced in lhs, alloced in rhs: rhs
	//   if freed in lhs, alloced in rhs: rhs
	//   if freed in both: shouldn't really happen
	// * exists in neither (obviously redundant)
	// * exists in lhs but not rhs - take lhs - is a persistant alloc
	// * exists in rhs but not lhs - take rhs - is a new alloc
	OutIterator lhs = out.begin();

	for (; rhs != rhsEnd;)
	{
		if (lhs != out.end())
		{
			if (lhs->first == rhs->first)
			{
				if (!rhs->second.freed)
				{
					lhs->second = rhs->second;
					++ lhs;
				}
				else
				{
					assert (!lhs->second.freed);

					lhs = out.erase(lhs);
				}

				++ rhs;
			}
			else if (lhs->first < rhs->first)
			{
				// taking lhs
				++ lhs;
			}
			else
			{
				// taking rhs
				lhs = out.insert(*rhs).first;
				++ lhs;
				++ rhs;
			}
		}
		else
		{
			// taking rhs
			lhs = out.insert(*rhs).first;
			++ lhs;
			++ rhs;
		}
	}
}

void AllocSet::Add(const AllocSet& other)
{
	EnsureMutable();

	if (other.m_mutableAllocs.IsValid())
	{
		AddSets(*m_mutableAllocs, other.m_mutableAllocs->begin(), other.m_mutableAllocs->end());
	}
	else if (other.m_immutableAllocs.IsValid())
	{
		AddSets(*m_mutableAllocs, other.m_immutableAllocs->begin(), other.m_immutableAllocs->end());
	}
	else
	{
	}
}
	
void AllocSet::AddAlloc(TAddress ptr, const AllocInfo& ai)
{
	EnsureMutable();

	MutableAllocSet::iterator it = m_mutableAllocs->find(ptr);

	if (it != m_mutableAllocs->end())
		it->second = ai;
	else
		m_mutableAllocs->insert(std::make_pair(ptr, ai));
}

void AllocSet::FreeAlloc(TAddress ptr, const AllocInfo& ai)
{
	EnsureMutable();

	MutableAllocSet::iterator it = m_mutableAllocs->find(ptr);

	if (it != m_mutableAllocs->end())
		m_mutableAllocs->erase(it);
	else
	{
		AllocInfo dai(ai);
		dai.freed = 1;

		m_mutableAllocs->insert(std::make_pair(ptr, dai));
	}
}

void AllocSet::MakeMutable()
{
	m_mutableAllocs = new MutableAllocSet();

	if (m_immutableAllocs.IsValid())
	{
		for (ImmutableAllocSet::const_iterator it = m_immutableAllocs->begin(), itEnd = m_immutableAllocs->end();
			it != itEnd;
			++ it)
		{
			m_mutableAllocs->insert(*it);
		}

		m_immutableAllocs = NULL;
	}
}

void AllocSet::MakeImmutable()
{
	if (m_immutableAllocs.IsValid())
		return;
	if (!m_mutableAllocs.IsValid())
		return;

	m_immutableAllocs = new ImmutableAllocSet();
	m_immutableAllocs->reserve(m_mutableAllocs->size());

	std::copy(m_mutableAllocs->begin(), m_mutableAllocs->end(), std::back_inserter(*m_immutableAllocs));

	m_mutableAllocs = NULL;
}


AllocSetLST::AllocSetLST(const TCHAR* filename)
	: m_filename(filename)
	//, m_zip(NULL)
	, m_currentAllocEvIdx(0)
{
}

AllocSetLST::~AllocSetLST()
{
	/*
	if (m_zip)
		CloseZip(m_zip);
		*/
}

bool AllocSetLST::Restore()
{
#if 0
	typedef std::basic_string<TCHAR> tstring;
	typedef std::map<tstring, std::pair<int, int> > ZipEntryMap;

	assert (m_zip == NULL);

	m_zip = OpenZip(m_filename.c_str(), NULL);
	if (!m_zip)
		return false;

	ZipEntryMap zipEntries;

	for (int zidx = 0; ; ++ zidx)
	{
		ZIPENTRY ze;

		if (GetZipItem(m_zip, zidx, &ze) != ZR_OK)
			break;

		zipEntries.insert(std::make_pair(tstring(ze.name), std::make_pair(zidx, (int) ze.unc_size)));
	}

	// Load callstack table.

	{
		ZipEntryMap::iterator it = zipEntries.find(_T("ct_fn"));
		if (it != zipEntries.end())
		{
			m_callstackTable = new CallstackTable();

			std::vector<u8> data(it->second.second);
			UnzipItem(m_zip, it->second.first, &data[0], data.size());

			m_callstackTable->Deserialise(&data[0], data.size());
		}
	}

	// Load tree.

	{
		ZipEntryMap::iterator it = zipEntries.find(_T("tree"));
		if (it == zipEntries.end())
			return false;

		std::vector<u8> data(it->second.second);
		UnzipItem(m_zip, it->second.first, &data[0], data.size());

		const u8* raw = &data[0];

		u32 layerCount = *reinterpret_cast<const u32*>(raw);
		raw += sizeof(u32);

		const u32* layerLengths = reinterpret_cast<const u32*>(raw);
		raw += sizeof(u32) * layerCount;

		m_tree.reserve(layerCount);
		for (u32 i = 0; i < layerCount; ++ i)
		{
			m_tree.push_back(new std::vector<AllocSetHeader>());
			m_tree.back()->reserve(layerLengths[i]);
		}

		for (u32 i = 0; i < layerCount; ++ i)
		{
			std::vector<AllocSetHeader>& nodes = *m_tree[i];

			const LayerNodeInfo* nodeInfos = reinterpret_cast<const LayerNodeInfo*>(raw);
			raw += sizeof(LayerNodeInfo) * layerLengths[i];

			for (u32 ni = 0; ni < layerLengths[i]; ++ ni)
			{
				AllocSetHeader header(nodeInfos[ni].evBegin, nodeInfos[ni].evEnd);

				TCHAR fn[128];
				_stprintf_s(fn, 128, _T("lst_%u_%u"), i, ni);

				ZipEntryMap::iterator itHeader = zipEntries.find(fn);
				if (itHeader != zipEntries.end())
				{
					header.serialisedIndex = itHeader->second.first;
					header.serialisedSize = itHeader->second.second;
				}

				header.totalSize.count = nodeInfos[ni].szCnt;
				header.totalSize.requested = nodeInfos[ni].szReq;
				header.totalSize.consumed = nodeInfos[ni].szCon;
				header.totalSize.global = nodeInfos[ni].szGlo;

				nodes.push_back(header);
			}
		}
	}
#endif

	return true;
}

void AllocSetLST::ReplayBegin()
{
#if 0
	using std::swap;
	
	m_callstackTable = new CallstackTable();
	m_tree.clear();
	m_currentAllocEvIdx = 0;

	m_zip = CreateZip(m_filename.c_str(), NULL);

	m_tree.push_back(new std::vector<AllocSetHeader>());
	m_tree.front()->push_back(AllocSetHeader(0, SplitAlignment));

	{
		AllocIndex tmp;
		swap(m_allocs, tmp);
	}
#endif
}

void AllocSetLST::Replay(ReplayRange range)
{
	using namespace ReplayEventIds;

	Ids id;

	AllocSetHeader* activeSet = &m_tree.front()->back();

	while (range.ReadNext(id))
	{
		switch (id)
		{
		case RE_Alloc3:
			{
				const ReplayAlloc3Event& ev = range.Get<ReplayAlloc3Event>();

				SizeInfo sz(1, 0, ev.sizeRequested, ev.sizeConsumed, ev.sizeGlobal);

				size_t csId = m_callstackTable->AddCallstack(ev);

				AllocInfo ai(ev.threadId, csId, sz.requested, sz.consumed, sz.global);
				activeSet->set.AddAlloc(ev.ptr, ai);

				activeSet->totalSize += sz;

				m_allocs.insert(std::make_pair(ev.ptr, ai));

				++ m_currentAllocEvIdx;
				if ((m_currentAllocEvIdx & (SplitAlignment - 1)) == 0)
				{
					FinaliseTree();
					activeSet = &m_tree.front()->back();
				}
			}
			break;

		case RE_Free3:
			{
				const ReplayFree3Event& ev = range.Get<ReplayFree3Event>();
				ApplyFree(activeSet, ev.ptr, ev.sizeGlobal);
			}
			break;

		case RE_Free4:
			{
				const ReplayFree4Event& ev = range.Get<ReplayFree4Event>();
				ApplyFree(activeSet, ev.ptr, ev.sizeGlobal);
			}
			break;
		}
	}
}

void AllocSetLST::ApplyFree(AllocSetHeader*& activeSet, TAddress ptr, ptrdiff_t sizeGlobal)
{
	AllocIndex::iterator it = m_allocs.find(ptr);

	if (it != m_allocs.end())
	{
		activeSet->set.FreeAlloc(ptr, it->second);

		SizeInfo sz(0, 1, -it->second.sizeRequested, -it->second.sizeConsumed, sizeGlobal);
		activeSet->totalSize += sz;

		m_allocs.erase(it);
	}
	else
	{
		SizeInfo sz(0, 1, 0, 0, sizeGlobal);
		activeSet->totalSize += sz;
	}

	++ m_currentAllocEvIdx;
	if ((m_currentAllocEvIdx & (SplitAlignment - 1)) == 0)
	{
		FinaliseTree();
		activeSet = &m_tree.front()->back();
	}
}

void AllocSetLST::ReplayEnd(u64 position)
{
#if 0
	// Add empty trees until the pyramid is POT aligned.
	while (m_tree[0]->size() & (m_tree[0]->size() - 1))
	{
		FinaliseTree();
	}
	FinaliseTree();

	m_callstackTable->Serialise(m_zip, _T("ct_fn"));


	// Serialise the tree headers

	{
		size_t len = 
			sizeof(u32) // layer count
			+ sizeof(u32) * m_tree.size() // layer node counts
			;

		for (size_t layer = 0, layerCount = m_tree.size(); layer < layerCount; ++ layer)
			len += m_tree[layer]->size() * sizeof(LayerNodeInfo);

		std::vector<u8> raw(len);
		u8* rawEnd = &raw[0];

		*reinterpret_cast<u32*>(rawEnd) = m_tree.size();
		rawEnd += sizeof(u32);

		for (size_t layer = 0, layerCount = m_tree.size(); layer < layerCount; ++ layer)
		{
			*reinterpret_cast<u32*>(rawEnd) = m_tree[layer]->size();
			rawEnd += sizeof(u32);
		}

		for (size_t layer = 0, layerCount = m_tree.size(); layer < layerCount; ++ layer)
		{
			std::vector<AllocSetHeader>& nodes = *m_tree[layer];

			for (size_t node = 0, nodeCount = nodes.size(); node < nodeCount; ++ node)
			{
				LayerNodeInfo& info = *reinterpret_cast<LayerNodeInfo*>(rawEnd);
				rawEnd += sizeof(LayerNodeInfo);

				info.evBegin = nodes[node].allocEvBegin;
				info.evEnd = nodes[node].allocEvEnd;
				info.szCnt = nodes[node].totalSize.count;
				info.szReq = nodes[node].totalSize.requested;
				info.szCon = nodes[node].totalSize.consumed;
				info.szGlo = nodes[node].totalSize.global;
			}
		}

		ZipAdd(m_zip, _T("tree"), &raw[0], raw.size());
	}

	CloseZip(m_zip);
	m_zip = NULL;
#endif
}

void AllocSetLST::Add(AllocSet& set, u64 begin, u64 end)
{
#if 0
	begin &= ~ ((1<<SplitShift)-1);
	end &= ~ ((1<<SplitShift)-1);

	while (begin < end)
	{
		int layer = (int) m_tree.size() - 1;

		while (layer > 0)
		{
			u64 mask = (1 << (layer + SplitShift)) - 1;
			if (begin & mask)
			{
				-- layer;
				continue;
			}
			else
			{
				u64 algnBegin = (begin & ~mask) + (1ULL << (layer + SplitShift));
				if (algnBegin > end)
				{
					-- layer;
					continue;
				}
			}

			break;
		}

		u32 index = static_cast<u32>(begin >> (layer + SplitShift));
		u64 algnBegin = (index + 1) << (layer + SplitShift);

		AllocSetHeader& header = m_tree[layer]->at(index);

		if (header.serialisedSize > 0)
		{
			std::vector<u8> raw(header.serialisedSize);
			UnzipItem(m_zip, header.serialisedIndex, &raw[0], raw.size());
			header.set.Deserialise(&raw[0], raw.size());

			set.Add(header.set);

			header.set.Compact();
		}

		begin = algnBegin;
	}
#endif
}

void AllocSetLST::FinaliseTree()
{
	size_t layerId = 0;

	while ((m_tree[layerId]->size() & 0x1) == 0)
	{
		// Need to propogate the tree down the pyramid.
		if (m_tree.size() == layerId + 1)
		{
			m_tree.push_back(new std::vector<AllocSetHeader>());
		}

		std::vector<AllocSetHeader>& destLayer = *m_tree[layerId + 1];
		std::vector<AllocSetHeader>& sourceLayer = *m_tree[layerId];

		AllocSetHeader& lhs = *(sourceLayer.end() - 2);
		AllocSetHeader& rhs = *(sourceLayer.end() - 1);

		AllocSetHeader foldedHeader(lhs.allocEvBegin, rhs.allocEvEnd);
		foldedHeader.totalSize = lhs.totalSize + rhs.totalSize;

		foldedHeader.set.Add(lhs.set);
		foldedHeader.set.Add(rhs.set);

		Serialise(layerId, &lhs - &*sourceLayer.begin());
		Serialise(layerId, &rhs - &*sourceLayer.begin());

		destLayer.push_back(foldedHeader);

		++ layerId;
	}

	m_tree.front()->push_back(AllocSetHeader(
		m_tree.front()->back().allocEvEnd,
		m_tree.front()->back().allocEvEnd + SplitAlignment));
}

void AllocSetLST::Serialise(size_t layer, size_t index)
{
#if 0
	AllocSetHeader& s = m_tree[layer]->at(index);

	u8* buffer;
	size_t bufferLen;
	s.set.Serialise(buffer, bufferLen);

	TCHAR fn[256];
	_stprintf_s(fn, 256, "lst_%Iu_%Iu", layer, index);

	ZipAdd(m_zip, fn, buffer, bufferLen);

	s.set.Compact();
#endif
}
