# see http://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B_Name_Mangling

class StupidDict(dict): pass

calling_conventions = StupidDict({
	"A": "__cdecl",
	"B": "__cdecl", # according to wine's source (http://source.winehq.org/source/dlls/msvcrt/undname.c#L561) - verify?
	"G": "__stdcall",
	"I": "__fastcall",
	"E": "__thiscall",
})

opcodes = {
	"?0": "(constructor)",
	"?1": "(destructor)",
	"?2": "(new)",
	"?3": "(delete)",
	"?4": "=",
	"?5": ">>",
	"?6": "<<",
	"?7": "!",
	"?8": "==",
	"?9": "!=",
	"?A": "[]",
	"?B": "(conversion)",
	"?C": "->",
	"?D": "*",
	"?E": "++",
	"?F": "--",
	"?G": "-",
	"?H": "+",
	"?I": "&",
	"?J": "->*",
	"?K": "/",
	"?L": "%",
	"?M": "<",
	"?N": "<=",
	"?O": ">",
	"?P": ">=",
	"?Q": ",",
	"?R": "(call)",
	"?S": "~",
	"?T": "^",
	"?U": "|",
	"?V": "&&",
	"?W": "||",
	"?X": "*=",
	"?Y": "+=",
	"?Z": "-=",
	"?_0": "/=",
	"?_1": "%=",
	"?_2": ">>=",
	"?_3": "<<=",
	"?_4": "&=",
	"?_5": "|=",
	"?_6": "^=",
	"?_7": "(vtable)",
	"?_C": "(unknown)",
	"?_F": "(default constructor closure)",
	"?_O": "(unknown)",
	"?_R0?AV": "(RTTI Type Descriptor)",
	"?_R1A@?OA@EA@": "(RTTI Base Class Descriptor at (0,-1,0,64))",
	"?_R2": "(RTTI Base Class Array)",
	"?_R3": "(RTTI Class Hierachy Descriptor)",
	"?_R4": "(RTTI Complete Object Locator)",
	"?_U": "(new[])",
	"?_V": "(delete[])",
}

# XXX: ugly, anyway to do this without typechecking or similar?

def getshort_template(name):
	"""returns the short form of a template object if `name` is a template
	for example when given Foo<class bar> returns Foo, if `name` isn't a template
	it itself is returned"""
	if isinstance(name, Template):
		return name.name
	else:
		return name

opformatters = {
	"(constructor)": lambda op: "%s::%s" % (op.name, getshort_template(op.name[-1])),
	"(destructor)": lambda op: "%s::~%s" % (op.name, getshort_template(op.name[-1])),
	"(new)": lambda op: "%s::operator new" % op.name,
	"(delete)": lambda op: "%s::operator delete" % op.name,
	"(conversion)": lambda op: "%s::(conversion)" % op.name,
	"(call)": lambda op: "%s()()" % op.name,
	"(unknown)": lambda op: "(unknown operator %s for %s)" % (op.opcode, op.name),
	"(new[])": lambda op: "%s::operator new[]" % op.name,
	"(delete[])": lambda op: "%s::operator delete[]" % op.name,
}

# XXX: according to wine there's a "thunk" type, figure out what it is,
# http://source.winehq.org/source/dlls/msvcrt/undname.c#L917
typecodes = {
	"A": {"visibility": "private", "type": "method", "special": None},
	"C": {"visibility": "private", "type": "method", "special": "static"},
	"E": {"visibility": "private", "type": "method", "special": "virtual"},
	"I": {"visibility": "protected", "type": "method", "special": None},
	"K": {"visibility": "protected", "type": "method", "special": "static"},
	"M": {"visibility": "protected", "type": "method", "special": "virtual"},
	"Q": {"visibility": "public", "type": "method", "special": None},
	"S": {"visibility": "public", "type": "method", "special": "static"},
	"U": {"visibility": "public", "type": "method", "special": "virtual"},
	"Y": {"type": "function"},
	"0": {"type": "data", "name": 0},
	"1": {"type": "data", "name": 1},
	"2": {"type": "data", "name": 2},
	"3": {"type": "data", "name": 3},
	"6": {"type": "compiler_data", "name": 6},
}

