Browse Source

mmgen-addrimport: reimplement using TrackingWallet

MMGen 6 years ago
parent
commit
556d1ca937
7 changed files with 133 additions and 70 deletions
  1. 48 28
      mmgen/altcoins/eth/tw.py
  2. 29 25
      mmgen/main_addrimport.py
  3. 3 2
      mmgen/main_tool.py
  4. 0 11
      mmgen/obj.py
  5. 9 1
      mmgen/tool.py
  6. 34 3
      mmgen/tw.py
  7. 10 0
      mmgen/util.py

+ 48 - 28
mmgen/altcoins/eth/tw.py

@@ -30,11 +30,13 @@ from mmgen.addr import AddrData
 class EthereumTrackingWallet(TrackingWallet):
 
 	desc = 'Ethereum tracking wallet'
+	caps = ()
 
 	data_dir = os.path.join(g.altcoin_data_dir,'eth',g.proto.data_subdir)
 	tw_file = os.path.join(data_dir,'tracking-wallet.json')
 
-	def __init__(self):
+	def __init__(self,mode='r'):
+		TrackingWallet.__init__(self,mode=mode)
 		check_or_create_dir(self.data_dir)
 		try:
 			self.orig_data = get_data_from_file(self.tw_file,silent=True)
@@ -43,14 +45,19 @@ class EthereumTrackingWallet(TrackingWallet):
 			try: os.stat(self.tw_file)
 			except:
 				self.orig_data = ''
-				self.data = {'accounts':{}}
+				self.data = {'accounts':{},'tokens':{}}
 			else: die(2,"File '{}' exists but does not contain valid json data")
 		else:
 			self.upgrade_wallet_maybe()
-			ad = self.data['accounts']
-			for v in ad.values():
-				v['mmid'] = TwMMGenID(v['mmid'],on_fail='raise')
-				v['comment'] = TwComment(v['comment'],on_fail='raise')
+			if not 'tokens' in self.data:
+				self.data['tokens'] = {}
+			def conv_types(ad):
+				for v in ad.values():
+					v['mmid'] = TwMMGenID(v['mmid'],on_fail='raise')
+					v['comment'] = TwComment(v['comment'],on_fail='raise')
+			conv_types(self.data['accounts'])
+			for v in self.data['tokens'].values():
+				conv_types(v)
 
 	def upgrade_wallet_maybe(self):
 		if not 'accounts' in self.data:
@@ -61,8 +68,12 @@ class EthereumTrackingWallet(TrackingWallet):
 			self.orig_data = json.dumps(self.data)
 			msg('{} upgraded successfully!'.format(self.desc))
 
-	def import_address(self,addr,label):
-		ad = self.data['accounts']
+	def data_root(self): return self.data['accounts']
+	def data_root_desc(self): return 'accounts'
+
+	@write_mode
+	def import_address(self,addr,label,foo):
+		ad = self.data_root()
 		if addr in ad:
 			if not ad[addr]['mmid'] and label.mmid:
 				msg("Warning: MMGen ID '{}' was missing in tracking wallet!".format(label.mmid))
@@ -70,55 +81,64 @@ class EthereumTrackingWallet(TrackingWallet):
 				die(3,"MMGen ID '{}' does not match tracking wallet!".format(label.mmid))
 		ad[addr] = { 'mmid': label.mmid, 'comment': label.comment }
 
-	# use 'check_data' to make sure wallet hasn't been altered by another program
-	def write(self):
+	@write_mode
+	def write(self): # use 'check_data' to check wallet hasn't been altered by another program
 		write_data_to_file( self.tw_file,
 							json.dumps(self.data),'Ethereum tracking wallet data',
 							ask_overwrite=False,ignore_opt_outdir=True,silent=True,
 							check_data=True,cmp_data=self.orig_data)
 
+	@write_mode
 	def delete_all(self):
 		self.data = {}
 		self.write()
 
