Browse Source

rpc.py: improve tracking wallet checking/creation/loading

The MMGen Project 2 years ago
parent
commit
db20b2f34a
4 changed files with 53 additions and 20 deletions
  1. 43 13
      mmgen/base_proto/bitcoin/rpc.py
  2. 1 1
      mmgen/base_proto/ethereum/rpc.py
  3. 6 4
      mmgen/rpc.py
  4. 3 2
      mmgen/tw/ctl.py

+ 43 - 13
mmgen/base_proto/bitcoin/rpc.py

@@ -79,7 +79,7 @@ class BitcoinRPCClient(RPCClient,metaclass=AsyncInit):
 	has_auth_cookie = True
 	wallet_path = '/'
 
-	async def __init__(self,proto,daemon,backend):
+	async def __init__(self,proto,daemon,backend,ignore_wallet):
 
 		self.proto = proto
 		self.daemon = daemon
@@ -148,8 +148,8 @@ class BitcoinRPCClient(RPCClient,metaclass=AsyncInit):
 		if self.chain == 'mainnet': # skip this for testnet, as Genesis block may change
 			await check_chainfork_mismatch(block0)
 
-		if not self.chain == 'regtest':
-			await self.check_tracking_wallet()
+		if not ignore_wallet:
+			await self.check_or_create_daemon_wallet()
 
 		# for regtest, wallet path must remain '/' until Carol’s user wallet has been created
 		if g.regtest_user:
@@ -158,16 +158,46 @@ class BitcoinRPCClient(RPCClient,metaclass=AsyncInit):
 	def make_host_path(self,wallet):
 		return f'/wallet/{wallet}' if wallet else self.wallet_path
 
-	async def check_tracking_wallet(self,wallet_checked=[]):
-		if not wallet_checked:
-			wallets = await self.call('listwallets')
-			if len(wallets) == 0:
-				wname = self.daemon.tracking_wallet_name
-				await self.icall('createwallet',wallet_name=wname)
-				ymsg(f'Created {self.daemon.coind_name} wallet {wname!r}')
-			elif len(wallets) > 1: # support only one loaded wallet for now
-				die(4,f'ERROR: more than one {self.daemon.coind_name} wallet loaded: {wallets}')
-			wallet_checked.append(True)
+	async def check_or_create_daemon_wallet(self,called=[],wallet_create=True):
+		"""
+		Returns True if the correct tracking wallet is currently loaded or if a new one
+		is created, False otherwise
+		"""
+
+		if called or self.chain == 'regtest':
+			return False
+
+		twname = self.daemon.tracking_wallet_name
+		loaded_wnames = await self.call('listwallets')
+		wnames = [i['name'] for i in (await self.call('listwalletdir'))['wallets']]
+		m = f'Please fix your {self.daemon.desc} wallet installation or cmdline options'
+		ret = False
+
+		if len(loaded_wnames) == 1:
+			loaded_wname = loaded_wnames[0]
+			if twname in wnames and loaded_wname != twname:
+				await self.call('unloadwallet',loaded_wname)
+				await self.call('loadwallet',twname)
+			elif loaded_wname == '':
+				ymsg(f'WARNING: use of default wallet as tracking wallet is not recommended!\n{m}')
+			elif loaded_wname != twname:
+				ymsg(f'WARNING: loaded wallet {loaded_wname!r} is not {twname!r}\n{m}')
+			ret = True
+		elif len(loaded_wnames) == 0:
+			if twname in wnames:
+				await self.call('loadwallet',twname)
+				ret = True
+			elif wallet_create:
+				await self.icall('createwallet',wallet_name=twname)
+				ymsg(f'Created {self.daemon.coind_name} wallet {twname!r}')
+				ret = True
+		else: # support only one loaded wallet for now
+			die(4,f'ERROR: more than one {self.daemon.coind_name} wallet loaded: {loaded_wnames}')
+
+		if wallet_create:
+			called.append(True)
+
+		return ret
 
 	def get_daemon_cfg_fn(self):
 		# Use dirname() to remove 'bob' or 'alice' component

+ 1 - 1
mmgen/base_proto/ethereum/rpc.py

@@ -35,7 +35,7 @@ class CallSigs:
 
 class EthereumRPCClient(RPCClient,metaclass=AsyncInit):
 
-	async def __init__(self,proto,daemon,backend):
+	async def __init__(self,proto,daemon,backend,ignore_wallet):
 		self.proto = proto
 		self.daemon = daemon
 		self.call_sigs = getattr(CallSigs,daemon.id,None)

+ 6 - 4
mmgen/rpc.py

@@ -464,7 +464,8 @@ async def rpc_init(
 		proto,
 		backend               = None,
 		daemon                = None,
-		ignore_daemon_version = False ):
+		ignore_daemon_version = False,
+		ignore_wallet         = False ):
 
 	if not 'rpc_init' in proto.mmcaps:
 		die(1,f'rpc_init() not supported for {proto.name} protocol!')
@@ -475,9 +476,10 @@ async def rpc_init(
 
 	from .daemon import CoinDaemon
 	rpc = await cls(
-		proto   = proto,
-		daemon  = daemon or CoinDaemon(proto=proto,test_suite=g.test_suite),
-		backend = backend or opt.rpc_backend )
+		proto         = proto,
+		daemon        = daemon or CoinDaemon(proto=proto,test_suite=g.test_suite),
+		backend       = backend or opt.rpc_backend,
+		ignore_wallet = ignore_wallet )
 
 	if rpc.daemon_version > rpc.daemon.coind_version:
 		handle_unsupported_daemon_version(

+ 3 - 2
mmgen/tw/ctl.py

@@ -42,7 +42,7 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 	def __new__(cls,proto,*args,**kwargs):
 		return MMGenObject.__new__(base_proto_subclass(cls,proto,'tw','ctl'))
 
-	async def __init__(self,proto,mode='r',token_addr=None):
+	async def __init__(self,proto,mode='r',token_addr=None,rpc_ignore_wallet=False):
 
 		assert mode in ('r','w','i'), f"{mode!r}: wallet mode must be 'r','w' or 'i'"
 		if mode == 'i':
@@ -52,7 +52,8 @@ class TrackingWallet(MMGenObject,metaclass=AsyncInit):
 		if g.debug:
 			print_stack_trace(f'TW INIT {mode!r} {self!r}')
 
-		self.rpc = await rpc_init(proto) # TODO: create on demand - only certain ops require RPC
+		# TODO: create on demand - only certain ops require RPC
+		self.rpc = await rpc_init( proto, ignore_wallet=rpc_ignore_wallet )
 		self.proto = proto
 		self.mode = mode
 		self.desc = self.base_desc = f'{self.proto.name} tracking wallet'