datatypecodes = frozenset(("0", "1", "2", "3"))
compiler_types = frozenset(("6",))
functypecodes = frozenset(typecodes) - datatypecodes - compiler_types

sdatatypes = {
	"C": "signed char",
	"D": "char",
	"E": "unsigned char",
	"F": "short",
	"G": "unsigned short",
	"H": "int",
	"I": "unsigned int",
	"J": "long",
	"K": "unsigned long",
	"M": "float",
	"N": "double",
	"O": "long double",
	"_J": "__int64",
	"_K": "unsigned __int64",
	"_N": "bool",
	"_W": "wchar_t",
}

backrefable_sdatatypes = set(("_J", "_K", "_N", "_W"))

namedatatypes = {
	"T": "union",
	"U": "struct",
	"V": "class",
	"W4": "enum", # 4-byte
}

# ReferenceOrPointer in the BNF Syntax
ptrtypes = {
	"A": (None, "&"),
	"P": (None, "*"),
	"Q": ("const", "*"),
	"R": ("volatile", "*"),
}

# Pointer in the BNF Syntax
# Note this conflicts with the "A|B|C" syntax, but it looks like the
# second char of each of these doesn't conflict with SimpleDataType/X
# this might change in the future so be weary of breakage
ptrtypes2 = {
	"AP": (None, "*"),
	"BQ": ("const", "*"),
	"CR": ("volatile", "*"),
}

storageclasses = {
	"A": "Normal",
	"B": "Volatile",
	"C": "const",
	"Z": "Executable",
}

# XXX: modifiers: not mentioned on the wikipedia page
modifiers = {
	"A": None,
	"B": "const",
	"C": "volatile",
	"D": "const volatile",
}

digits = {
	"A": "0",
	"B": "1",
	"C": "2",
	"D": "3",
	"E": "4",
	"F": "5",
	"G": "6",
	"H": "7",
	"I": "8",
	"J": "9",
	"K": "A",
	"L": "B",
	"M": "C",
	"N": "D",
	"O": "E",
	"P": "F",
}

import re
import string
import collections

def collapse_name(name, base):
	"""collapses the given name in-place based on `base`"""
	# make sure base is a name and not an Operator or etc.
	while isinstance(base, Reference):
		base = base.to
	while isinstance(name, Reference):
		name = name.to
	if isinstance(name, Name):
		for i in xrange(len(name)-1):
			if base[i] == name[0]:
				del name[0]
		# extra pass to remove stuff in containers
		for part in name:
			if hasattr(part, "collapse"):
				part.collapse(base)

def collapsed_name(name, base):
	"""returns a new collapsed name based on `base`"""
	import copy
	n = copy.deepcopy(name)
	collapse_name(n, base)
	return n

class Args(list):
	def collapse(self, base):
		"""modifies self in-place collapsing the names in it based on `base`"""
		for arg in self:
			collapse_name(arg, base)

	def get_collapsed(self, base):
		"""returns a new Args instance with names collapsed based on `base`"""
		import copy
		n = copy.deepcopy(self)
		n.collapse(base)
		return n

	def __str__(self):
		return ", ".join(str(arg) for arg in self)

class Reference(object):
	"""base Reference class, anything that is mostly just a 1:1 container with special formatting should derive from this"""
	def __init__(self, to):
		self.to = to

class Name(list):
	def __str__(self):
		return "::".join(str(s) for s in self)

class Template(object):
	def __init__(self, name, args):
		self.name = name
		self.args = args

	def __str__(self):
		return "%s<%s>" % (self.name, self.args)

	def collapse(self, base):
		self.args.collapse(base)

class Operator(Reference):
	def __init__(self, name, opcode):
		Reference.__init__(self, name)
		self.name = name
		self.opcode = opcode
		self.opstr = opcodes[opcode]

	def __str__(self):
		if self.opstr in opformatters:
			return opformatters[self.opstr](self)
		return "%s::operator%s" % (self.name, self.opstr)

