Browse Source

proto.*: use match statement where practicable (5 files)

The MMGen Project 3 months ago
parent
commit
4540ccb40a

+ 13 - 12
mmgen/proto/cosmos/tx/protobuf.py

@@ -144,15 +144,16 @@ class Tx(BaseMessage):
 		pubkey = self.authInfo.signerInfos[0].publicKey.key.data
 		msghash = sha256(bytes(sign_doc)).digest()
 
-		if backend == 'secp256k1':
-			from ...secp256k1.secp256k1 import verify_sig
-			if not verify_sig(sig, msghash, pubkey):
-				raise ValueError('signature verification failed')
-		elif backend == 'ecdsa':
-			# ecdsa.keys.VerifyingKey.verify_digest():
-			#   raises BadSignatureError if the signature is invalid or malformed
-			import ecdsa
-			ec_pubkey = ecdsa.VerifyingKey.from_string(pubkey, curve=ecdsa.curves.SECP256k1)
-			ec_pubkey.verify_digest(sig, msghash)
-		else:
-			raise ValueError(f'verify_sig(): {backend}: unrecognized backend')
+		match backend:
+			case 'secp256k1':
+				from ...secp256k1.secp256k1 import verify_sig
+				if not verify_sig(sig, msghash, pubkey):
+					raise ValueError('signature verification failed')
+			case 'ecdsa':
+				# ecdsa.keys.VerifyingKey.verify_digest():
+				#   raises BadSignatureError if the signature is invalid or malformed
+				import ecdsa
+				ec_pubkey = ecdsa.VerifyingKey.from_string(pubkey, curve=ecdsa.curves.SECP256k1)
+				ec_pubkey.verify_digest(sig, msghash)
+			case _:
+				raise ValueError(f'verify_sig(): {backend}: unrecognized backend')

+ 12 - 11
mmgen/proto/eth/rpc/local.py

@@ -76,18 +76,19 @@ class EthereumRPCClient(RPCClient, metaclass=AsyncInit):
 		self.cur_date = int(bh['timestamp'], 16)
 
 		self.caps = ()
-		if self.daemon.id in ('parity', 'openethereum'):
-			if (await self.call('parity_nodeKind'))['capability'] == 'full':
+		match self.daemon.id:
+			case 'parity' | 'openethereum':
+				if (await self.call('parity_nodeKind'))['capability'] == 'full':
+					self.caps += ('full_node',)
+				# parity/openethereum return chainID only for dev chain:
+				self.chainID = None if ci is None else Int(ci, base=16)
+				self.chain = (await self.call('parity_chain')).replace(' ', '_').replace('_testnet', '')
+			case 'geth' | 'reth' | 'erigon':
+				if self.daemon.network == 'mainnet' and hasattr(daemon_warning, self.daemon.id):
+					daemon_warning(self.daemon.id)
 				self.caps += ('full_node',)
-			# parity/openethereum return chainID only for dev chain:
-			self.chainID = None if ci is None else Int(ci, base=16)
-			self.chain = (await self.call('parity_chain')).replace(' ', '_').replace('_testnet', '')
-		elif self.daemon.id in ('geth', 'reth', 'erigon'):
-			if self.daemon.network == 'mainnet' and hasattr(daemon_warning, self.daemon.id):
-				daemon_warning(self.daemon.id)
-			self.caps += ('full_node',)
-			self.chainID = Int(ci, base=16)
-			self.chain = self.proto.chain_ids[self.chainID]
+				self.chainID = Int(ci, base=16)
+				self.chain = self.proto.chain_ids[self.chainID]
 
 	def make_host_path(self, wallet):
 		return ''

+ 10 - 9
mmgen/proto/eth/tx/status.py

@@ -31,15 +31,16 @@ class Status(TxBase.Status):
 		async def is_in_mempool():
 			if not 'full_node' in tx.rpc.caps:
 				return False