-	def delete(self,addr):
+	@write_mode
+	def remove_address(self,addr):
+		root = self.data_root()
+
+		from mmgen.obj import is_coin_addr,is_mmgen_id
 		if is_coin_addr(addr):
 			have_match = lambda k: k == addr
 		elif is_mmgen_id(addr):
-			have_match = lambda k: self.data['accounts'][k]['mmid'] == addr
+			have_match = lambda k: root[k]['mmid'] == addr
 		else:
 			die(1,"'{}' is not an Ethereum address or MMGen ID".format(addr))
 
-		for k in self.data['accounts']:
+		for k in root:
 			if have_match(k):
-				del self.data['accounts'][k]
-				break
+				# return the addr resolved to mmid if possible
+				ret = root[k]['mmid'] if is_mmgen_id(root[k]['mmid']) else addr
+				del root[k]
+				self.write()
+				return ret
 		else:
-			die(1,"Address '{}' not found in tracking wallet".format(addr))
-		self.write()
+			m = "Address '{}' not found in '{}' section of tracking wallet"
+			msg(m.format(addr,self.data_root_desc()))
+			return None
 
 	def is_in_wallet(self,addr):
-		return addr in self.data['accounts']
+		return addr in self.data_root()
 
 	def sorted_list(self):
 		return sorted(
-			map(lambda x: {'addr':x[0], 'mmid':x[1]['mmid'], 'comment':x[1]['comment'] },
-								self.data['accounts'].items()),
-			key=lambda x: x['mmid'].sort_key+x['addr']
-			)
+			map(lambda x: {'addr':x[0],'mmid':x[1]['mmid'],'comment':x[1]['comment']},self.data_root().items()),
+			key=lambda x: x['mmid'].sort_key+x['addr'] )
 
 	def mmid_ordered_dict(self):
 		from collections import OrderedDict
 		return OrderedDict(map(lambda x: (x['mmid'],{'addr':x['addr'],'comment':x['comment']}), self.sorted_list()))
 
+	@write_mode
 	def import_label(self,coinaddr,lbl):
-		for addr,d in self.data['accounts'].items():
+		for addr,d in self.data_root().items():
 			if addr == coinaddr:
 				d['comment'] = lbl.comment
 				self.write()
 				return None
 		else: # emulate RPC library
-			return ('rpcfail',(None,2,"Address '{}' not found in tracking wallet".format(coinaddr)))
+			m = "Address '{}' not found in '{}' section of tracking wallet"
+			return ('rpcfail',(None,2,m.format(coinaddr,self.data_root_desc())))
 
 # Use consistent naming, even though Ethereum doesn't have unspent outputs
 class EthereumTwUnspentOutputs(TwUnspentOutputs):
@@ -144,12 +164,12 @@ Display options: show [D]ays, show [m]mgen addr, r[e]draw screen
 				'address': d['addr'],
 				'amount': ETHAmt(int(g.rpch.eth_getBalance('0x'+d['addr']),16),'wei'),
 				'confirmations': 0, # TODO
-				}, EthereumTrackingWallet().sorted_list())
+				}, TrackingWallet().sorted_list())
 
 class EthereumTwAddrList(TwAddrList):
 
 	def __init__(self,usr_addr_list,minconf,showempty,showbtcaddrs,all_labels):
-		tw = EthereumTrackingWallet().mmid_ordered_dict()
+		tw = TrackingWallet().mmid_ordered_dict()
 		self.total = g.proto.coin_amt('0')
 
 		rpc_init()
@@ -177,7 +197,7 @@ class EthereumTwGetBalance(TwGetBalance):
 	fs = '{w:13} {c}\n' # TODO - for now, just suppress display of meaningless data
 
 	def create_data(self):