class NamedDataType(Reference):
	def __init__(self, typecode, name):
		Reference.__init__(self, name)
		self.typecode = typecode
		self.typestr = namedatatypes[typecode]
		self.name = name

	def __str__(self):
		return "%s %s" % (self.typestr, self.name)

class ModifiedDataType(Reference):
	def __init__(self, type, modifier):
		Reference.__init__(self, type)
		self.type = type
		self.modifier = modifier

	def __str__(self):
		return "%s %s" % (self.modifier, self.type)

class Pointer(Reference):
	def __init__(self, to, types):
		Reference.__init__(self, to)
		self.types = types

	def __str__(self):
		# XXX: what's the proper pointer syntax? - Think I got it right now (http://articles.techrepublic.com.com/5100-22-1052161.html)
		# XXX: should these be reversed here or when constructing the object?
		types = reversed(self.types)
		pairs = ["%s%s " % (ptrtype, modifier) if modifier else "%s" % ptrtype for modifier, ptrtype in types]
		#prefixes = [t[0] for t in self.types if t[0]]
		#ptrs = reversed([t[1] for t in self.types])
		return "%s %s" % (self.to, "".join(pairs).rstrip())

class Array(Reference):
	def __init__(self, of, dimensions):
		Reference.__init__(self, of)
		self.of = of
		self.dimensions = dimensions

	def __str__(self):
		return "%s %s" % (self.of, "".join("[%d]" % dim for dim in self.dimensions))

class Function(object):
	def __init__(self, name, type, callingconv, ret_type, args, storage):
		self.name = name
		self.type = type
		self.callingconv = callingconv
		self.ret_type = ret_type
		self.args = args
		self.storage = storage

	def __str__(self):
		# XXX: calling convention, storage class, static/visibility?
		# XXX: do collapsing, should probably be done optionally in __format__ later
		args = self.args.get_collapsed(self.name)
		ret_type = collapsed_name(self.ret_type, self.name)
		return "%s %s(%s)" % (ret_type, self.name, args)

class Method(Function):
	def __init__(self, name, type, modifier, callingconv, ret_type, args, storage):
		Function.__init__(self, name, type, callingconv, ret_type, args, storage)
		self.modifier = modifier

	def __str__(self):
		#args = ", ".join(str(arg) for arg in self.args)
		postfix = " %s" % self.modifier if self.modifier else ""
		prefix = "%s " % self.type["special"] if self.type["special"] else ""
		return "%s%s%s" % (prefix, Function.__str__(self), postfix)

class Variable(object):
	def __init__(self, name, type, datatype, storage):
		self.name = name
		self.type = type
		self.datatype = datatype
		self.storage = storage

	def __str__(self):
		return "%s %s" % (self.datatype, self.name)

class CompilerData(object):
	def __init__(self, name, type, modifier, extname):
		self.name = name
		self.type = type
		self.modifier = modifier
		self.extname = extname

	def __str__(self):
		prefix = "%s " % self.modifier if self.modifier else ""
		return "%s%s" % (prefix, self.name)

class FuncPointer(Pointer):
	def __str__(self):
		# TODO: is there any reasonable case where you'd have a pointer to a func pointer?
		# XXX: include Calling Convention?
		ptrs = [t[1] for t in self.types]
		args = self.to.args
		return "%s (%s%s)(%s)" % (self.to.ret_type, "".join(ptrs), self.to.name, args)

def get_name(s, backrefs, reverse=True):
	nbackrefs = backrefs["name"]
	validc = r"[a-zA-Z_]\w*"
	identifier = re.compile(r"(?P<backref>\d)|(?P<name>%s)@" % validc)
	def identname(match):
		"""returns (should_to_backrefs, name)"""
		gd = match.groupdict()
		name = None
		backref = gd.get("backref", None)
		if backref:
			name = nbackrefs[int(backref)]
		else:
			name = gd["name"]
		return (False if backref else True, name)
	curr = s
	identifiers = []
	while True:
		if curr[:2] == "?$":
			curr = curr[2:]
			unqmatch = identifier.match(curr)
			add, unQualifiedName = identname(unqmatch)
			if add:
				nbackrefs.append(unQualifiedName)
			curr = curr[unqmatch.end():]
			tbr = collections.defaultdict(list)
			(args, curr) = get_arglist(curr, tbr)
			(name, curr) = get_name(curr, backrefs, reverse=False)
			identifiers.append(Template(unQualifiedName, args))
			identifiers.extend(name)
			continue
		m = identifier.match(curr)
		if not m:
			break
		add, name = identname(m)
		if add:
			nbackrefs.append(name)
		identifiers.append(name)
		curr = curr[m.end():]
	if reverse:
		name = list(reversed(identifiers))
	else:
		name = identifiers
	return (Name(name), curr)

