# XXX: this from __future__ import is just for the stuff in if __name__ == "__main__"
from __future__ import print_function
import struct
import uuid
from functools import partial
from tfplib import psycowrap as psyco

class InvalidPackage(Exception):
	pass

def read_compact_int(f):
	"""read a UT-Package compact int, can be from 1 to 5 bytes"""
	negative = False
	ret = 0
	for i in range(5):
		byte = struct.unpack("B", f.read(1))[0]
		if i == 0: # First Byte
			if byte & (1 << 7): # upper most bit specifies sign
				negative = True
			ret |= (byte & 0b111111) # last six bits
			if not byte & (1 << 6):
				break
		elif i == 4:
			ret |= (byte & 0b11111) << (6 + (3 * 7))
		else:
			ret |= (byte & 0b1111111) << (6 + ((i - 1) * 7))
			if not byte & (1 << 7):
				break
	if negative:
		ret *= -1
	return ret

psyco.bind(read_compact_int)

def read_name(f, version):
	"""read a null-terminated name, if version is recent enough it will go
	by the size byte at the start, otherwise just read until it hits a nul
	"""
	ret = ""
	if version < 64:
		byte = f.read(1)
		while byte != b"\x00":
			ret += byte
			byte = f.read(1)
	else:
		length = struct.unpack("B", f.read(1))[0]
		ret = f.read(length).rstrip(b"\x00")
	ret = ret.decode("ascii")
	return ret

psyco.bind(read_name)

class Name(str):
	def __new__(cls, s, flags):
		self = str.__new__(cls, s)
		self.flags = flags
		return self

class Import(object):
	def __init__(self, class_package, class_name, package_index, object_name, package_object):
		self.class_package = class_package
		self.class_name = class_name
		self._package_index = package_index
		self._package_object = package_object
		self.object_name = object_name

	def _get_package(self):
		return self._package_object.GetObjectRef(self._package_index)

	def __str__(self):
		return str((self.class_package, self.class_name, self._package_index, self.object_name))

	def __repr__(self):
		return str(self)

	package = property(_get_package)

class Export(object):
	def __init__(self, package, Class_index, super_index, group_index, object_name, object_flags, serial_size, serial_offset):
		self._package_object = package
		self.object_name = object_name
		self.object_flags = object_flags
		self._cached_serial = None
		self._class_index, self._super_index, self._group_index, self._serial_size, self._serial_offset = Class_index, super_index, group_index, serial_size, serial_offset

	def _get_class(self):
		return self._package_object.GetObjectRef(self._class_index)
	def _get_super(self):
		return self._package_object.GetObjectRef(self._super_index)
	def _get_group(self):
		return self._package_object.GetObjectRef(self._group_index)
	def _get_serial(self):
		if self._cached_serial != None:
			return self._cached_serial
		self._cached_serial = self._package_object.GetSerial(self._serial_size, self._serial_offset)
		return self._cached_serial

	class_ = property(_get_class)
	super = property(_get_super)
	group = property(_get_group)
	serial = property(_get_serial)


class Package(object):
	def __init__(self, f):
		# XXX: indeed, not using basestring here for 2.x is ugly, but meh
		if isinstance(f, str):
			f = open(f, "rb")
		self._file = f
		hformat = "IHHIIIIIII"
		hlen = struct.calcsize(hformat)
		headers = dict(zip(("Signature", "PackageVersion", "LicenseMode", "Flags", "Name Count", "Name Offset", "Export Count", "Export Offset", "Import Count", "Import Offset"), struct.unpack(hformat, f.read(hlen))))
		if headers["Signature"] != 0x9E2A83C1:
			raise InvalidPackage("Signature magic doesn't match")
		if headers["PackageVersion"] >= 68:
			headers["GUID"] = uuid.UUID(bytes=f.read(16))
		self.headers = headers
		self._names = None
		self._imports = None
		self._exports = None

	def _get_names(self):
		if self._names is not None:
			return self._names
		self._names = []
		origpos = self._file.tell()
		self._file.seek(self.headers["Name Offset"])
		for i in range(self.headers["Name Count"]):
			name = read_name(self._file, self.headers["PackageVersion"])
			flags = struct.unpack("I", self._file.read(4))[0]
			self._names.append(Name(name, flags))
		self._file.seek(origpos)
		return self._names

	psyco.bind(_get_names)

	def _get_imports(self):
		if self._imports is not None:
			return self._imports
		self._imports = []
		origpos = self._file.tell()
		self._file.seek(self.headers["Import Offset"])
		rf = partial(read_compact_int, self._file)
		nt = self.names
		for i in range(self.headers["Import Count"]):
			cp, cn, pindex, oname = nt[rf()], nt[rf()], struct.unpack("i", self._file.read(4))[0], nt[rf()]
			self._imports.append(Import(cp, cn, pindex, oname, self))
		self._file.seek(origpos)
		return self._imports

	def _get_exports(self):
		if self._exports is not None:
			return self._exports
		self._exports = []
		origpos = self._file.tell()
		self._file.seek(self.headers["Export Offset"])
		nt = self.names
		rf = partial(read_compact_int, self._file)
		for i in range(self.headers["Export Count"]):
			soffset = None
			clindex, sindex, gindex, oname, oflags, ssize = rf(), rf(), struct.unpack("i", self._file.read(4))[0], nt[rf()], struct.unpack("I", self._file.read(4))[0], rf()
			if ssize:
				soffset = rf()
			self._exports.append(Export(self, clindex, sindex, gindex, oname, oflags, ssize, soffset))
		self._file.seek(origpos)
		return self._exports

	psyco.bind(_get_exports)

	names = property(_get_names)
	imports = property(_get_imports)
	exports = property(_get_exports)

	def GetObjectRef(self, index):
		table = None
		if index < 0:
			table = self.imports
		elif index == 0:
			return None
		elif index > 0:
			table = self.exports
		return table[abs(index)-1]

	def GetSerial(self, size, offset):
		if not size:
			return ""
		origpos = self._file.tell()
		self._file.seek(offset)
		ret = self._file.read(size)
		self._file.seek(origpos)
		return ret

if __name__ == "__main__":
	p = Package("/opt/ut2004/System/Engine.u")
	print(len(p.names), p.names[7])
	print(len(p.exports))
	print(len(p.imports))

# vim: set noexpandtab:
