Browse Source

mmgen.tw.ctl: resolve MMGen address via get_addr_label_pairs()

The MMGen Project 2 years ago
parent
commit
8e04c21271

+ 33 - 30
mmgen/proto/btc/tw/common.py

@@ -13,13 +13,13 @@ proto.btc.tw.common: Bitcoin base protocol tracking wallet dependency classes
 """
 """
 
 
 from ....addr import CoinAddr
 from ....addr import CoinAddr
-from ....util import die
+from ....util import die,msg,rmsg
 from ....obj import MMGenList
 from ....obj import MMGenList
 from ....tw.common import get_tw_label
 from ....tw.common import get_tw_label
 
 
 class BitcoinTwCommon:
 class BitcoinTwCommon:
 
 
-	async def get_addr_label_pairs(self):
+	async def get_addr_label_pairs(self,twmmid=None):
 		"""
 		"""
 		Get all the accounts in the tracking wallet and their associated addresses.
 		Get all the accounts in the tracking wallet and their associated addresses.
 		Returns list of (label,address) tuples.
 		Returns list of (label,address) tuples.
@@ -34,36 +34,39 @@ class BitcoinTwCommon:
 			if err:
 			if err:
 				die(4,'Tracking wallet is corrupted!')
 				die(4,'Tracking wallet is corrupted!')
 
 
-		def check_addr_array_lens(acct_pairs):
-			err = False
-			for label,addrs in acct_pairs:
-				if not label:
-					continue
-				if len(addrs) != 1:
-					err = True
-					if len(addrs) == 0:
-						msg(f'Label {label!r}: has no associated address!')
-					else:
-						msg(f'{addrs!r}: more than one {self.proto.coin} address in account!')
-			if err:
-				die(4,'Tracking wallet is corrupted!')
+		async def get_acct_list():
+			if 'label_api' in self.rpc.caps:
+				return await self.rpc.call('listlabels')
+			else:
+				return (await self.rpc.call('listaccounts',0,True)).keys()
+
+		async def get_acct_addrs(acct_list):
+			if 'label_api' in self.rpc.caps:
+				return [list(a.keys())
+					for a in await self.rpc.batch_call('getaddressesbylabel',[(k,) for k in acct_list])]
+			else:
+				return await self.rpc.batch_call('getaddressesbyaccount',[(a,) for a in acct_list])
+
+		acct_labels = [get_tw_label(self.proto,a) for a in await get_acct_list()]
+
+		if twmmid:
+			acct_labels = [lbl for lbl in acct_labels if lbl.mmid == twmmid]
+
+		if not acct_labels:
+			return None
 
 
-		# for compatibility with old mmids, must use raw RPC rather than native data for matching
-		# args: minconf,watchonly, MUST use keys() so we get list, not dict
-		if 'label_api' in self.rpc.caps:
-			acct_list = await self.rpc.call('listlabels')
-			aa = await self.rpc.batch_call('getaddressesbylabel',[(k,) for k in acct_list])
-			acct_addrs = [list(a.keys()) for a in aa]
-		else:
-			acct_list = list((await self.rpc.call('listaccounts',0,True)).keys()) # raw list, no 'L'
-			# use raw list here
-			acct_addrs = await self.rpc.batch_call('getaddressesbyaccount',[(a,) for a in acct_list])
-		acct_labels = MMGenList([get_tw_label(self.proto,a) for a in acct_list])
 		check_dup_mmid(acct_labels)
 		check_dup_mmid(acct_labels)
-		assert len(acct_list) == len(acct_addrs), 'len(listaccounts()) != len(getaddressesbyaccount())'
-		addr_pairs = list(zip(acct_labels,acct_addrs))
-		check_addr_array_lens(addr_pairs)
-		return [(lbl,addrs[0]) for lbl,addrs in addr_pairs]
+
+		acct_addrs = await get_acct_addrs(acct_labels)
+
+		for n,a in enumerate(acct_addrs):
+			if len(a) != 1:
+				raise ValueError(f'{a}: label {acct_labels[n]!r} has != 1 associated address!')
+
+		return [(
+			label,
+			CoinAddr(self.proto,addrs[0])
+		) for label,addrs in zip(acct_labels,acct_addrs)]
 
 
 	async def get_unspent_by_mmid(self,minconf=1,mmid_filter=[]):
 	async def get_unspent_by_mmid(self,minconf=1,mmid_filter=[]):
 		"""
 		"""