sdatatype_p = "|".join(re.escape(dt) for dt in sdatatypes)
ndatatype_p = "|".join(re.escape(dt) for dt in namedatatypes)
simpledtype_re = re.compile(r"(?P<simple>%s)|(?P<named>%s)" % (sdatatype_p, ndatatype_p))

rop_re = re.compile("|".join(re.escape(ptype) for ptype in ptrtypes))
ptrref_re = re.compile("|".join(re.escape(ptype) for ptype in ptrtypes2))

sane_y_re = re.compile(r"Y\d[%s]" % re.escape("".join(digits)))

class InvalidDataType(Exception):
	def __init__(self, s):
		self.s = s

	def __str__(self):
		return "Invalid Data Type at the start of %r" % self.s

class CantParse(Exception):
	def __init__(self, symbol, scheme):
		self.symbol = symbol
		self.scheme = scheme

	def __str__(self):
		return "Can't parse %r with the %s mangling scheme" % (self.symbol, self.scheme)

def get_number(curr, backrefs):
	nstr, sep, curr = curr.partition("@")
	return (int("".join(digits[c] for c in nstr), 16), curr)

def get_datatype(s, backrefs, record=True):
	pbackrefs = backrefs["pointer"]
	curr = s
	# SimpleDataType stuff first
	m = simpledtype_re.match(curr)
	if m:
		gd = m.groupdict()
		curr = curr[m.end():]
		assert gd.get("simple") or gd.get("named")
		if gd.get("simple"):
			dtype = sdatatypes[gd["simple"]]
			if record:
				pbackrefs.append(dtype)
			return (dtype, curr)
		elif gd.get("named"):
			name, curr = get_name(curr, backrefs)
			assert curr[0] == "@"
			curr = curr[1:]
			dtype = NamedDataType(gd["named"], Name(name))
			if record:
				pbackrefs.append(dtype)
			return (dtype, curr)
	# XXX: Pointers?
	m = rop_re.match(curr)
	if m:
		types = [ptrtypes[m.group(0)]]
		curr = curr[m.end():]
		while True:
			m = ptrref_re.match(curr)
			if not m:
				break
			types.append(ptrtypes2[m.group(0)])
			curr = curr[m.end():]
		if curr[0] in ("A", "B", "C"):
			# XXX: looks like A == normal, B == const, C == volatile
			modifier = modifiers[curr[0]]
			curr = curr[1:]
			if curr[0] == "X":
				base, curr = "void", curr[1:]
			else:
				base, curr = get_datatype(curr, backrefs, False)
				if modifier:
					base = ModifiedDataType(base, modifier)
			dtype = Pointer(base, types)
		elif curr[0] in ("6", "8"):
			curr = curr[1:]
			fdict, curr = get_functype(curr, backrefs)
			# XXX: anonymous and no typecode, what to do?
			func = Function("", typecodes["Y"], fdict["cconv"], fdict["ret"], fdict["args"], fdict["storage_class"])
			dtype = FuncPointer(func, types)
		pbackrefs.append(dtype)
		return dtype, curr
	if curr[0] in string.digits:
		dtype = pbackrefs[int(curr[0])]
		curr = curr[1:]
		return dtype, curr
	# XXX: wtf, looks like there's a second syntax for Y: Y<amount><firstnum in base10>$$C<modifier><type>
	if sane_y_re.match(curr):
		curr = curr[1:]
		amount = int(curr[0]) + 1
		curr = curr[1:]
		dimensions = []
		for i in xrange(amount):
			dimension, curr = get_number(curr, backrefs)
			dimensions.append(dimension)
		base, curr = get_datatype(curr, backrefs)
		dtype = Array(base, dimensions)
		return (dtype, curr)
	raise InvalidDataType(curr)

