Browse Source

main_addrimport.py: cleanups, refactor some code to tw.ctl

The MMGen Project 2 years ago
parent
commit
4631ba7e58
2 changed files with 60 additions and 43 deletions
  1. 10 40
      mmgen/main_addrimport.py
  2. 50 3
      mmgen/tw/ctl.py

+ 10 - 40
mmgen/main_addrimport.py

@@ -131,39 +131,6 @@ def check_opts(tw):
 
 	return batch,rescan
 
-async def import_address(args):
-	try:
-		res = await args.tw.import_address( args.addr, args.lbl )
-		qmsg(args.msg)
-		return res
-	except Exception as e:
-		die(2,f'\nImport of address {args.addr!r} failed: {e.args[0]!r}')
-
-def gen_args_list(tw,al,batch):
-
-	fs = '{:%s} {:34} {:%s} - OK' % (
-		len(str(al.num_addrs)) * 2 + 2,
-		1 if opt.addrlist or opt.address else len(str(max(al.idxs()))) + 13 )
-
-	ad = namedtuple('args_list_data',['addr','lbl','tw','msg'])
-
-	for num,e in enumerate(al.data,1):
-		if e.idx:
-			label = f'{al.al_id}:{e.idx}' + (' ' + e.label if e.label else '')
-			add_msg = label
-		else:
-			label = f'{proto.base_coin.lower()}:{e.addr}'
-			add_msg = 'non-'+g.proj_name
-
-		if batch:
-			yield ad( e.addr, TwLabel(proto,label), None, None )
-		else:
-			yield ad(
-				addr = e.addr,
-				lbl  = TwLabel(proto,label),
-				tw   = tw,
-				msg  = fs.format(f'{num}/{al.num_addrs}:', e.addr, f'({add_msg})') )
-
 async def main():
 	from .tw.ctl import TrackingWallet
 	if opt.token_addr:
@@ -195,14 +162,17 @@ async def main():
 
 	batch,rescan = check_opts(tw)
 
-	args_list = list(gen_args_list(tw,al,batch))
+	def gen_args_list(al):
+		_d = namedtuple('import_data',['addr','twmmid','comment'])
+		for num,e in enumerate(al.data,1):
+			yield _d(
+				addr    = e.addr,
+				twmmid  = f'{al.al_id}:{e.idx}' if e.idx else f'{proto.base_coin.lower()}:{e.addr}',
+				comment = e.label )
 
-	if batch:
-		ret = await tw.batch_import_address([ (a.addr,a.lbl) for a in args_list ])
-		msg(f'OK: {len(ret)} addresses imported')
-	else:
-		await asyncio.gather(*(import_address(a) for a in args_list))
-		msg('Address import completed OK')
+	args_list = list(gen_args_list(al))
+
+	await tw.import_address_common( args_list, batch=batch )
 
 	if rescan:
 		await tw.rescan_addresses({a.addr for a in args_list})

+ 50 - 3
mmgen/tw/ctl.py

@@ -20,10 +20,19 @@
 twctl: Tracking wallet control class for the MMGen suite
 """
 
+import asyncio,json
 from collections import namedtuple
 
 from ..globalvars import g
-from ..util import msg,dmsg,write_mode,base_proto_subclass,die
+from ..util import (
+	msg,
+	msg_r,
+	qmsg,
+	dmsg,
+	suf,
+	write_mode,
+	base_proto_subclass,
+	die )
 from ..base_obj import AsyncInit
 from ..objmethods import MMGenObject
 from ..obj import TwComment,get_obj
@@ -73,7 +82,7 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 		self.cur_balances = {} # cache balances to prevent repeated lookups per program invocation
 
 	def init_from_wallet_file(self):
-		import os,json
+		import os
 		tw_dir = (
 			os.path.join(g.data_dir) if self.proto.coin == 'BTC' else
 			os.path.join(
@@ -207,7 +216,6 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 			return
 		dmsg(f'write(): checking if {self.desc} data has changed')
 
-		import json
 		wdata = json.dumps(self.data)
 
 		if self.orig_data != wdata:
@@ -293,3 +301,42 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 	@write_mode
 	async def remove_label(self,mmaddr):
 		await self.add_label(mmaddr,'')
+
+	async def import_address_common(self,data,batch=False,gather=False):
+
+		async def do_import(address,label,message):
+			try:
+				res = await self.import_address( address, label )
+				qmsg(message)
+				return res
+			except Exception as e:
+				die(2,f'\nImport of address {address!r} failed: {e.args[0]!r}')
+
+		_d = namedtuple( 'formatted_import_data', data[0]._fields + ('mmid_disp',))
+		pfx = self.proto.base_coin.lower() + ':'
+		fdata = [ _d(*d, 'non-MMGen' if d.twmmid.startswith(pfx) else d.twmmid ) for d in data ]
+
+		fs = '{:%s}: {:%s} {:%s} - OK' % (
+			len(str(len(fdata))) * 2 + 1,
+			max(len(d.addr) for d in fdata),
+			max(len(d.mmid_disp) for d in fdata) + 2
+		)
+
+		nAddrs = len(data)
+		out = [( # create list, not generator, so we know data is valid before starting import
+				CoinAddr( self.proto, d.addr ),
+				TwLabel( self.proto, d.twmmid + (f' {d.comment}' if d.comment else '') ),
+				fs.format( f'{n}/{nAddrs}', d.addr, f'({d.mmid_disp})' )
+			) for n,d in enumerate(fdata,1)]
+
+		if batch:
+			msg_r(f'Batch importing {len(out)} address{suf(data,"es")}...')
+			ret = await self.batch_import_address((a,b) for a,b,c in out)
+			msg(f'done\n{len(ret)} addresses imported')
+		else:
+			if gather: # this seems to provide little performance benefit
+				await asyncio.gather(*(do_import(*d) for d in out))
+			else:
+				for d in out:
+					await do_import(*d)
+			msg('Address import completed OK')