-		data = EthereumTrackingWallet().mmid_ordered_dict()
+		data = TrackingWallet().mmid_ordered_dict()
 		for d in data:
 			keys = ['TOTAL']
 			keys += [str(d.obj.sid)] if d.type == 'mmgen' else ['Non-MMGen']
@@ -194,6 +214,6 @@ class EthereumAddrData(AddrData):
 	@classmethod
 	def get_tw_data(cls):
 		vmsg('Getting address data from tracking wallet')
-		tw = EthereumTrackingWallet().mmid_ordered_dict()
+		tw = TrackingWallet().mmid_ordered_dict()
 		# emulate the output of RPC 'listaccounts' and 'getaddressesbyaccount'
 		return [(mmid+' '+d['comment'],[d['addr']]) for mmid,d in tw.items()]

+ 29 - 25
mmgen/main_addrimport.py

@@ -103,23 +103,30 @@ qmsg('OK. {} addresses{}'.format(al.num_addrs,m))
 
 err_flag = False
 
-if g.coin == 'ETH':
-	if opt.rescan:
-		die('--rescan option meaningless for coin {}'.format(g.coin))
-	from mmgen.altcoins.eth.tw import EthereumTrackingWallet
-	eth_tw = EthereumTrackingWallet()
-
-	def import_address(addr,label,rescan):
-		eth_tw.import_address(addr,label)
-else:
-	if not opt.quiet: confirm_or_exit(ai_msgs('rescan'),'continue',expect='YES')
-
-	def import_address(addr,label,rescan):
-		try:
-			g.rpch.importaddress(addr,label,rescan,timeout=(False,3600)[rescan])
-		except:
-			global err_flag
-			err_flag = True
+from mmgen.tw import TrackingWallet
+try:
+	tw = TrackingWallet(mode='w')
+except UnrecognizedTokenSymbolError as e:
+	m1 = "Note: when importing addresses for a new token, the token must be specified"
+	m2 = "by address, not symbol."
+	die(1,'{}\n{}\n{}'.format(e[0],m1,m2))
+
+if opt.rescan and not 'rescan' in tw.caps:
+	msg("'--rescan' ignored: not supported by {}".format(type(tw).__name__))
+	opt.rescan = False
+
+if opt.rescan and not opt.quiet:
+	confirm_or_exit(ai_msgs('rescan'),'continue',expect='YES')
+
+if opt.batch and not 'batch' in tw.caps:
+	msg("'--batch' ignored: not supported by {}".format(type(tw).__name__))
+	opt.batch = False
+
+def import_address(addr,label,rescan):
+	try: tw.import_address(addr,label,rescan)
+	except:
+		global err_flag
+		err_flag = True
 
 w_n_of_m = len(str(al.num_addrs)) * 2 + 2
 w_mmid = 1 if opt.addrlist or opt.address else len(str(max(al.idxs()))) + 13
@@ -127,11 +134,9 @@ msg_fmt = '{{:{}}} {{:34}} {{:{}}}'.format(w_n_of_m,w_mmid)
 
 if opt.rescan: import threading
 
-msg(u'Importing {} address{} from {}{}'.format(
-		len(al.data),
-		suf(al.data,'es'),
-		infile,
-		('',' (batch mode)')[bool(opt.batch)]))
+fs = u'Importing {} address{} from {}{}'
+bm =' (batch mode)' if opt.batch else ''
+msg(fs.format(len(al.data),suf(al.data,'es'),infile,bm))
 
 if not al.data[0].addr.is_for_chain(g.chain):
 	die(2,'Address{} not compatible with {} chain!'.format((' list','')[bool(opt.address)],g.chain))
@@ -175,8 +180,7 @@ for n,e in enumerate(al.data):
 		msg(' - OK')
 
 if opt.batch:
-	ret = g.rpch.importaddress(arg_list,batch=True)
+	ret = tw.batch_import_address(arg_list)
 	msg('OK: {} addresses imported'.format(len(ret)))
 
-if g.coin == 'ETH':
-	eth_tw.write()
+tw.write()