# this is in a separate function than get_datatype because it seems it only applies to the return value
def get_modifieddatatype(s, backrefs):
	curr = s
	assert curr[0] == "?"
	curr = curr[1:]
	modifier = modifiers[curr[0]]
	curr = curr[1:]
	dtype, curr = get_datatype(curr, backrefs)
	if modifier:
		dtype = ModifiedDataType(dtype, modifier)
	return (dtype, curr)

def get_arglist(s, backrefs):
	if s[0] == "X":
		return (Args(), s[1:])
	curr = s
	args = Args()
	while True:
		if curr[0] == "@" or curr[0] == "Z":
			curr = curr[1:]
			break
		try:
			dt, curr = get_datatype(curr, backrefs)
		except InvalidDataType:
			break
		args.append(dt)
		#m = simpledtype_re.match(curr)
		#if not m:
		#	break
		#gd = m.groupdict()
		#curr = curr[m.end():]
		#assert gd.get("simple") or gd.get("named")
		#if gd.get("simple"):
		#	args.append(sdatatypes[gd["simple"]])
		#elif gd.get("named"):
		#	name, curr = get_name(curr, backrefs)
		#	assert curr[0] == "@"
		#	curr = curr[1:]
		#	args.append(NamedDataType(gd["named"], Name(name)))
	return (args, curr)

ftypecode_p = "|".join(re.escape(typecode) for typecode in functypecodes)
dtypecode_p = "|".join(re.escape(typecode) for typecode in datatypecodes)
ctypecode_p = "|".join(re.escape(typecode) for typecode in compiler_types)
typecodes_re = re.compile(r"(?P<function>%s)|(?P<datatype>%s)|(?P<compilertype>%s)" % (ftypecode_p, dtypecode_p, ctypecode_p))

def get_typecode(s, backrefs):
	m = typecodes_re.match(s)
	curr = s[m.end():]
	gd = m.groupdict()
	assert gd["function"] or gd["datatype"] or gd["compilertype"]
	typecode_d = {}
	if gd["function"]:
		typecode_d["type"] = typecodes[gd["function"]]
		if typecode_d["type"]["type"] == "method":
			if typecode_d["type"]["special"] not in ("static", "thunk"):
				# read the modifier
				typecode_d["modifier"] = modifiers[curr[0]]
				curr = curr[1:]
			else:
				# static method, has no modifier, fake it though
				typecode_d["modifier"] = None
		functype, curr = get_functype(curr, backrefs)
		typecode_d.update(functype)
	elif gd["datatype"]:
		typecode_d["type"] = typecodes[gd["datatype"]]
		datatype, curr = get_datatype(curr, backrefs)
		sclass = storageclasses[curr[0]]
		curr = curr[1:]
		typecode_d["datatype"] = datatype
		typecode_d["storage_class"] = sclass
	elif gd["compilertype"]:
		typecode_d["type"] = typecodes[gd["compilertype"]]
		typecode_d["modifier"] = modifiers[curr[0]]
		curr = curr[1:]
		typecode_d["name"], curr = get_name(curr, backrefs)
	return typecode_d, curr

callingconvre = re.compile("|".join(re.escape(convention) for convention in calling_conventions))

def get_functype(s, backrefs):
	m = callingconvre.match(s)
	curr = s[m.end():]
	cconv = calling_conventions[m.group(0)]
	ret = None
	if curr[0] in ("X", "@"):
		curr = curr[1:]
		ret = "void"
	elif curr[0] == "?":
		ret, curr = get_modifieddatatype(curr, backrefs)
	if ret is None:
		ret, curr = get_datatype(curr, backrefs)
	args, curr = get_arglist(curr, backrefs)
	sclass = storageclasses[curr[0]]
	curr = curr[1:]
	ftype_dict = {
		"cconv": cconv,
		"ret": ret,
		"args": args,
		"storage_class": sclass,
	}
	return ftype_dict, curr

opre = re.compile("|".join(re.escape(opcode) for opcode in opcodes))

