Browse Source

PubKey: add privkey attribute, perform pubkey_type matching checks

The MMGen Project 4 years ago
parent
commit
d1d7132a7c
4 changed files with 40 additions and 30 deletions
  1. 30 20
      mmgen/addr.py
  2. 5 6
      mmgen/obj.py
  3. 3 2
      test/objtest_py_d/ot_btc_mainnet.py
  4. 2 2
      test/test_py_d/ts_main.py

+ 30 - 20
mmgen/addr.py

@@ -24,7 +24,7 @@ from hashlib import sha256,sha512
 from .common import *
 from .obj import *
 from .baseconv import *
-from .protocol import init_proto
+from .protocol import init_proto,hash160
 
 pnm = g.proj_name
 
@@ -34,30 +34,32 @@ def dmsg_sc(desc,data):
 
 class AddrGenerator(MMGenObject):
 	def __new__(cls,proto,addr_type):
-		if type(addr_type) == str: # allow override w/o check
-			gen_method = addr_type
+
+		if type(addr_type) == str:
+			addr_type = MMGenAddrType(proto=proto,id_str=addr_type)
 		elif type(addr_type) == MMGenAddrType:
 			assert addr_type in proto.mmtypes, f'{addr_type}: invalid address type for coin {proto.coin}'
-			gen_method = addr_type.gen_method
 		else:
 			raise TypeError(f'{type(addr_type)}: incorrect argument type for {cls.__name__}()')
-		gen_methods = {
+
+		addr_generators = {
 			'p2pkh':    AddrGeneratorP2PKH,
 			'segwit':   AddrGeneratorSegwit,
 			'bech32':   AddrGeneratorBech32,
 			'ethereum': AddrGeneratorEthereum,
 			'zcash_z':  AddrGeneratorZcashZ,
-			'monero':   AddrGeneratorMonero}
-		assert gen_method in gen_methods
-		me = super(cls,cls).__new__(gen_methods[gen_method])
-		me.desc = gen_methods
+			'monero':   AddrGeneratorMonero,
+		}
+		me = super(cls,cls).__new__(addr_generators[addr_type.gen_method])
+		me.desc = type(me).__name__
 		me.proto = proto
+		me.addr_type = addr_type
+		me.pubkey_type = addr_type.pubkey_type
 		return me
 
 class AddrGeneratorP2PKH(AddrGenerator):
 	def to_addr(self,pubhex):
-		from .protocol import hash160
-		assert type(pubhex) == PubKey
+		assert pubhex.privkey.pubkey_type == self.pubkey_type
 		return CoinAddr(self.proto,self.proto.pubhash2addr(hash160(pubhex),p2sh=False))
 
 	def to_segwit_redeem_script(self,pubhex):
@@ -65,6 +67,7 @@ class AddrGeneratorP2PKH(AddrGenerator):
 
 class AddrGeneratorSegwit(AddrGenerator):
 	def to_addr(self,pubhex):
+		assert pubhex.privkey.pubkey_type == self.pubkey_type
 		assert pubhex.compressed,'Uncompressed public keys incompatible with Segwit'
 		return CoinAddr(self.proto,self.proto.pubhex2segwitaddr(pubhex))
 
@@ -74,8 +77,8 @@ class AddrGeneratorSegwit(AddrGenerator):
 
 class AddrGeneratorBech32(AddrGenerator):
 	def to_addr(self,pubhex):
+		assert pubhex.privkey.pubkey_type == self.pubkey_type
 		assert pubhex.compressed,'Uncompressed public keys incompatible with Segwit'
-		from .protocol import hash160
 		return CoinAddr(self.proto,self.proto.pubhash2bech32addr(hash160(pubhex)))
 
 	def to_segwit_redeem_script(self,pubhex):
@@ -96,7 +99,7 @@ class AddrGeneratorEthereum(AddrGenerator):
 		self.hash256 = hash256
 
 	def to_addr(self,pubhex):
-		assert type(pubhex) == PubKey
+		assert pubhex.privkey.pubkey_type == self.pubkey_type
 		return CoinAddr(self.proto,self.keccak_256(bytes.fromhex(pubhex[2:])).hexdigest()[24:])
 
 	def to_wallet_passwd(self,sk_hex):
@@ -116,6 +119,7 @@ class AddrGeneratorZcashZ(AddrGenerator):
 		return Sha256(s,preprocess=False).digest()
 
 	def to_addr(self,pubhex): # pubhex is really privhex
+		assert pubhex.privkey.pubkey_type == self.pubkey_type
 		key = bytes.fromhex(pubhex)
 		assert len(key) == 32, f'{len(key)}: incorrect privkey length'
 		from nacl.bindings import crypto_scalarmult_base
@@ -173,6 +177,7 @@ class AddrGeneratorMonero(AddrGenerator):
 		return a + b
 
 	def to_addr(self,sk_hex): # sk_hex instead of pubhex
+		assert sk_hex.privkey.pubkey_type == self.pubkey_type
 
 		# Source and license for scalarmultbase function:
 		#   https://github.com/bigreddmachine/MoneroPy/blob/master/moneropy/crypto/ed25519.py
@@ -277,21 +282,26 @@ class KeyGeneratorPython(KeyGenerator):
 
 	def to_pubhex(self,privhex):
 		assert type(privhex) == PrivKey