-			if tx.rpc.daemon.id in ('parity', 'openethereum'):
-				return coin_txid in [x['hash'] for x in await tx.rpc.call('parity_pendingTransactions')]
-			elif tx.rpc.daemon.id in ('geth', 'reth', 'erigon'):
-				def gen(key):
-					for e in res[key].values():
-						for v in e.values():
-							yield v['hash']
-				res = await tx.rpc.call('txpool_content')
-				return coin_txid in list(gen('queued')) + list(gen('pending'))
+			match tx.rpc.daemon.id:
+				case 'parity' | 'openethereum':
+					return coin_txid in [x['hash'] for x in await tx.rpc.call('parity_pendingTransactions')]
+				case 'geth' | 'reth' | 'erigon':
+					def gen(key):
+						for e in res[key].values():
+							for v in e.values():
+								yield v['hash']
+					res = await tx.rpc.call('txpool_content')
+					return coin_txid in list(gen('queued')) + list(gen('pending'))
 
 		async def is_in_wallet():
 			d = await tx.rpc.call('eth_getTransactionReceipt', coin_txid)

+ 13 - 13
mmgen/proto/rune/tx/protobuf.py

@@ -159,19 +159,19 @@ def base_unit_to_amt(n, *, decimals):
 def tx_info(tx, proto):
 	b = tx.body.messages[0].body
 	s = tx.authInfo.signerInfos[0]
-	msg_type = tx.body.messages[0].id.removeprefix('/types.')
-	if msg_type == 'MsgSend':
-		from_addr = proto.encode_addr_bech32x(b.fromAddress)
-		to_addr   = proto.encode_addr_bech32x(b.toAddress)
-		asset     = b.amount[0].denom.upper()
-		memo      = tx.body.memo
-		amt       = base_unit_to_amt(int(b.amount[0].amount), decimals=8)
-	elif msg_type == 'MsgDeposit':
-		from_addr = proto.encode_addr_bech32x(b.signer)
-		to_addr = 'None'
-		asset     = b.coins[0].asset.symbol
-		memo      = b.memo
-		amt       = base_unit_to_amt(int(b.coins[0].amount), decimals=b.coins[0].decimals or 8)
+	match msg_type := tx.body.messages[0].id.removeprefix('/types.'):
+		case 'MsgSend':
+			from_addr = proto.encode_addr_bech32x(b.fromAddress)
+			to_addr   = proto.encode_addr_bech32x(b.toAddress)
+			asset     = b.amount[0].denom.upper()
+			memo      = tx.body.memo
+			amt       = base_unit_to_amt(int(b.amount[0].amount), decimals=8)
+		case 'MsgDeposit':
+			from_addr = proto.encode_addr_bech32x(b.signer)
+			to_addr = 'None'
+			asset     = b.coins[0].asset.symbol
+			memo      = b.memo
+			amt       = base_unit_to_amt(int(b.coins[0].amount), decimals=b.coins[0].decimals or 8)
 	yield f'TxID:      {tx.txid}'
 	yield f'Type:      {msg_type}'
 	yield f'From:      {from_addr}'

+ 7 - 7
mmgen/proto/zec/params.py

@@ -64,13 +64,13 @@ class mainnet(mainnet):
 			return super().preprocess_key(sec, pubkey_type)
 
 	def pubhash2addr(self, pubhash, addr_type):
-		hash_len = len(pubhash)
-		if hash_len == 20:
-			return super().pubhash2addr(pubhash, addr_type)
-		elif hash_len == 64:
-			raise NotImplementedError('Zcash z-addresses do not support pubhash2addr()')
-		else:
-			raise ValueError(f'{hash_len}: incorrect pubkey hash length')
+		match len(pubhash):
+			case 20:
+				return super().pubhash2addr(pubhash, addr_type)
+			case 64:
+				raise NotImplementedError('Zcash z-addresses do not support pubhash2addr()')
+			case x:
+				raise ValueError(f'{x}: incorrect pubkey hash length')
 
 	def viewkey(self, viewkey_str):
 		return ZcashViewKey.__new__(ZcashViewKey, self, viewkey_str)