Browse Source

tx.new_swap: add `get_swap_proto_mod()`, `init_proto_from_coin()`

The MMGen Project 1 week ago
parent
commit
bf6af7273b
3 changed files with 23 additions and 30 deletions
  1. 2 8
      mmgen/proto/btc/tx/new_swap.py
  2. 0 3
      mmgen/tx/bump.py
  3. 21 19
      mmgen/tx/new_swap.py

+ 2 - 8
mmgen/proto/btc/tx/new_swap.py

@@ -13,13 +13,14 @@ proto.btc.tx.new_swap: Bitcoin new swap transaction class
 """
 
 from ....tx.new_swap import NewSwap as TxNewSwap
+from ....tx.new_swap import get_swap_proto_mod
 from .new import New
 
 class NewSwap(New, TxNewSwap):
 	desc = 'Bitcoin swap transaction'
 
 	def update_data_output(self, trade_limit):
-		sp = self.swap_proto_mod
+		sp = get_swap_proto_mod(self.swap_proto)
 		o = self.data_output._asdict()
 		parsed_memo = sp.data.parse(o['data'].decode())
 		memo = sp.data(
@@ -29,13 +30,6 @@ class NewSwap(New, TxNewSwap):
 		o['data'] = f'data:{memo}'
 		self.data_output = self.Output(self.proto, **o)
 
-	def update_vault_addr(self, addr):
-		vault_idx = self.vault_idx
-		assert vault_idx == 0, f'{vault_idx}: vault index is not zero!'
-		o = self.outputs[vault_idx]._asdict()
-		o['addr'] = addr
-		self.outputs[vault_idx] = self.Output(self.proto, **o)
-
 	@property
 	def vault_idx(self):
 		return self._chg_output_ops('idx', 'is_vault')

+ 0 - 3
mmgen/tx/bump.py

@@ -42,9 +42,6 @@ class Bump(Completed, NewSwap):
 					setattr(self, attr, getattr(Base, attr))
 			self.outputs = self.OutputList(self)
 			self.cfg = kwargs['cfg'] # must use current cfg opts, not those from orig_tx
-		elif self.is_swap:
-			import importlib
-			self.swap_proto_mod = importlib.import_module(f'mmgen.swap.proto.{self.swap_proto}')
 
 		if not self.is_replaceable():
 			die(1, f'Transaction {self.txid} is not replaceable')

+ 21 - 19
mmgen/tx/new_swap.py

@@ -18,14 +18,22 @@ from ..cfg import gc
 from .new import New
 from ..amt import UniAmt
 
+def get_swap_proto_mod(swap_proto_name):
+	import importlib
+	return importlib.import_module(f'mmgen.swap.proto.{swap_proto_name}')
+
+def init_proto_from_coin(cfg, sp, coin, desc):
+	if coin not in sp.params.coins[desc]:
+		raise ValueError(f'{coin!r}: unsupported {desc} coin for {gc.proj_name} {sp.name} swap')
+	from ..protocol import init_proto
+	return init_proto(cfg, coin, network=cfg._proto.network, need_amt=True)
+
 class NewSwap(New):
 	desc = 'swap transaction'
 
 	def __init__(self, *args, **kwargs):
-		import importlib
 		self.is_swap = True
 		self.swap_proto = kwargs['cfg'].swap_proto
-		self.swap_proto_mod = importlib.import_module(f'mmgen.swap.proto.{self.swap_proto}')
 		New.__init__(self, *args, **kwargs)
 
 	def check_addr_is_wallet_addr(self, output, *, message):
@@ -62,31 +70,19 @@ class NewSwap(New):
 			# recv_coin      # required: uppercase coin symbol
 			recv_spec = None # optional: destination address spec. Same rules as for chg_spec
 
-		def check_coin_arg(coin, desc):
-			if coin not in sp.params.coins[desc]:
-				raise ValueError(f'{coin!r}: unsupported {desc} coin for {gc.proj_name} {sp.name} swap')
-			return coin
-
 		def get_arg():
 			try:
 				return args_in.pop(0)
 			except:
 				self.cfg._usage()
 
-		def init_proto_from_coin(coinsym, desc):
-			return init_proto(
-				self.cfg,
-				check_coin_arg(coinsym, desc),
-				network = self.proto.network,
-				need_amt = True)
-
 		def parse():
 
 			from ..amt import is_coin_amt
 			arg = get_arg()
 
 			# arg 1: send_coin
-			self.send_proto = init_proto_from_coin(arg, 'send')
+			self.send_proto = init_proto_from_coin(self.cfg, sp, arg, 'send')
 			arg = get_arg()
 
 			# arg 2: amt
@@ -101,7 +97,7 @@ class NewSwap(New):
 					arg = get_arg()
 
 			# arg 4: recv_coin
-			self.recv_proto = init_proto_from_coin(arg, 'receive')
+			self.recv_proto = init_proto_from_coin(self.cfg, sp, arg, 'receive')
 
 			# arg 5: recv_spec (receive address spec)
 			if args_in:
@@ -110,8 +106,7 @@ class NewSwap(New):
 			if args_in: # done parsing, all args consumed
 				self.cfg._usage()
 
-		from ..protocol import init_proto
-		sp = self.swap_proto_mod
+		sp = get_swap_proto_mod(self.swap_proto)
 		args_in = list(cmd_args)
 		args = CmdlineArgs()
 		parse()
@@ -155,8 +150,15 @@ class NewSwap(New):
 		else:
 			self.usr_trade_limit = None
 
+	def update_vault_addr(self, addr):
+		vault_idx = self.vault_idx
+		assert vault_idx == 0, f'{vault_idx}: vault index is not zero!'
+		o = self.outputs[vault_idx]._asdict()
+		o['addr'] = addr
+		self.outputs[vault_idx] = self.Output(self.proto, **o)
+
 	def update_vault_output(self, amt, *, deduct_est_fee=False):
-		sp = self.swap_proto_mod
+		sp = get_swap_proto_mod(self.swap_proto)
 		c = sp.rpc_client(self, amt)
 
 		from ..util import msg