-		return PubKey(self.privnum2pubhex(
-			int(privhex,16),compressed=privhex.compressed),compressed=privhex.compressed)
+		return PubKey(
+			s       = self.privnum2pubhex(int(privhex,16),compressed=privhex.compressed),
+			privkey = privhex )
 
 class KeyGeneratorSecp256k1(KeyGenerator):
 	desc = 'mmgen-secp256k1'
 	def to_pubhex(self,privhex):
 		assert type(privhex) == PrivKey
 		from .secp256k1 import priv2pub
-		return PubKey(priv2pub(bytes.fromhex(privhex),int(privhex.compressed)).hex(),compressed=privhex.compressed)
+		return PubKey(
+			s       = priv2pub(bytes.fromhex(privhex),int(privhex.compressed)).hex(),
+			privkey = privhex )
 
 class KeyGeneratorDummy(KeyGenerator):
 	desc = 'mmgen-dummy'
 	def to_pubhex(self,privhex):
 		assert type(privhex) == PrivKey
-		return PubKey(privhex,compressed=privhex.compressed)
+		return PubKey(
+			s       = privhex,
+			privkey = privhex )
 
 class AddrListEntryBase(MMGenListItem):
 	invalid_attrs = {'proto'}
@@ -605,9 +615,9 @@ Removed {{}} duplicate WIF key{{}} from keylist (also in {pnm} key-address file
 
 	def generate_addrs_from_keys(self):
 		# assume that the first listed mmtype is valid for flat key list
-		t = self.proto.addr_type(self.proto.mmtypes[0])
-		kg = KeyGenerator(self.proto,t.pubkey_type)
-		ag = AddrGenerator(self.proto,t.gen_method)
+		at = self.proto.addr_type(self.proto.mmtypes[0])
+		kg = KeyGenerator(self.proto,at.pubkey_type)
+		ag = AddrGenerator(self.proto,at)
 		d = self.data
 		for n,e in enumerate(d,1):
 			qmsg_r('\rGenerating addresses from keylist: {}/{}'.format(n,len(d)))

+ 5 - 6
mmgen/obj.py

@@ -746,15 +746,14 @@ class WifKey(str,Hilite,InitErrors):
 			return cls.init_fail(e,wif)
 
 class PubKey(HexStr,MMGenObject): # TODO: add some real checks
-	def __new__(cls,s,compressed):
+	def __new__(cls,s,privkey):
 		try:
-			assert type(compressed) == bool,"'compressed' must be of type bool"
+			me = HexStr.__new__(cls,s,case='lower')
+			me.privkey = privkey
+			me.compressed = privkey.compressed
+			return me
 		except Exception as e:
 			return cls.init_fail(e,s)
-		me = HexStr.__new__(cls,s,case='lower')
-		if me:
-			me.compressed = compressed
-			return me
 
 class PrivKey(str,Hilite,InitErrors,MMGenObject):
 	"""

+ 3 - 2
test/objtest_py_d/ot_btc_mainnet.py

@@ -16,6 +16,7 @@ proto = init_proto('btc')
 tw_pfx = proto.base_coin.lower() + ':'
 
 ssm = str(SeedShareCount.max_val)
+privkey = PrivKey(proto=proto,s=bytes.fromhex('deadbeef'*8),compressed=True,pubkey_type='std')
 
 tests = {
 	'Int': {
@@ -224,8 +225,8 @@ tests = {
 	},
 	'PubKey': {
 		'arg1': 's',
-		'bad':  ({'arg':1,'compressed':False},{'arg':'F00BAA12','compressed':False},),
-		'good': ({'arg':'deadbeef','compressed':True},) # TODO: add real pubkeys
+		'bad':  ({'s':1,'privkey':False},{'s':'F00BAA12','privkey':False},),
+		'good': ({'s':'deadbeef','privkey':privkey},) # TODO: add real pubkeys
 	},
 	'PrivKey': {
 		'arg1': 'proto',

+ 2 - 2
test/test_py_d/ts_main.py

@@ -339,7 +339,7 @@ class TestSuiteMain(TestSuiteBase,TestSuiteShared):
 			from mmgen.addr import AddrGenerator,KeyGenerator
 			rand_coinaddr = AddrGenerator(
 				self.proto,
-				'p2pkh'
+				'compressed'
 				).to_addr(KeyGenerator(self.proto,'std').to_pubhex(privkey))
 			of = joinpath(self.cfgs[non_mmgen_input]['tmpdir'],non_mmgen_fn)
 			write_data_to_file(
@@ -376,7 +376,7 @@ class TestSuiteMain(TestSuiteBase,TestSuiteShared):
 	def _make_txcreate_cmdline(self,tx_data):
 		from mmgen.obj import PrivKey
 		privkey = PrivKey(self.proto,os.urandom(32),compressed=True,pubkey_type='std')
-		t = ('p2pkh','segwit')['S' in self.proto.mmtypes]
+		t = ('compressed','segwit')['S' in self.proto.mmtypes]
 		from mmgen.addr import AddrGenerator,KeyGenerator
 		rand_coinaddr = AddrGenerator(self.proto,t).to_addr(KeyGenerator(self.proto,'std').to_pubhex(privkey))