+ 3 - 2
mmgen/main_tool.py

@@ -85,8 +85,9 @@ File encryption:
       * The encrypted file is indistinguishable from random data
 
 {pnm}-specific operations:
-  add_label    - add descriptive label for {pnm} address in tracking wallet
-  remove_label - remove descriptive label for {pnm} address in tracking wallet
+  remove_address - remove an address from tracking wallet
+  add_label      - add descriptive label for {pnm} address in tracking wallet
+  remove_label   - remove descriptive label for {pnm} address in tracking wallet
   addrfile_chksum    - compute checksum for {pnm} address file
   keyaddrfile_chksum - compute checksum for {pnm} key-address file
   passwdfile_chksum  - compute checksum for {pnm} password file

+ 0 - 11
mmgen/obj.py

@@ -443,17 +443,6 @@ class CoinAddr(str,Hilite,InitErrors,MMGenObject):
 		else:
 			return pfx_ok(vn[self.addr_fmt][1])
 
-	def is_in_tracking_wallet(self):
-
-		from mmgen.globalvars import g
-		if g.coin in ('ETH','ETC'):
-			from mmgen.altcoins.eth.tw import EthereumTrackingWallet
-			return EthereumTrackingWallet().is_in_wallet(self)
-
-		from mmgen.rpc import rpc_init
-		d = rpc_init().validateaddress(self)
-		return d['iswatchonly'] and 'account' in d
-
 class ViewKey(object):
 	def __new__(cls,s,on_fail='die'):
 		from mmgen.globalvars import g

+ 9 - 1
mmgen/tool.py

