Browse Source

proto.decode_wif(): parse version bytes more efficiently

The MMGen Project 2 years ago
parent
commit
61d79e2899
5 changed files with 28 additions and 22 deletions
  1. 1 1
      mmgen/altcoin.py
  2. 8 18
      mmgen/proto/btc.py
  3. 0 1
      mmgen/proto/xmr.py
  4. 13 2
      mmgen/proto/zec.py
  5. 6 0
      mmgen/protocol.py

+ 1 - 1
mmgen/altcoin.py

@@ -454,7 +454,7 @@ class CoinInfo(object):
 					test_equal(
 						'WIF version number',
 						e.wif_ver_num,
-						int.from_bytes(bytes.fromhex(proto.wif_ver_num['std']),'big'),
+						int.from_bytes(proto.wif_ver_bytes['std'],'big'),
 						*cdata )
 
 					test_equal(

+ 8 - 18
mmgen/proto/btc.py

@@ -52,36 +52,26 @@ class mainnet(CoinProtocol.Secp256k1): # chainparams.cpp
 
 	def encode_wif(self,privbytes,pubkey_type,compressed): # input is preprocessed hex
 		assert len(privbytes) == self.privkey_len, f'{len(privbytes)} bytes: incorrect private key length!'
-		assert pubkey_type in self.wif_ver_num, f'{pubkey_type!r}: invalid pubkey_type'
+		assert pubkey_type in self.wif_ver_bytes, f'{pubkey_type!r}: invalid pubkey_type'
 		return b58chk_encode(
-			bytes.fromhex(self.wif_ver_num[pubkey_type])
+			self.wif_ver_bytes[pubkey_type]
 			+ privbytes
 			+ (b'',b'\x01')[bool(compressed)])
 
 	def decode_wif(self,wif):
-		key = b58chk_decode(wif)
-
-		for k,v in self.wif_ver_num.items():
-			v = bytes.fromhex(v)
-			if key[:len(v)] == v:
-				pubkey_type = k
-				key = key[len(v):]
-				break
-		else:
-			raise ValueError('Invalid WIF version number')
+		key_data = b58chk_decode(wif)
+		vlen = self.wif_ver_bytes_len or self.get_wif_ver_bytes_len(key_data)
+		key = key_data[vlen:]
 
 		if len(key) == self.privkey_len + 1:
 			assert key[-1] == 0x01, f'{key[-1]!r}: invalid compressed key suffix byte'
-			compressed = True
-		elif len(key) == self.privkey_len:
-			compressed = False
-		else:
+		elif len(key) != self.privkey_len:
 			raise ValueError(f'{len(key)}: invalid key length')
 
 		return decoded_wif(
 			sec         = key[:self.privkey_len],
-			pubkey_type = pubkey_type,
-			compressed  = compressed )
+			pubkey_type = self.wif_ver_bytes_to_pubkey_type[key_data[:vlen]],
+			compressed  = len(key) == self.privkey_len + 1 )
 
 	def decode_addr(self,addr):
 

+ 0 - 1
mmgen/proto/xmr.py

@@ -25,7 +25,6 @@ class mainnet(CoinProtocol.DummyWIF,CoinProtocol.Base):
 	base_coin      = 'XMR'
 	base_proto     = 'Monero'
 	addr_ver_info  = { '12': 'monero', '2a': 'monero_sub', '13': 'monero_integrated' }
-	wif_ver_num    = {}
 	pubkey_types   = ('monero',)
 	mmtypes        = ('M',)
 	dfl_mmtype     = 'M'

+ 13 - 2
mmgen/proto/zec.py

@@ -13,7 +13,8 @@ Zcash protocol
 """
 
 from .btc import mainnet
-from ..protocol import decoded_addr
+from .common import b58chk_decode
+from ..protocol import decoded_wif,decoded_addr
 
 class mainnet(mainnet):
 	base_coin      = 'ZEC'
@@ -30,12 +31,22 @@ class mainnet(mainnet):
 		from ..opts import opt
 		self.coin_id = 'ZEC-Z' if opt.type in ('zcash_z','Z') else 'ZEC-T'
 
+	def get_wif_ver_bytes_len(self,key_data):
+		"""
+		vlen must be set dynamically since Zcash has variable-length version bytes
+		"""
+		for v in self.wif_ver_bytes.values():
+			if key_data[:len(v)] == v:
+				return len(v)
+		else:
+			raise ValueError('Invalid WIF version number')
+
 	def get_addr_len(self,addr_fmt):
 		return (20,64)[addr_fmt in ('zcash_z','viewkey')]
 
 	def decode_addr_bytes(self,addr_bytes):
 		"""
-		vlen must be set dynamically since Zcash has variable length ver_bytes
+		vlen must be set dynamically since Zcash has variable-length version bytes
 		"""
 		for ver_bytes,addr_fmt in self.addr_ver_bytes.items():
 			vlen = len(ver_bytes)

+ 6 - 0
mmgen/protocol.py

@@ -68,6 +68,12 @@ class CoinProtocol(MMGenObject):
 				'regtest': '_rt',
 			}[network]
 
+			if hasattr(self,'wif_ver_num'):
+				self.wif_ver_bytes = {k:bytes.fromhex(v) for k,v in self.wif_ver_num.items()}
+				self.wif_ver_bytes_to_pubkey_type = {v:k for k,v in self.wif_ver_bytes.items()}
+				vbs = list(self.wif_ver_bytes.values())
+				self.wif_ver_bytes_len = len(vbs[0]) if len(set(len(b) for b in vbs)) == 1 else None
+
 			if hasattr(self,'addr_ver_info'):
 				self.addr_ver_bytes = {bytes.fromhex(k):v for k,v in self.addr_ver_info.items()}
 				self.addr_fmt_to_ver_bytes = {v:k for k,v in self.addr_ver_bytes.items()}