Browse Source

tw.ctl: move store-related methods + init to `tw.store`

The MMGen Project 6 months ago
parent
commit
6e084d4f71
3 changed files with 189 additions and 172 deletions
  1. 1 3
      mmgen/proto/btc/tw/ctl.py
  2. 14 155
      mmgen/tw/ctl.py
  3. 174 14
      mmgen/tw/store.py

+ 1 - 3
mmgen/proto/btc/tw/ctl.py

@@ -17,12 +17,10 @@ from ....util import msg, msg_r, rmsg, die, suf, fmt_list
 
 class BitcoinTwCtl(TwCtl):
 
-	def init_empty(self):
-		self.data = {'coin': self.proto.coin, 'addresses': {}}
-
 	async def rpc_get_balance(self, addr, block='latest'):
 		raise NotImplementedError('not implemented')
 
+	# TODO: do check with check_import_mmid()
 	@write_mode
 	async def import_address(self, addr, *, label, rescan=False):
 		if (await self.rpc.walletinfo).get('descriptors'):

+ 14 - 155
mmgen/tw/ctl.py

@@ -20,11 +20,9 @@
 tw.ctl: Tracking wallet control class for the MMGen suite
 """
 
-import json
 from collections import namedtuple
-from pathlib import Path
 
-from ..util import msg, msg_r, suf, die
+from ..util import msg, msg_r, ymsg, suf, die
 from ..base_obj import AsyncInit
 from ..objmethods import MMGenObject
 from ..obj import TwComment, get_obj
@@ -50,10 +48,7 @@ class TwCtl(MMGenObject, metaclass=AsyncInit):
 
 	caps = ('rescan', 'batch')
 	data_key = 'addresses'
-	use_tw_file = False
-	aggressive_sync = False
 	importing = False
-	tw_fn = 'tracking-wallet.json'
 	use_cached_balances = False
 
 	def __new__(cls, cfg, proto, *args, **kwargs):
@@ -68,7 +63,6 @@ class TwCtl(MMGenObject, metaclass=AsyncInit):
 			mode              = 'r',
 			token_addr        = None,
 			no_rpc            = False,
-			no_wallet_init    = False,
 			rpc_ignore_wallet = False):
 
 		assert mode in ('r', 'w', 'i'), f"{mode!r}: wallet mode must be 'r', 'w' or 'i'"
@@ -80,158 +74,10 @@ class TwCtl(MMGenObject, metaclass=AsyncInit):
 		self.proto = proto
 		self.mode = mode
 		self.desc = self.base_desc = f'{self.proto.name} tracking wallet'
-		self.cur_balances = {} # cache balances to prevent repeated lookups per program invocation
-
-		if cfg.cached_balances:
-			self.use_cached_balances = True
 
 		if not no_rpc:
 			self.rpc = await rpc_init(cfg, proto, ignore_wallet=rpc_ignore_wallet)
 
-		if self.use_tw_file:
-			if self.proto.coin == 'BTC':
-				self.tw_dir = Path(self.cfg.data_dir)
-			else:
-				self.tw_dir = Path(
-					self.cfg.data_dir_root,
-					'altcoins',
-					self.proto.coin.lower(),
-					('' if self.proto.network == 'mainnet' else self.proto.network)
-				)
-			self.tw_path = self.tw_dir / self.tw_fn
-
-		if no_wallet_init:
-			return
-
-		if self.use_tw_file:
-			self.init_from_wallet_file()
-		else:
-			self.init_empty()
-
-		if self.data['coin'] != self.proto.coin: # TODO remove?
-			die('WalletFileError',
-				f'Tracking wallet coin ({self.data["coin"]}) does not match current coin ({self.proto.coin})!')
-
-		self.conv_types(self.data[self.data_key])
-
-	def upgrade_wallet_maybe(self):
-		pass
-
-	def init_from_wallet_file(self):
-		from ..fileutil import check_or_create_dir, get_data_from_file
-		check_or_create_dir(self.tw_dir)
-		try:
-			self.orig_data = get_data_from_file(self.cfg, self.tw_path, quiet=True)
-			self.data = json.loads(self.orig_data)
-		except:
-			try:
-				self.tw_path.stat()
-			except:
-				self.orig_data = ''
-				self.init_empty()
-				self.force_write()
-			else:
-				die('WalletFileError', f'File ‘{self.tw_path}’ exists but does not contain valid JSON data')
-		else:
-			self.upgrade_wallet_maybe()
-
-		# ensure that wallet file is written when user exits via KeyboardInterrupt:
-		if self.mode == 'w':
-			import atexit
-			def del_twctl(twctl):
-				self.cfg._util.dmsg(f'Running exit handler del_twctl() for {twctl!r}')
-				del twctl
-			atexit.register(del_twctl, self)
-
-	def __del__(self):
-		"""
-		TwCtl instances opened in write or import mode must be explicitly destroyed with ‘del
-		twuo.twctl’ and the like to ensure the instance is deleted and wallet is written before
-		global vars are destroyed by the interpreter at shutdown.
-
-		Not that this code can only be debugged by examining the program output, as exceptions
-		are ignored within __del__():
-
-			/usr/share/doc/python3.6-doc/html/reference/datamodel.html#object.__del__
-
-		Since no exceptions are raised, errors will not be caught by the test suite.
-		"""
-		if getattr(self, 'mode', None) == 'w': # mode attr might not exist in this state
-			self.write()
-		elif self.cfg.debug:
-			msg('read-only wallet, doing nothing')
-
-	def conv_types(self, ad):
-		for k, v in ad.items():
-			if k not in ('params', 'coin'):
-				v['mmid'] = TwMMGenID(self.proto, v['mmid'])
-				v['comment'] = TwComment(v['comment'])
-
-	@property
-	def data_root(self):
-		return self.data[self.data_key]
-
-	@property
-	def data_root_desc(self):
-		return self.data_key
-
-	def cache_balance(self, addr, bal, *, session_cache, data_root, force=False):
-		if force or addr not in session_cache:
-			session_cache[addr] = str(bal)
-			if addr in data_root:
-				data_root[addr]['balance'] = str(bal)
-				if self.aggressive_sync:
-					self.write()
-
-	def get_cached_balance(self, addr, session_cache, data_root):
-		if addr in session_cache:
-			return self.proto.coin_amt(session_cache[addr])
-		if not self.use_cached_balances:
-			return None
-		if addr in data_root and 'balance' in data_root[addr]:
-			return self.proto.coin_amt(data_root[addr]['balance'])
-
-	async def get_balance(self, addr, *, force_rpc=False, block='latest'):
-		ret = None if force_rpc else self.get_cached_balance(addr, self.cur_balances, self.data_root)
-		if ret is None:
-			ret = await self.rpc_get_balance(addr, block=block)
-			self.cache_balance(addr, ret, session_cache=self.cur_balances, data_root=self.data_root)
-		return ret
-
-	def force_write(self):
-		mode_save = self.mode
-		self.mode = 'w'
-		self.write()
-		self.mode = mode_save
-
-	@write_mode
-	def write_changed(self, data, quiet):
-		from ..fileutil import write_data_to_file
-		write_data_to_file(
-			self.cfg,
-			self.tw_path,
-			data,
-			desc              = f'{self.base_desc} data',
-			ask_overwrite     = False,
-			ignore_opt_outdir = True,
-			quiet             = quiet,
-			check_data        = True, # die if wallet has been altered by another program
-			cmp_data          = self.orig_data)
-
-		self.orig_data = data
-
-	def write(self, *, quiet=True):
-		if not self.use_tw_file:
-			self.cfg._util.dmsg("'use_tw_file' is False, doing nothing")
-			return
-		self.cfg._util.dmsg(f'write(): checking if {self.desc} data has changed')
-
-		wdata = json.dumps(self.data)
-		if self.orig_data != wdata:
-			self.write_changed(wdata, quiet=quiet)
-		elif self.cfg.debug:
-			msg('Data is unchanged\n')
-
 	async def resolve_address(self, addrspec):
 
 		twmmid, coinaddr = (None, None)
@@ -305,6 +151,19 @@ class TwCtl(MMGenObject, metaclass=AsyncInit):
 	async def remove_comment(self, mmaddr):
 		await self.set_comment(mmaddr, '')
 
+	def check_import_mmid(self, addr, old_mmid, new_mmid):
+		'returns True if mmid needs update, None otherwise'
+		if new_mmid != old_mmid:
+			if old_mmid.endswith(':' + addr):
+				ymsg(f'Warning: address {new_mmid} was previously imported as non-MMGen!')
+				return True
+			else:
+				fs = (
+					'attempting to import MMGen address {a!r} ({b}) as non-MMGen!'
+						if new_mmid.endswith(':' + addr) else
+					'imported MMGen ID {b!r} does not match tracking wallet MMGen ID {a!r}!')
+				die(2, fs.format(a=old_mmid, b=new_mmid))
+
 	async def import_address_common(self, data, *, batch=False, gather=False):
 
 		async def do_import(address, comment, message):

+ 174 - 14
mmgen/tw/store.py

@@ -12,17 +12,86 @@
 tw.store: Tracking wallet control class with store
 """
 