@@ -88,6 +88,7 @@ cmd_data = OrderedDict([
 
 	('Add_label',       ['<{} or coin address> [str]'.format(pnm),'<label> [str]']),
 	('Remove_label',    ['<{} or coin address> [str]'.format(pnm)]),
+	('Remove_address',  ['<{} or coin address> [str]'.format(pnm)]),
 	('Addrfile_chksum', ['<{} addr file> [str]'.format(pnm),"mmtype [str='']"]),
 	('Keyaddrfile_chksum', ['<{} addr file> [str]'.format(pnm),"mmtype [str='']"]),
 	('Passwdfile_chksum', ['<{} password file> [str]'.format(pnm)]),
@@ -714,7 +715,14 @@ def Twview(pager=False,reverse=False,wide=False,minconf=1,sort='age',show_days=T
 def Add_label(mmaddr_or_coin_addr,label):
 	rpc_init()
 	from mmgen.tw import TrackingWallet
-	TrackingWallet().add_label(mmaddr_or_coin_addr,label,on_fail='raise')
+	TrackingWallet(mode='w').add_label(mmaddr_or_coin_addr,label,on_fail='raise')
 
 def Remove_label(mmaddr_or_coin_addr):
 	Add_label(mmaddr_or_coin_addr,'')
+
+def Remove_address(mmaddr_or_coin_addr):
+	from mmgen.tw import TrackingWallet
+	tw = TrackingWallet(mode='w')
+	ret = tw.remove_address(mmaddr_or_coin_addr)
+	if ret:
+		msg("Address '{}' deleted from tracking wallet".format(ret))

+ 34 - 3
mmgen/tw.py

@@ -313,7 +313,7 @@ watch-only wallet using '{}-addrimport' and then re-run this program.
 				idx,lbl = self.get_idx_and_label_from_user()
 				if idx:
 					e = self.unspent[idx-1]
-					if TrackingWallet().add_label(e.twmmid,lbl,addr=e.addr):
+					if TrackingWallet(mode='w').add_label(e.twmmid,lbl,addr=e.addr):
 						self.get_unspent_data()
 						self.do_sort()
 						msg(u'{}\n{}\n{}'.format(self.fmt_display,prompt,p))
@@ -472,6 +472,30 @@ class TrackingWallet(MMGenObject):
 	def __new__(cls,*args,**kwargs):
 		return MMGenObject.__new__(altcoin_subclass(cls,'tw','TrackingWallet'),*args,**kwargs)
 
+	mode = 'r'
+	caps = ('rescan','batch')
+
+	def __init__(self,mode='r'):
+		m = "'{}': invalid 'mode' parameter for {} constructor"
+		assert mode in ('r','w'),m.format(mode,type(self).__name__)
+		self.mode = mode
+
+	@write_mode
+	def import_address(self,addr,label,rescan):
+		return g.rpch.importaddress(addr,label,rescan,timeout=(False,3600)[rescan])
+
+	@write_mode
+	def batch_import_address(self,arg_list):
+		return g.rpch.importaddress(arg_list,batch=True)
+
+	@write_mode
+	def write(self): pass
+
+	def is_in_wallet(self,addr):
+		d = g.rpch.validateaddress(addr)
+		return d['iswatchonly'] and 'account' in d
+
+	@write_mode
 	def import_label(self,coinaddr,lbl):
 		# NOTE: this works because importaddress() removes the old account before
 		# associating the new account with the address.
@@ -480,6 +504,7 @@ class TrackingWallet(MMGenObject):
 		return g.rpch.importaddress(coinaddr,lbl,False,on_fail='return')
 
 	# returns on failure
+	@write_mode
 	def add_label(self,arg1,label='',addr=None,silent=False,on_fail='return'):
 		from mmgen.tx import is_mmgen_id,is_coin_addr
 		mmaddr,coinaddr = None,None
@@ -496,7 +521,7 @@ class TrackingWallet(MMGenObject):
 			if not is_mmgen_id(arg1):
 				assert coinaddr,u"Invalid coin address for this chain: {}".format(arg1)
 			assert coinaddr,u"{pn} address '{ma}' not found in tracking wallet"
-			assert coinaddr.is_in_tracking_wallet(),u"Address '{ca}' not found in tracking wallet"
+			assert self.is_in_wallet(coinaddr),u"Address '{ca}' not found in tracking wallet"
 		except Exception as e:
 			msg(e[0].format(pn=g.proj_name,ma=mmaddr,ca=coinaddr))
 			return False
@@ -532,7 +557,13 @@ class TrackingWallet(MMGenObject):
 			else:     msg(u'Removed label from {}'.format(s))
 			return True
 
-	def remove_label(self,mmaddr): self.add_label(mmaddr,'')
+	@write_mode
+	def remove_label(self,mmaddr):
+		self.add_label(mmaddr,'')
+
+	@write_mode
+	def remove_address(self,addr):
+		raise NotImplementedError,'address removal not implemented for coin {}'.format(g.coin)
 
 class TwGetBalance(MMGenObject):
 

+ 10 - 0
mmgen/util.py

@@ -892,6 +892,7 @@ def format_par(s,indent=0,width=80,as_list=False):
 		lines.append(' '*indent + line)
 	return lines if as_list else '\n'.join(lines) + '\n'
 
+# module loading magic for tx.py and tw.py
 def altcoin_subclass(cls,mod_id,cls_name):
 	if cls.__name__ != cls_name: return cls
 	pn = capfirst(g.proto.name)
@@ -900,3 +901,12 @@ def altcoin_subclass(cls,mod_id,cls_name):
 	e2 = 'cls = {}{}{}'.format(pn,tn,cls_name)
 	try: exec e1; exec e2; return cls
 	except ImportError: return cls
+
+# decorator for TrackingWallet
+def write_mode(orig_func):
+	def f(self,*args,**kwargs):
+		if self.mode != 'w':
+			m = '{} opened in read-only mode: cannot execute method {}()'
+			die(1,m.format(type(self).__name__,locals()['orig_func'].__name__))
+		return orig_func(self,*args,**kwargs)
+	return f