Browse Source

gentest.py: refactor compare_test()

The MMGen Project 5 years ago
parent
commit
d20d10b9a5
1 changed files with 31 additions and 26 deletions
  1. 31 26
      test/gentest.py

+ 31 - 26
test/gentest.py

@@ -157,14 +157,15 @@ def init_external_prog():
 	b_desc = ext_prog
 	b = 'ext'
 
-def match_error(sec,wif,a_addr,b_addr,a,b):
-	qmsg_r(red('\nERROR: Values do not match!'))
-	die(3,"""
-  sec key   : {}
-  WIF key   : {}
-  {a:10}: {}
-  {b:10}: {}
-""".format(sec,wif,a_addr,b_addr,pnm=g.proj_name,a=kg_a.desc,b=b_desc).rstrip())
+def test_equal(a_addr,b_addr,sec,wif,a,b):
+	if a_addr != b_addr:
+		qmsg_r(red('\nERROR: Values do not match!'))
+		die(3,"""
+	  sec key   : {}
+	  WIF key   : {}
+	  {a:10}: {}
+	  {b:10}: {}
+	""".format(sec,wif,a_addr,b_addr,pnm=g.proj_name,a=kg_a.desc,b=b_desc).rstrip())
 
 def compare_test():
 	for k in ('segwit','compressed'):
@@ -176,6 +177,7 @@ def compare_test():
 		if g.coin not in ci.external_tests[('mainnet','testnet')[g.testnet]][ext_prog]:
 			msg("Coin '{}' incompatible with external generator '{}'".format(g.coin,ext_prog))
 			return
+	global last_t
 	last_t = time.time()
 	A = kg_a.desc
 	B = ext_prog if b == 'ext' else kg_b.desc
@@ -185,32 +187,36 @@ def compare_test():
 	m = "Comparing address generators '{}' and '{}' for coin {}, addrtype {!r}"
 	qmsg(green(m.format(A,B,g.coin,addr_type.name)))
 
-	for i in range(rounds):
+	def do_compare_test(n,trounds,in_bytes):
+		global last_t
 		if opt.verbose or time.time() - last_t >= 0.1:
-			qmsg_r('\rRound {}/{} '.format(i+1,rounds))
+			qmsg_r('\rRound {}/{} '.format(i+1,trounds))
 			last_t = time.time()
-		sec = PrivKey(os.urandom(32),compressed=addr_type.compressed,pubkey_type=addr_type.pubkey_type)
+		sec = PrivKey(in_bytes,compressed=addr_type.compressed,pubkey_type=addr_type.pubkey_type)
 		ph = kg_a.to_pubhex(sec)
 		a_addr = ag.to_addr(ph)
-		if addr_type.name == 'zcash_z':
-			a_vk = ag.to_viewkey(ph)
+		a_vk = ag.to_viewkey(ph) if 'viewkey' in addr_type.extra_attrs else None
 		if b == 'ext':
-			if addr_type.name == 'zcash_z':
+			if 'viewkey' in addr_type.extra_attrs:
 				b_wif,b_addr,b_vk = ext_sec2addr(sec)
-				vmsg_r('\nvkey: {}'.format(b_vk))
-				if b_vk != a_vk:
-					match_error(sec,sec.wif,a_vk,b_vk,a,b)
+				test_equal(a_vk,b_vk,sec,sec.wif,a,b)
 			else:
 				b_wif,b_addr = ext_sec2addr(sec)
-			if b_wif != sec.wif:
-				match_error(sec,sec.wif,sec.wif,b_wif,a,b)
+			test_equal(sec.wif,b_wif,sec,sec.wif,a,b)
 		else:
 			b_addr = ag.to_addr(kg_b.to_pubhex(sec))
-		vmsg('\nkey:  {}\naddr: {}\n'.format(sec.wif,a_addr))
-		if a_addr != b_addr:
-			match_error(sec,sec.wif,a_addr,b_addr,a,ext_prog if b == 'ext' else b)
-	qmsg_r('\rRound {}/{} '.format(i+1,rounds))
-	qmsg(green(('\n','')[bool(opt.verbose)] + 'OK'))
+		vmsg(ct_fs.format(b=in_bytes.hex(),k=sec.wif,v=a_vk,a=a_addr))
+		test_equal(a_addr,b_addr,sec,sec.wif,a,ext_prog if b == 'ext' else b)
+		qmsg_r('\rRound {}/{} '.format(n+1,trounds))
+
+	ct_fs   = ( '\ninput:    {b}\n%-9s {k}\naddr:     {a}\n',
+				'\ninput:    {b}\n%-9s {k}\nvkey:     {v}\naddr:     {a}\n')[
+					'viewkey' in addr_type.extra_attrs] % (addr_type.wif_label + ':')
+
+	qmsg(purple('random input:'))
+	for i in range(rounds):
+		do_compare_test(i,rounds,os.urandom(32))
+	qmsg(green('\rOK            ' if opt.verbose else 'OK'))
 
 def speed_test():
 	m = "Testing speed of address generator '{}' for coin {}"
@@ -243,8 +249,7 @@ def dump_test():
 			die(2,'\nInvalid {}net WIF address in dump file: {}'.format(('main','test')[g.testnet],wif))
 		b_addr = ag.to_addr(kg_a.to_pubhex(sec))
 		vmsg('\nwif: {}\naddr: {}\n'.format(wif,b_addr))
-		if a_addr != b_addr:
-			match_error(sec,wif,a_addr,b_addr,3,a)
+		test_equal(a_addr,b_addr,sec,wif,3,a)
 	qmsg(green(('\n','')[bool(opt.verbose)] + 'OK'))
 
 # begin execution