Browse Source

gentest.py: cleanups, improve arg parsing and data caching

The MMGen Project 3 years ago
parent
commit
83859767aa
1 changed files with 130 additions and 87 deletions
  1. 130 87
      test/gentest.py

+ 130 - 87
test/gentest.py

@@ -133,7 +133,8 @@ SUPPORTED EXTERNAL TOOLS:
 	}
 }
 
-gtr = namedtuple('gen_tool_result',['wif','addr','vk'])
+gtr = namedtuple('gen_tool_result',['wif','addr','viewkey'])
+sd = namedtuple('saved_data_item',['reduced','wif','addr','viewkey'])
 
 def get_cmd_output(cmd,input=None):
 	return run(cmd,input=input,stdout=PIPE,stderr=DEVNULL).stdout.decode().splitlines()
@@ -150,31 +151,42 @@ class GenTool(object):
 	def __del__(self):
 		if opt.save_results:
 			key = f'{self.proto.coin}-{self.proto.network}-{self.addr_type.name}-{self.desc}'.lower()
-			saved_results[key] = self.data
+			saved_results[key] = {k.hex():v._asdict() for k,v in self.data.items()}
 
-	def run_tool(self,sec):
+	def run_tool(self,sec,cache_data):
 		vcoin = 'BTC' if self.proto.coin == 'BCH' else self.proto.coin
-		ret = self.run(sec,vcoin)
-		self.data[sec.hex()] = ret._asdict()
-		return ret
+		key = sec.orig_bytes
+		if key in self.data:
+			return self.data[key]
+		else:
+			ret = self.run(sec,vcoin)
+			if cache_data:
+				self.data[key] = sd( **{'reduced':sec.hex()}, **ret._asdict() )
+			return ret
 
 class GenToolEthkey(GenTool):
 	desc = 'ethkey'
 	def run(self,sec,vcoin):
 		o = get_cmd_output(['ethkey','info',sec.hex()])
-		return gtr(o[0].split()[1],o[-1].split()[1],None)
+		return gtr(
+			o[0].split()[1],
+			o[-1].split()[1],
+			None )
 
 class GenToolKeyconv(GenTool):
 	desc = 'keyconv'
 	def run(self,sec,vcoin):
 		o = get_cmd_output(['keyconv','-C',vcoin,sec.wif])
-		return gtr(o[1].split()[1],o[0].split()[1],None)
+		return gtr(
+			o[1].split()[1],
+			o[0].split()[1],
+			None )
 
 class GenToolZcash_mini(GenTool):
 	desc = 'zcash-mini'
 	def run(self,sec,vcoin):
 		o = get_cmd_output(['zcash-mini','-key','-simple'],input=(sec.wif+'\n').encode())