+ 2 - 1
mmgen/proto/btc/tw/ctl.py

@@ -15,8 +15,9 @@ proto.btc.twctl: Bitcoin base protocol tracking wallet control class
 from ....globalvars import g
 from ....globalvars import g
 from ....tw.ctl import TrackingWallet,write_mode
 from ....tw.ctl import TrackingWallet,write_mode
 from ....util import msg,msg_r,rmsg,vmsg,die,suf,fmt_list
 from ....util import msg,msg_r,rmsg,vmsg,die,suf,fmt_list
+from .common import BitcoinTwCommon
 
 
-class BitcoinTrackingWallet(TrackingWallet):
+class BitcoinTrackingWallet(TrackingWallet,BitcoinTwCommon):
 
 
 	def init_empty(self):
 	def init_empty(self):
 		self.data = { 'coin': self.proto.coin, 'addresses': {} }
 		self.data = { 'coin': self.proto.coin, 'addresses': {} }

+ 2 - 2
mmgen/proto/btc/tw/txhistory.py

@@ -310,8 +310,8 @@ Actions: [q]uit, r[e]draw:
 
 
 		if self.sinceblock: # mapping data may be incomplete for inputs, so update from 'listlabels'
 		if self.sinceblock: # mapping data may be incomplete for inputs, so update from 'listlabels'
 			mm_map.update(
 			mm_map.update(
-				{ addr: _mmp(lbl.mmid, lbl.comment) if lbl else _mmp(None,None) for lbl,addr in
-					[(get_tw_label(self.proto,a), b) for a,b in await self.get_addr_label_pairs()] }
+				{ addr: _mmp(label.mmid, label.comment) if label else _mmp(None,None)
+					for label,addr in await self.get_addr_label_pairs() }
 			)
 			)
 
 
 		msg_r('Getting wallet transactions...')
 		msg_r('Getting wallet transactions...')

+ 38 - 0
mmgen/proto/eth/tw/common.py

@@ -0,0 +1,38 @@
+#!/usr/bin/env python3
+#
+# mmgen = Multi-Mode GENerator, a command-line cryptocurrency wallet
+# Copyright (C)2013-2022 The MMGen Project <mmgen@tuta.io>
+# Licensed under the GNU General Public License, Version 3:
+#   https://www.gnu.org/licenses
+# Public project repositories:
+#   https://github.com/mmgen/mmgen
+#   https://gitlab.com/mmgen/mmgen
+
+"""
+proto.eth.tw.common: Ethereum base protocol tracking wallet dependency classes
+"""
+
+from ....tw.ctl import TrackingWallet
+from ....addr import CoinAddr
+from ....tw.common import TwLabel
+
+class EthereumTwCommon:
+
+	async def get_addr_label_pairs(self,twmmid=None):
+		wallet = (
+			self if isinstance(self,TrackingWallet) else
+			(self.wallet or await TrackingWallet(self.proto,mode='w'))
+		)
+
+		ret = [(
+				TwLabel( self.proto, mmid + ' ' + d['comment'] ),
+				CoinAddr( self.proto, d['addr'] )
+			) for mmid,d in wallet.mmid_ordered_dict.items() ]
+
+		if wallet is not self:
+			del wallet
+
+		if twmmid:
+			ret = [e for e in ret if e[0].mmid == twmmid]
+
+		return ret or None

+ 2 - 7
mmgen/proto/eth/tw/ctl.py

@@ -25,16 +25,14 @@ from ....tw.ctl import TrackingWallet,write_mode
 from ....addr import is_coin_addr,is_mmgen_id
 from ....addr import is_coin_addr,is_mmgen_id
 from ....amt import ETHAmt
 from ....amt import ETHAmt
 from ..contract import Token,TokenResolve
 from ..contract import Token,TokenResolve