+import json
+from pathlib import Path
+
+from ..base_obj import AsyncInit
+from ..obj import TwComment
 from ..util import msg, ymsg, die, cached_property
 from ..addr import is_coin_addr, is_mmgen_id, CoinAddr
 
-from .shared import TwLabel
+from .shared import TwMMGenID, TwLabel
 from .ctl import TwCtl, write_mode, label_addr_pair
 
-class TwCtlWithStore(TwCtl):
+class TwCtlWithStore(TwCtl, metaclass=AsyncInit):
 
 	caps = ('batch',)
-	data_key = 'addresses'
-	use_tw_file = True
+	tw_fn = 'tracking-wallet.json'
+	aggressive_sync = False
+
+	async def __init__(
+			self,
+			cfg,
+			proto,
+			*,
+			mode              = 'r',
+			token_addr        = None,
+			no_rpc            = False,
+			no_wallet_init    = False,
+			rpc_ignore_wallet = False):
+
+		await super().__init__(cfg, proto, mode=mode, no_rpc=no_rpc, rpc_ignore_wallet=rpc_ignore_wallet)
+
+		self.cur_balances = {} # cache balances to prevent repeated lookups per program invocation
+
+		if cfg.cached_balances:
+			self.use_cached_balances = True
+
+		self.tw_dir = Path(
+			self.cfg.data_dir_root,
+			'altcoins',
+			self.proto.coin.lower(),
+			('' if self.proto.network == 'mainnet' else self.proto.network)
+		)
+		self.tw_path = self.tw_dir / self.tw_fn
+
+		if no_wallet_init:
+			return
+
+		self.init_from_wallet_file()
+
+		if self.data['coin'] != self.proto.coin:
+			fs = 'Tracking wallet coin ({}) does not match current coin ({})!'
+			die('WalletFileError', fs.format(self.data['coin'], self.proto.coin))
+
+		self.conv_types(self.data[self.data_key])
+
+	def __del__(self):
+		"""
+		TwCtl instances opened in write or import mode must be explicitly destroyed with ‘del
+		twuo.twctl’ and the like to ensure the instance is deleted and wallet is written before
+		global vars are destroyed by the interpreter at shutdown.
+
+		Not that this code can only be debugged by examining the program output, as exceptions
+		are ignored within __del__():
+
+			/usr/share/doc/python3.6-doc/html/reference/datamodel.html#object.__del__
+
+		Since no exceptions are raised, errors will not be caught by the test suite.
+		"""
+		if getattr(self, 'mode', None) == 'w': # mode attr might not exist in this state
+			self.write()
+		elif self.cfg.debug:
+			msg('read-only wallet, doing nothing')
+
+	def upgrade_wallet_maybe(self):
+		pass
+
+	def conv_types(self, ad):
+		for k, v in ad.items():
+			if k not in ('params', 'coin'):
+				v['mmid'] = TwMMGenID(self.proto, v['mmid'])
+				v['comment'] = TwComment(v['comment'])
 
 	def init_empty(self):
 		self.data = {
@@ -31,6 +100,32 @@ class TwCtlWithStore(TwCtl):
 			'addresses': {},
 		}
 