def parse_mangledsymbol(s):
	backrefs = collections.defaultdict(list)
	if s[0] != "?":
		# XXX: Except here?
		raise CantParse(s, "MSVC++")
	curr = s[1:]
	opmatch = opre.match(curr)
	if opmatch:
		curr = curr[opmatch.end():]
	name, curr = get_name(curr, backrefs)
	if opmatch:
		name = Operator(name, opmatch.group(0))
	assert curr[0] == "@"
	curr = curr[1:]
	typecode, curr = get_typecode(curr, backrefs)
	type = typecode["type"]["type"]
	if type == "data":
		ret = Variable(name, typecode["type"], typecode["datatype"], typecode["storage_class"])
	elif type == "method":
		ret = Method(name, typecode["type"], typecode["modifier"], typecode["cconv"], typecode["ret"], typecode["args"], typecode["storage_class"])
	elif type == "compiler_data":
		ret = CompilerData(name, typecode["type"], typecode["modifier"], typecode["name"])
	elif type == "function":
		ret = Function(name, typecode["type"], typecode["cconv"], typecode["ret"], typecode["args"], typecode["storage_class"])
	return ret

# Just do the basic Separation of parts with pyparsing, don't try to
# understand it with it (too hard, it seems to be against me whatever
# I try)
#from pyparsing import Or, Literal, Word, OneOrMore, alphas, Regex, Optional, ZeroOrMore, Group
#from string import digits
#ops = Or([Literal(op) for op in opcodes])("opcode")
#del op

#validc = Regex(r"[a-zA-Z_]\w*")

#anydigit = Or([Literal(digit) for digit in digits])
#backref = anydigit
#identifier = (validc + "@") ^ backref
# XXX: template handling, verify this is correct
#name = ("?$" + identifier + (lambda: args)() + OneOrMore(identifier)) ^ OneOrMore(identifier)

#cconv = Or([Literal(convention) for convention in calling_conventions])
#del convention

#ftypecode = Or([Literal(typecode) for typecode in functypecodes])
#datatypecode = Or([Literal(typecode) for typecode in datatypecodes])
#del typecode

#simpledtype = Or([Literal(dtype) for dtype in sdatatypes])
#namedtype = Or([Literal(dtype) for dtype in namedatatypes])
#simpletype = simpledtype ^ (namedtype + name + "@")

#rop = Or([Literal(ptype) for ptype in ptrtypes]) # reference or pointer
#pointerrefs = ZeroOrMore(Or([Literal(ptype) for ptype in ptrtypes2]))
#del ptype
#ptrref = anydigit
# XXX: FunctionType forms of ptrtype
#ptrtype = (rop + pointerrefs + (Literal("A") ^ "B" ^ "C") + (simpletype ^ "X")) ^ (ptrref)

#dtype = simpletype ^ ptrtype

#storageclass = Or([Literal(stype) for stype in storageclasses])
#del stype


#args = Group("X" ^ (OneOrMore(dtype) + Optional(Literal("@") ^ "Z")))
# Note: looks like ReturnValue can be @ for constructors atleast
#functype = cconv + (Literal("X") ^ "@" ^ dtype)  + args + storageclass

#typecode = (ftypecode + functype) ^ (datatypecode + dtype + storageclass)

#mangledsymbol = "?" + Optional(ops) + name.copy()("name") + "@" + typecode.copy()("TypeCode")

tests = [
	"?xorWith@BitSet@xercesc_2_8@@QAEXABV12@@Z",
	"??0ASCIIRangeFactory@xercesc_2_8@@QAE@XZ",
	"??0?$XMLHolder@U_RTL_CRITICAL_SECTION@@@xercesc_2_8@@QAE@XZ",
	"??0ArrayIndexOutOfBoundsException@xercesc_2_8@@QAE@QBDIW4Codes@XMLExcepts@1@QB_W222PAVMemoryManager@1@@Z"
]

def reorder_ident(parsetree):
	# kill @'s
	inreverse = parsetree[::2]
	# reverse to get (namespace, namespace, class, whatever, etc.)
	return tuple(reversed(inreverse))