+from .common import EthereumTwCommon
 
 
-class EthereumTrackingWallet(TrackingWallet):
+class EthereumTrackingWallet(TrackingWallet,EthereumTwCommon):
 
 
 	caps = ('batch',)
 	caps = ('batch',)
 	data_key = 'accounts'
 	data_key = 'accounts'
 	use_tw_file = True
 	use_tw_file = True
 
 
-	async def is_in_wallet(self,addr):
-		return addr in self.data_root
-
 	def init_empty(self):
 	def init_empty(self):
 		self.data = {
 		self.data = {
 			'coin': self.proto.coin,
 			'coin': self.proto.coin,
@@ -205,9 +203,6 @@ class EthereumTokenTrackingWallet(EthereumTrackingWallet):
 
 
 		proto.tokensym = self.symbol
 		proto.tokensym = self.symbol
 
 
-	async def is_in_wallet(self,addr):
-		return addr in self.data['tokens'][self.token]
-
 	@property
 	@property
 	def data_root(self):
 	def data_root(self):
 		return self.data['tokens'][self.token]
 		return self.data['tokens'][self.token]

+ 1 - 1
mmgen/tw/common.py

@@ -445,7 +445,7 @@ class TwCommon:
 
 
 			async def do_comment_add(comment):
 			async def do_comment_add(comment):
 				if await parent.wallet.set_comment( entry.twmmid, comment, entry.addr ):
 				if await parent.wallet.set_comment( entry.twmmid, comment, entry.addr ):
-					await parent.get_data()
+					entry.comment = comment
 					parent.oneshot_msg = yellow('Label {a} {b}{c}\n\n'.format(
 					parent.oneshot_msg = yellow('Label {a} {b}{c}\n\n'.format(
 						a = 'for' if cur_comment and comment else 'added to' if comment else 'removed from',
 						a = 'for' if cur_comment and comment else 'added to' if comment else 'removed from',
 						b = desc,
 						b = desc,

+ 29 - 27
mmgen/tw/ctl.py

@@ -214,43 +214,41 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 		elif g.debug:
 		elif g.debug:
 			msg('Data is unchanged\n')
 			msg('Data is unchanged\n')
 
 
-	async def is_in_wallet(self,addr):
-		from .addrs import TwAddrList
-		return addr in (await TwAddrList(self.proto,[],0,True,True,True,wallet=self)).coinaddr_list()
+	async def resolve_address(self,addrspec):
 
 
-	async def resolve_address(self,addrspec,usr_coinaddr=None):
+		twmmid,coinaddr = (None,None)
 
 
-		mmaddr,coinaddr = None,None
+		if is_coin_addr(self.proto,addrspec):
+			coinaddr = get_obj(CoinAddr,proto=self.proto,addr=addrspec)
+		elif is_mmgen_id(self.proto,addrspec):
+			twmmid = TwMMGenID(self.proto,addrspec)
+		else:
+			msg(f'{addrspec!r}: invalid address for this network')
+			return None
 
 
-		if is_coin_addr(self.proto,usr_coinaddr or addrspec):
-			coinaddr = get_obj(CoinAddr,proto=self.proto,addr=usr_coinaddr or addrspec)
+		pairs = await self.get_addr_label_pairs(twmmid)
 
 
-		if is_mmgen_id(self.proto,addrspec):
-			mmaddr = TwMMGenID(self.proto,addrspec)
+		if not pairs:
+			msg(f'MMGen address {twmmid!r} not found in tracking wallet')
+			return None
 
 
-		if mmaddr and not coinaddr:
-			from ..addrdata import TwAddrData
-			coinaddr = (await TwAddrData(self.proto)).mmaddr2coinaddr(mmaddr)
+		pairs_data = dict((label.mmid,addr) for label,addr in pairs)
 
 
-		try:
-			assert coinaddr, (
-				f'{g.proj_name} address {mmaddr!r} not found in tracking wallet' if mmaddr else
-				f'Invalid coin address for this chain: {addrspec}' )
-			assert await self.is_in_wallet(coinaddr), f'Address {coinaddr!r} not found in tracking wallet'
-		except Exception as e:
-			msg(str(e))
-			return None
+		if twmmid and not coinaddr:
+			coinaddr = pairs_data[twmmid]
 
 
 		# Allow for the possibility that BTC addr of MMGen addr was entered.
 		# Allow for the possibility that BTC addr of MMGen addr was entered.
 		# Do reverse lookup, so that MMGen addr will not be marked as non-MMGen.
 		# Do reverse lookup, so that MMGen addr will not be marked as non-MMGen.
-		if not mmaddr:
-			from ..addrdata import TwAddrData
-			mmaddr = (await TwAddrData(proto=self.proto)).coinaddr2mmaddr(coinaddr)
-
-		if not mmaddr:
-			mmaddr = f'{self.proto.base_coin.lower()}:{coinaddr}'
+		if not twmmid:
+			for mmid,addr in pairs_data.items():
+				if coinaddr == addr:
+					twmmid = mmid
+					break
+			else:
+				msg(f'Coin address {addrspec!r} not found in tracking wallet')
+				return None
 
 
-		return addr_info( TwMMGenID(self.proto,mmaddr), coinaddr )
+		return addr_info(twmmid,coinaddr)
 
 
 	# returns on failure
 	# returns on failure
 	@write_mode
 	@write_mode
@@ -277,6 +275,10 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 			return False
 			return False
 
 
 		if await self.set_label(res.coinaddr,lbl):
 		if await self.set_label(res.coinaddr,lbl):
+			# redundant paranoia step:
+			pairs = await self.get_addr_label_pairs(res.twmmid)
+			assert pairs[0][0].comment == comment, f'{pairs[0][0].comment!r} != {comment!r}'
+
 			desc = '{} address {} in tracking wallet'.format(
 			desc = '{} address {} in tracking wallet'.format(
 				res.twmmid.type.replace('mmgen','MMGen'),
 				res.twmmid.type.replace('mmgen','MMGen'),
 				res.twmmid.addr.hl() )
 				res.twmmid.addr.hl() )

+ 3 - 3
test/test_py_d/ts_regtest.py

@@ -1157,11 +1157,11 @@ class TestSuiteRegtest(TestSuiteBase,TestSuiteShared):
 		return t
 		return t
 
 
 	def alice_add_comment_badaddr1(self):
 	def alice_add_comment_badaddr1(self):
-		return self.alice_add_comment_badaddr( rt_pw,'Invalid coin address for this chain: ', 2)
+		return self.alice_add_comment_badaddr( rt_pw, 'invalid address', 2 )
 
 
 	def alice_add_comment_badaddr2(self):
 	def alice_add_comment_badaddr2(self):
 		addr = init_proto(self.proto.coin,network='mainnet').pubhash2addr(bytes(20),False) # mainnet zero address
 		addr = init_proto(self.proto.coin,network='mainnet').pubhash2addr(bytes(20),False) # mainnet zero address
-		return self.alice_add_comment_badaddr( addr, f'Invalid coin address for this chain: {addr}', 2 )
+		return self.alice_add_comment_badaddr( addr, 'invalid address', 2 )
 
 
 	def alice_add_comment_badaddr3(self):
 	def alice_add_comment_badaddr3(self):
 		addr = self._user_sid('alice') + ':C:123'
 		addr = self._user_sid('alice') + ':C:123'
@@ -1169,7 +1169,7 @@ class TestSuiteRegtest(TestSuiteBase,TestSuiteShared):
 
 
 	def alice_add_comment_badaddr4(self):
 	def alice_add_comment_badaddr4(self):
 		addr = self.proto.pubhash2addr(bytes(20),False) # regtest (testnet) zero address
 		addr = self.proto.pubhash2addr(bytes(20),False) # regtest (testnet) zero address
-		return self.alice_add_comment_badaddr( addr, f'Address {addr!r} not found in tracking wallet', 2 )
+		return self.alice_add_comment_badaddr( addr, f'Coin address {addr!r} not found in tracking wallet', 2 )
 
 
 	def alice_remove_comment1(self):
 	def alice_remove_comment1(self):
 		sid = self._user_sid('alice')
 		sid = self._user_sid('alice')