Browse Source

improve load_cryptodome() monkey-patch function

The MMGen Project 2 months ago
parent
commit
51e456a4e5
6 changed files with 32 additions and 26 deletions
  1. 1 1
      alt-requirements.txt
  2. 23 18
      mmgen/util2.py
  3. 1 0
      pyproject.toml
  4. 3 4
      test/gentest.py
  5. 2 1
      test/modtest_d/ut_dep.py
  6. 2 2
      test/modtest_d/ut_testdep.py

+ 1 - 1
alt-requirements.txt

@@ -1 +1 @@
-pycryptodomex
+pycryptodome

+ 23 - 18
mmgen/util2.py

@@ -30,18 +30,30 @@ def die_pause(ev=0, s=''):
 	input('Press ENTER to exit')
 	sys.exit(ev)
 
-# monkey-patch function for monero-python: permits its use with pycryptodome (e.g. MSYS2)
-# instead of the expected pycryptodomex
-def load_cryptodomex():
-	try:
-		import Cryptodome # cryptodomex
-	except ImportError:
+def cffi_override_fixup():
+	from cffi import FFI
+	class FFI_override:
+		def cdef(self, csource, override=False, packed=False, pack=None):
+			self._cdef(csource, override=True, packed=packed, pack=pack)
+	FFI.cdef = FFI_override.cdef
+
+# monkey-patch function: makes modules pycryptodome and pycryptodomex available to packages that
+# expect them (monero-python, eth-keys), regardless of which one is installed on system
+def load_cryptodome(called=[]):
+	if not called:
+		cffi_override_fixup()
 		try:
-			import Crypto # cryptodome
+			import Crypto # Crypto == pycryptodome
 		except ImportError:
-			die(2, 'Unable to import either the ‘pycryptodomex’ or ‘pycryptodome’ package')
+			try:
+				import Cryptodome # Crypto == pycryptodome
+			except ImportError:
+				die(2, 'Unable to import the ‘pycryptodome’ or ‘pycryptodomex’ package')
+			else:
+				sys.modules['Crypto'] = Cryptodome # Crypto == pycryptodome
 		else:
-			sys.modules['Cryptodome'] = Crypto
+			sys.modules['Cryptodome'] = Crypto # Cryptodome == pycryptodomex
+		called.append(True)
 
 # called with no arguments by pyethereum.utils:
 def get_keccak(cfg=None, cached_ret=[]):
@@ -51,15 +63,8 @@ def get_keccak(cfg=None, cached_ret=[]):
 			cfg._util.qmsg('Using internal keccak module by user request')
 			from .contrib.keccak import keccak_256
 		else:
-			try:
-				from Cryptodome.Hash import keccak
-			except ImportError as e:
-				try:
-					from Crypto.Hash import keccak
-				except ImportError as e2:
-					msg(f'{e2} and {e}')
-					die('MMGenImportError',
-						'Please install the ‘pycryptodome’ or ‘pycryptodomex’ package on your system')
+			load_cryptodome()
+			from Crypto.Hash import keccak
 			def keccak_256(data):
 				return keccak.new(data=data, digest_bytes=32)
 		cached_ret.append(keccak_256)

+ 1 - 0
pyproject.toml

@@ -91,4 +91,5 @@ ignored-classes = [ # ignored for no-member, otherwise checked
 	"SwapMgrBase",
 	"Opts",
 	"Help",
+	"FFI_override",
 ]

+ 3 - 4
test/gentest.py

@@ -565,7 +565,6 @@ from mmgen.key import PrivKey
 from mmgen.addr import MMGenAddrType
 from mmgen.addrgen import KeyGenerator, AddrGenerator
 from mmgen.keygen import get_backends
-from mmgen.util2 import load_cryptodomex
 from test.include.common import getrand, get_ethkey, set_globals
 
 gtr = namedtuple('gen_tool_result', ['wif', 'addr', 'viewkey'])
@@ -582,9 +581,9 @@ vmsg = cfg._util.vmsg
 
 proto = cfg._proto
 
-if proto.coin == 'XMR':
-	# This must be done at top level, not in monero tool __init__
-	load_cryptodomex()
+if proto.coin in ('XMR', 'ETH', 'ETC'):
+	from mmgen.util2 import load_cryptodome
+	load_cryptodome()
 
 if __name__ == '__main__':
 	from mmgen.main import launch

+ 2 - 1
test/modtest_d/ut_dep.py

@@ -32,7 +32,8 @@ class unit_tests:
 	def keccak(self, name, ut): # used by ETH, XMR
 		from mmgen.util2 import get_keccak
 		try:
-			get_keccak()
+			keccak_256 = get_keccak()
+			keccak_256(b'abc')
 		except Exception as e:
 			rmsg(str(e))
 			return False

+ 2 - 2
test/modtest_d/ut_testdep.py

@@ -47,8 +47,8 @@ class unit_tests:
 		return True
 
 	def monero_python(self, name, ut):
-		from mmgen.util2 import load_cryptodomex
-		load_cryptodomex()
+		from mmgen.util2 import load_cryptodome
+		load_cryptodome()
 		from monero.seed import Seed
 		Seed('deadbeef' * 8).public_address()
 		return True