+	def init_from_wallet_file(self):
+		from ..fileutil import check_or_create_dir, get_data_from_file
+		check_or_create_dir(self.tw_dir)
+		try:
+			self.orig_data = get_data_from_file(self.cfg, self.tw_path, quiet=True)
+			self.data = json.loads(self.orig_data)
+		except:
+			try:
+				self.tw_path.stat()
+			except:
+				self.orig_data = ''
+				self.init_empty()
+				self.force_write()
+			else:
+				die('WalletFileError', f'File ‘{self.tw_path}’ exists but does not contain valid JSON data')
+		else:
+			self.upgrade_wallet_maybe()
+
+		# ensure that wallet file is written when user exits via KeyboardInterrupt:
+		if self.mode == 'w':
+			import atexit
+			def del_twctl(twctl):
+				self.cfg._util.dmsg(f'Running exit handler del_twctl() for {twctl!r}')
+				del twctl
+			atexit.register(del_twctl, self)
+
 	@write_mode
 	async def batch_import_address(self, args_list):
 		return [await self.import_address(a, label=b, rescan=c) for a, b, c in args_list]
@@ -42,17 +137,9 @@ class TwCtlWithStore(TwCtl):
 	async def import_address(self, addr, *, label, rescan=False):
 		r = self.data_root
 		if addr in r:
-			if r[addr]['mmid']:
-				if r[addr]['mmid'] != label.mmid:
-					fs = 'imported MMGen ID {!r} does not match tracking wallet MMGen ID {!r}!'
-					die(3, fs.format(label.mmid, r[addr]['mmid']))
-			elif label.mmid:
-				ymsg(f'Warning: MMGen ID {label.mmid!r} was missing in tracking wallet!')
+			if self.check_import_mmid(addr, r[addr]['mmid'], label.mmid):
 				r[addr]['mmid'] = label.mmid
-			if not 'comment' in r[addr]:
-				ymsg(f'Warning: Label for MMGen ID {label.mmid!r} was missing in tracking wallet!')
-				r[addr]['comment'] = label.comment
-			elif label.comment: # overwrite existing comment only if new comment not empty
+			if label.comment: # overwrite existing comment only if new comment not empty
 				r[addr]['comment'] = label.comment
 		else:
 			r[addr] = {'mmid': label.mmid, 'comment': label.comment}
@@ -112,3 +199,76 @@ class TwCtlWithStore(TwCtl):
 		from decimal import Decimal
 		# TODO: for now, consider used addrs to be addrs with balance
 		return ({k for k, v in self.data['addresses'].items() if Decimal(v.get('balance', 0))})
+
+	@property
+	def data_root(self):
+		return self.data[self.data_key]
+
+	@property
+	def data_root_desc(self):
+		return self.data_key
+
+	def cache_balance(self, addr, bal, *, session_cache, data_root, force=False):
+		if force or addr not in session_cache:
+			session_cache[addr] = str(bal)
+			if addr in data_root:
+				data_root[addr]['balance'] = str(bal)
+				if self.aggressive_sync:
+					self.write()
+
+	async def rpc_get_balance(self, addr, block='latest'):
+		assert self.rpc.is_remote, 'tw.store.rpc_get_balance(): RPC is not remote!'
+		try:
+			return self.rpc.get_balance(addr, block=block)
+		except Exception as e:
+			ymsg(f'{type(e).__name__}: {e}')
+			ymsg(f'Unable to get balance for address ‘{addr}’')
+			import asyncio
+			await asyncio.sleep(3)
+
+	def get_cached_balance(self, addr, session_cache, data_root):
+		if addr in session_cache:
+			return self.proto.coin_amt(session_cache[addr])
+		if self.use_cached_balances:
+			return self.proto.coin_amt(
+				data_root[addr]['balance'] if addr in data_root and 'balance' in data_root[addr]
+				else '0')
+
+	async def get_balance(self, addr, *, force_rpc=False, block='latest'):
+		ret = None if force_rpc else self.get_cached_balance(addr, self.cur_balances, self.data_root)
+		if ret is None:
+			ret = await self.rpc_get_balance(addr, block=block)
+			if ret is not None:
+				self.cache_balance(addr, ret, session_cache=self.cur_balances, data_root=self.data_root)
+		return ret
+
+	def force_write(self):
+		mode_save = self.mode
+		self.mode = 'w'
+		self.write()
+		self.mode = mode_save
+
+	@write_mode
+	def write_changed(self, data, quiet):
+		from ..fileutil import write_data_to_file
+		write_data_to_file(
+			self.cfg,
+			self.tw_path,
+			data,
+			desc              = f'{self.base_desc} data',
+			ask_overwrite     = False,
+			ignore_opt_outdir = True,
+			quiet             = quiet,
+			check_data        = True, # die if wallet has been altered by another program
+			cmp_data          = self.orig_data)
+
+		self.orig_data = data
+
+	def write(self, *, quiet=True):
+		self.cfg._util.dmsg(f'write(): checking if {self.desc} data has changed')
+
+		wdata = json.dumps(self.data)
+		if self.orig_data != wdata:
+			self.write_changed(wdata, quiet=quiet)
+		elif self.cfg.debug:
+			msg('Data is unchanged\n')