-		return gtr(o[1],o[0],o[-1])
+		return gtr( o[1], o[0], o[-1] )
 
 class GenToolPycoin(GenTool):
 	"""
@@ -191,7 +203,7 @@ class GenToolPycoin(GenTool):
 
 	def run(self,sec,vcoin):
 		if self.proto.testnet:
-			vcoin = ci.external_tests['testnet']['pycoin'][vcoin]
+			vcoin = cinfo.external_tests['testnet']['pycoin'][vcoin]
 		network = self.nfnc(vcoin)
 		key = network.keys.private(
 			secret_exponent = int(sec.hex(),16),
@@ -207,7 +219,7 @@ class GenToolPycoin(GenTool):
 				addr = network.address.for_p2pkh_wit(hash160_c)
 		else:
 			addr = key.address()
-		return gtr(key.wif(),addr,None)
+		return gtr( key.wif(), addr, None )
 
 class GenToolMoneropy(GenTool):
 	desc = 'moneropy'
@@ -221,15 +233,12 @@ class GenToolMoneropy(GenTool):
 		self.mpa = moneropy.account
 
 	def run(self,sec,vcoin):
-		if sec.hex() in self.data:
-			return gtr(**self.data[sec.hex()])
-		else:
-			sk,vk,addr = self.mpa.account_from_spend_key(sec.hex()) # VERY slow!
-			return gtr(sk,addr,vk)
+		sk,vk,addr = self.mpa.account_from_spend_key(sec.hex()) # VERY slow!
+		return gtr( sk, addr, vk )
 
 def find_or_check_tool(proto,addr_type,toolname):
 
-	ext_progs = list(ci.external_tests[proto.network])
+	ext_progs = list(cinfo.external_tests[proto.network])
 
 	if toolname not in ext_progs + ['ext']:
 		die(1,f'{toolname!r}: unsupported tool for network {proto.network}')
@@ -237,7 +246,7 @@ def find_or_check_tool(proto,addr_type,toolname):
 	if opt.all_coins and toolname == 'ext':
 		die(1,"'--all-coins' must be combined with a specific external testing tool")
 	else:
-		tool = ci.get_test_support(
+		tool = cinfo.get_test_support(
 			proto.coin,
 			addr_type.name,
 			proto.network,
@@ -266,47 +275,55 @@ def test_equal(desc,a_val,b_val,in_bytes,sec,wif,a_desc,b_desc):
 				w=max(len(e) for e in (a_desc,b_desc)) + 1
 		).rstrip())
 
-def do_ab_test(proto,addr_type,kg_b,rounds,backend_num):
+def do_ab_test(proto,cfg,addr_type,gen1,kg2,ag,tool,cache_data):
 
-	def run_ab_inner(n,trounds,in_bytes):
+	def do_ab_inner(n,trounds,in_bytes):
 		global last_t
 		if opt.verbose or time.time() - last_t >= 0.1:
 			qmsg_r(f'\rRound {i+1}/{trounds} ')
 			last_t = time.time()
 		sec = PrivKey(proto,in_bytes,compressed=addr_type.compressed,pubkey_type=addr_type.pubkey_type)
-		data = kg_a.gen_data(sec)
-		ag = AddrGenerator(proto,addr_type)
-		a_addr = ag.to_addr(data)
-		tinfo = (in_bytes,sec,sec.wif,type(kg_a).__name__,type(kg_b).__name__)
-		a_vk = None
+		data = kg1.gen_data(sec)
+		addr1 = ag.to_addr(data)
+		tinfo = ( in_bytes, sec, sec.wif, type(kg1).__name__, type(kg2).__name__ if kg2 else tool.desc )
 
 		def do_msg():
-			vmsg( fs.format( b=in_bytes.hex(), r=sec.hex(), k=sec.wif, v=a_vk, a=a_addr ))
+			if opt.verbose:
+				msg( fs.format( b=in_bytes.hex(), r=sec.hex(), k=sec.wif, v=vk2, a=addr1 ))
 
-		if isinstance(kg_b,GenTool):
+		if tool:
 			def run_tool():
-				b = kg_b.run_tool(sec)
-				test_equal('WIF keys',sec.wif,b.wif,*tinfo)
-				test_equal('addresses',a_addr,b.addr,*tinfo)
-				if b.vk:
-					test_equal( 'view keys', ag.to_viewkey(data), b.vk, *tinfo )
-				return b.vk
-			a_vk = run_tool()
+				o = tool.run_tool(sec,cache_data)
+				test_equal( 'WIF keys', sec.wif, o.wif, *tinfo )
+				test_equal( 'addresses', addr1, o.addr, *tinfo )
+				if o.viewkey:
+					test_equal( 'view keys', ag.to_viewkey(data), o.viewkey, *tinfo )
+				return o.viewkey
+			vk2 = run_tool()
 			do_msg()
 		else:
-			test_equal( 'addresses', a_addr, ag.to_addr(kg_b.gen_data(sec)), *tinfo )
+			test_equal( 'addresses', addr1, ag.to_addr(kg2.gen_data(sec)), *tinfo )
+			vk2 = None
 			do_msg()
 
 		qmsg_r(f'\rRound {n+1}/{trounds} ')
 
-	kg_a = KeyGenerator(proto,addr_type.pubkey_type,backend_num)
-	if type(kg_a) == type(kg_b):
+	def get_randbytes():
+		if tool and len(tool.data) > len(edgecase_sks):
+			for privbytes in tuple(tool.data)[len(edgecase_sks):]:
+				yield privbytes
+		else:
+			for i in range(cfg.rounds):
+				yield getrand(32)
+
+	kg1 = KeyGenerator( proto, addr_type.pubkey_type, gen1 )
+	if type(kg1) == type(kg2):
 		rdie(1,'Key generators are the same!')
 
-	e = ci.get_entry(proto.coin,proto.network)
+	e = cinfo.get_entry(proto.coin,proto.network)
 	qmsg(green("Comparing address generators '{A}' and '{B}' for {N} {c} ({n}), addrtype {a!r}".format(
-		A = type(kg_a).__name__,
-		B = type(kg_b).__name__.replace('GenTool','').replace('_','-').lower(),
+		A = type(kg1).__name__.replace('_','-'),
+		B = type(kg2).__name__.replace('_','-') if kg2 else tool.desc,
 		N = proto.network,
 		c = proto.coin,
 		n = e.name if e else '---',
@@ -337,37 +354,41 @@ def do_ab_test(proto,addr_type,kg_b,rounds,backend_num):
 	)
 
 	qmsg(purple('edge cases:'))
-	for i,in_bytes in enumerate(edgecase_sks):
-		run_ab_inner(i,len(edgecase_sks),in_bytes)
+	for i,privbytes in enumerate(edgecase_sks):
+		do_ab_inner(i,len(edgecase_sks),privbytes)
 	qmsg(green('\rOK            ' if opt.verbose else 'OK'))
 
 	qmsg(purple('random input:'))
-	for i in range(rounds):
-		run_ab_inner(i,rounds,getrand(32))
+	for i,privbytes in enumerate(get_randbytes()):
+		do_ab_inner(i,cfg.rounds,privbytes)
 	qmsg(green('\rOK            ' if opt.verbose else 'OK'))
 
 def init_tool(proto,addr_type,toolname):
 	return globals()['GenTool'+capfirst(toolname.replace('-','_'))](proto,addr_type)
 
-def ab_test(proto,gen_num,rounds,toolname_or_gen2_num):
+def ab_test(proto,cfg):
 
 	addr_type = MMGenAddrType( proto=proto, id_str=opt.type or proto.dfl_mmtype )
 
-	if is_int(toolname_or_gen2_num):
-		assert gen_num != 'all', "'all' must be used only with external tool"
-		tool = KeyGenerator( proto, addr_type.pubkey_type, int(toolname_or_gen2_num) )
+	if cfg.gen2:
+		assert cfg.gen1 != 'all', "'all' must be used only with external tool"
+		kg2 = KeyGenerator( proto, addr_type.pubkey_type, cfg.gen2 )
+		tool = None
 	else:
-		toolname = find_or_check_tool( proto, addr_type, toolname_or_gen2_num )
+		toolname = find_or_check_tool( proto, addr_type, cfg.tool )
 		if toolname == None:
-			ymsg(f'Warning: skipping tool {toolname_or_gen2_num!r} for {proto.coin} {addr_type.name}')
+			ymsg(f'Warning: skipping tool {cfg.tool!r} for {proto.coin} {addr_type.name}')
 			return
 		tool = init_tool( proto, addr_type, toolname )
+		kg2 = None
+
+	ag = AddrGenerator( proto, addr_type )
 
-	if gen_num == 'all': # check all backends against external tool
+	if cfg.all_backends: # check all backends against external tool
 		for n in range(len(get_backends(addr_type.pubkey_type))):
-			do_ab_test( proto, addr_type, tool, rounds, n+1 )
+			do_ab_test( proto, cfg, addr_type, gen1=n+1, kg2=kg2, ag=ag, tool=tool, cache_data=not n )
 	else:                # check specific backend against external tool or another backend
-		do_ab_test( proto, addr_type, tool, rounds, gen_num )
+		do_ab_test( proto, cfg, addr_type, gen1=cfg.gen1, kg2=kg2, ag=ag, tool=tool, cache_data=False )
 
 def speed_test(proto,kg,ag,rounds):
 	qmsg(green('Testing speed of address generator {!r} for coin {}'.format(
@@ -423,7 +444,7 @@ def get_protos(proto,addr_type,toolname):
 
 	init_genonly_altcoins(testnet=proto.testnet)
 
-	for coin in ci.external_tests[proto.network][toolname]:
+	for coin in cinfo.external_tests[proto.network][toolname]:
 		if coin.lower() not in CoinProtocol.coins:
 			continue
 		ret = init_proto(coin,testnet=proto.testnet)
@@ -437,53 +458,75 @@ def parse_args():
 		opts.usage()
 
 	arg1,arg2 = cmd_args
-	pa = namedtuple('parsed_args',['test','gen_num','rounds','arg'])
+	cfg = namedtuple('parsed_args',['test','gen1','gen2','rounds','tool','all_backends','dumpfile'])
+	gen1,gen2,rounds = (0,0,0)
+	tool,all_backends,dumpfile = (None,None,None)
 
 	if is_int(arg1) and is_int(arg2):
-		return pa( test='speed', gen_num=arg1, rounds=int(arg2), arg=None )
-
-	if is_int(arg1) and os.access(arg2,os.R_OK):
-		return pa( test='dump', gen_num=arg1, rounds=None, arg=arg2 )
-
-	if not is_int(arg2):
-		die(1,'Second argument must be dump filename or integer rounds specification')
+		test = 'speed'
+		gen1 = arg1
+		rounds = arg2
+	elif is_int(arg1) and os.access(arg2,os.R_OK):
+		test = 'dump'
+		gen1 = arg1
+		dumpfile = arg2
+	else:
+		test = 'ab'
+		rounds = arg2
 
-	try:
-		a,b = arg1.split(':')
-	except:
-		die(1,'First argument must be a generator backend number or two colon-separated arguments')
+		if not is_int(arg2):
+			die(1,'Second argument must be dump filename or integer rounds specification')
 
-	if not is_int(a) and a != 'all':
-		die(1,"First part of first argument must be a generator backend number or 'all'")
+		try:
+			a,b = arg1.split(':')
+		except:
+			die(1,'First argument must be a generator backend number or two colon-separated arguments')
 
-	if is_int(b):
-		if opt.all_coins:
-			die(1,'--all-coins must be used with external tool only')
-	else:
-		proto = init_proto_from_opts()
-		ext_progs = list(ci.external_tests[proto.network]) + ['ext']
-		if b not in ext_progs:
-			die(1,f'Second part of first argument must be a generator backend number or one of {ext_progs}')
+		if is_int(a):
+			gen1 = a
+		else:
+			if a == 'all':
+				all_backends = True
+			else:
+				die(1,"First part of first argument must be a generator backend number or 'all'")
 
-	return pa( test='ab', gen_num=a, rounds=int(arg2), arg=b )
+		if is_int(b):
+			if opt.all_coins:
+				die(1,'--all-coins must be used with external tool only')
+			gen2 = b
+		else:
+			tool = b
+			proto = init_proto_from_opts()
+			ext_progs = list(cinfo.external_tests[proto.network]) + ['ext']
+			if b not in ext_progs:
+				die(1,f'Second part of first argument must be a generator backend number or one of {ext_progs}')
+
+	return cfg(
+		test,
+		int(gen1) or None,
+		int(gen2) or None,
+		int(rounds) or None,
+		tool,
+		all_backends,
+		dumpfile )
 
 def main():
 
-	pa = parse_args()
+	cfg = parse_args()
 	proto = init_proto_from_opts()
 	addr_type = MMGenAddrType( proto=proto, id_str=opt.type or proto.dfl_mmtype )
 
-	if pa.test == 'ab':
-		protos = get_protos(proto,addr_type,pa.arg) if opt.all_coins else [proto]
+	if cfg.test == 'ab':
+		protos = get_protos(proto,addr_type,cfg.tool) if opt.all_coins else [proto]
 		for proto in protos:
-			ab_test( proto, pa.gen_num, pa.rounds, toolname_or_gen2_num=pa.arg )
+			ab_test( proto, cfg )
 	else:
-		kg = KeyGenerator( proto, addr_type.pubkey_type, pa.gen_num )
+		kg = KeyGenerator( proto, addr_type.pubkey_type, cfg.gen1 )
 		ag = AddrGenerator( proto, addr_type )
-		if pa.test == 'speed':
-			speed_test( proto, kg, ag, pa.rounds )
-		elif pa.test == 'dump':
-			dump_test( proto, kg, ag, filename=pa.arg )
+		if cfg.test == 'speed':
+			speed_test( proto, kg, ag, cfg.rounds )
+		elif cfg.test == 'dump':
+			dump_test( proto, kg, ag, cfg.dumpfile )
 
 	if saved_results:
 		import json
@@ -493,7 +536,7 @@ def main():
 from subprocess import run,PIPE,DEVNULL
 from collections import namedtuple
 from mmgen.protocol import init_proto,init_proto_from_opts,CoinProtocol,init_genonly_altcoins
-from mmgen.altcoin import CoinInfo as ci
+from mmgen.altcoin import CoinInfo as cinfo
 from mmgen.key import PrivKey
 from mmgen.addr import KeyGenerator,AddrGenerator,MMGenAddrType
 from mmgen.keygen import get_backends