#!/usr/bin/env python3

import sys

bits = int(sys.argv[1])
n = int(sys.argv[2])
impl = int(sys.argv[3])
implname = sys.argv[4]
implcompiler = sys.argv[5]
binary = sys.argv[6]
littleendian = {'littleendian':True,'bigendian':False}[sys.argv[7]]
libs = sys.argv[8:]

import angr
import claripy

add_options = {
  angr.options.LAZY_SOLVES,
  angr.options.SYMBOLIC_WRITE_ADDRESSES,
  angr.options.CONSERVATIVE_READ_STRATEGY,
  angr.options.CONSERVATIVE_WRITE_STRATEGY,
  angr.options.SYMBOL_FILL_UNCONSTRAINED_MEMORY,
  angr.options.SYMBOL_FILL_UNCONSTRAINED_REGISTERS,
}
remove_options = {
  angr.options.SIMPLIFY_CONSTRAINTS,
  angr.options.SIMPLIFY_EXPRS,
  angr.options.SIMPLIFY_MEMORY_READS,
  angr.options.SIMPLIFY_MEMORY_WRITES,
  angr.options.SIMPLIFY_REGISTER_READS,
  angr.options.SIMPLIFY_REGISTER_WRITES,
  angr.options.SIMPLIFY_RETS,
}

claripy.simplifications.extract_distributable = {}
del claripy.simplifications.simpleton._simplifiers['__xor__']

sys.setrecursionlimit(1000000)

# ===== patch cpuid to enable avx2
# XXX: should share with outsim; in any case keep synchronized

prevcpuid = angr.engines.vex.heavy.dirty.CORRECT_amd64g_dirtyhelper_CPUID_avx_and_cx16

def cpuid(state,_):
  eax = state.regs.rax[31:0]
  prevcpuid(state,_)
  # substitute some haswell cpuid data:
  state.registers.store('rax',0x000306c3,size=8,condition=(eax==1))
  state.registers.store('rbx',0x04100800,size=8,condition=(eax==1))
  state.registers.store('rcx',0x7ffafbff,size=8,condition=(eax==1))
  state.registers.store('rdx',0xbfebfbff,size=8,condition=(eax==1))
  state.registers.store('rax',0x00000000,size=8,condition=(eax==7))
  state.registers.store('rbx',0x000027ab,size=8,condition=(eax==7))
  state.registers.store('rcx',0x00000000,size=8,condition=(eax==7))
  state.registers.store('rdx',0x9c000600,size=8,condition=(eax==7))
  return None,[]

angr.engines.vex.heavy.dirty.amd64g_dirtyhelper_CPUID_avx2 = cpuid

# ===== main

stdin = []
for i in range(n):
  varname = f'x_{i}_{bits}'
  variable = claripy.BVS(varname,bits,explicit_name=True)
  if littleendian: variable = claripy.Reverse(variable)
  stdin += [variable]

stdin = angr.SimFile('/dev/stdin',content=claripy.Concat(*stdin),has_end=True)

proj = angr.Project(binary,auto_load_libs=False,force_load_libs=libs)
state = proj.factory.full_init_state(args=[binary,str(n),str(impl),implname,implcompiler],add_options=add_options,remove_options=remove_options,stdin=stdin)
simgr = proj.factory.simgr(state)
simgr.run()
assert len(simgr.errored) == 0

exits = simgr.deadended
if len(exits) > 1:
  mergedexit,_,_ = exits[0].merge(*exits[1:],merge_conditions=[e2.solver.constraints for e2 in exits])
else:
  mergedexit = exits[0]

packets = mergedexit.posix.stdout.content

def rename(op):
  if op == '__add__': return 'add'
  if op == '__sub__': return 'sub'
  if op == '__mul__': return 'mul'
  if op == '__or__': return 'or'
  if op == '__xor__': return 'xor'
  if op == '__and__': return 'and'
  if op == '__invert__': return 'invert'
  if op == '__eq__': return 'equal'
  if op == '__ge__': return 'unsignedge'
  if op == '__gt__': return 'unsignedgt'
  if op == '__le__': return 'unsignedle'
  if op == '__lt__': return 'unsignedlt'
  if op == '__lshift__': return 'lshift'
  if op == '__rshift__': return 'signedrshift'
  if op == 'LShR': return 'unsignedrshift'
  if op == 'SLE': return 'signedle'
  if op == 'SLT': return 'signedlt'
  return op

walked = {}
walknext = 0

def walk(t):
  global walknext
  if t in walked: return walked[t]
  if t.op == '__xor__':
    inputs = [walk(a) for a in t.args]
    result = inputs[0]
    for x in inputs[1:]:
      walknext += 1
      print('v%d = xor(v%d,v%d)' % (walknext,result,x))
      result = walknext
    walked[t] = result
    return result
  if t.op == 'BVV':
    walknext += 1
    print('v%d = constant(%d,%d)' % (walknext,t.size(),t.args[0]))
  elif t.op == 'BVS':
    walknext += 1
    print('v%d = %s' % (walknext,t.args[0]))
  elif t.op == 'Extract':
    assert len(t.args) == 3
    input = 'v%d' % walk(t.args[2])
    walknext += 1
    print('v%d = Extract(%s,%d,%d)' % (walknext,input,t.args[0],t.args[1]))
  elif t.op in ('SignExt','ZeroExt'):
    assert len(t.args) == 2
    input = 'v%d' % walk(t.args[1])
    walknext += 1
    print('v%d = %s(%s,%d)' % (walknext,t.op,input,t.args[0]))
  else:
    inputs = ['v%d' % walk(a) for a in t.args]
    walknext += 1
    if t.op == 'SGE':
      t.op = 'SLE'
      inputs = reversed(inputs)
    if t.op == 'SGT':
      t.op = 'SLT'
      inputs = reversed(inputs)
    print('v%d = %s(%s)' % (walknext,rename(t.op),','.join(inputs)))
  walked[t] = walknext
  return walknext

assert len(packets) == n
for ppos,p in enumerate(packets):
  assert p[1].op == 'BVV' and 8*p[1].args[0] == bits
  result = p[0]
  if littleendian: result = claripy.Reverse(result)
  print(f'y_{ppos}_{bits} = v{walk(result)}')
