util.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. #!/usr/bin/env python3
  2. #
  3. # MMGen Wallet, a terminal-based cryptocurrency wallet
  4. # Copyright (C)2013-2025 The MMGen Project <mmgen@tuta.io>
  5. #
  6. # This program is free software: you can redistribute it and/or modify
  7. # it under the terms of the GNU General Public License as published by
  8. # the Free Software Foundation, either version 3 of the License, or
  9. # (at your option) any later version.
  10. #
  11. # This program is distributed in the hope that it will be useful,
  12. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  13. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  14. # GNU General Public License for more details.
  15. #
  16. # You should have received a copy of the GNU General Public License
  17. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  18. """
  19. util: Frequently-used variables, classes and utility functions for the MMGen suite
  20. """
  21. import sys, os, time, re
  22. from .color import red, yellow, green, blue, purple
  23. from .cfg import gv
  24. ascii_lowercase = 'abcdefghijklmnopqrstuvwxyz'
  25. digits = '0123456789'
  26. hexdigits = '0123456789abcdefABCDEF'
  27. hexdigits_uc = '0123456789ABCDEF'
  28. hexdigits_lc = '0123456789abcdef'
  29. def noop(*args, **kwargs):
  30. pass
  31. class Util:
  32. def __init__(self, cfg):
  33. self.cfg = cfg
  34. if cfg.quiet:
  35. self.qmsg = self.qmsg_r = noop
  36. else:
  37. self.qmsg = msg
  38. self.qmsg_r = msg_r
  39. if cfg.verbose:
  40. self.vmsg = msg
  41. self.vmsg_r = msg_r
  42. self.Vmsg = Msg
  43. self.Vmsg_r = Msg_r
  44. else:
  45. self.vmsg = self.vmsg_r = self.Vmsg = self.Vmsg_r = noop
  46. self.dmsg = msg if cfg.debug else noop
  47. if cfg.pager:
  48. from .ui import do_pager
  49. self.stdout_or_pager = do_pager
  50. else:
  51. self.stdout_or_pager = Msg_r
  52. def compare_chksums(
  53. self,
  54. chk1,
  55. desc1,
  56. chk2,
  57. desc2,
  58. *,
  59. hdr = '',
  60. die_on_fail = False,
  61. verbose = False):
  62. if not chk1 == chk2:
  63. fs = "{} ERROR: {} checksum ({}) doesn't match {} checksum ({})"
  64. m = fs.format((hdr+':\n ' if hdr else 'CHECKSUM'), desc2, chk2, desc1, chk1)
  65. if die_on_fail:
  66. die(3, m)
  67. else:
  68. if verbose or self.cfg.verbose:
  69. msg(m)
  70. return False
  71. if self.cfg.verbose:
  72. msg(f'{capfirst(desc1)} checksum OK ({chk1})')
  73. return True
  74. def compare_or_die(self, val1, desc1, val2, desc2, *, e='Error'):
  75. if val1 != val2:
  76. die(3, f"{e}: {desc2} ({val2}) doesn't match {desc1} ({val1})")
  77. if self.cfg.debug:
  78. msg(f'{capfirst(desc2)} OK ({val2})')
  79. return True
  80. if sys.platform == 'win32':
  81. def msg_r(s):
  82. try:
  83. gv.stderr.write(s)
  84. gv.stderr.flush()
  85. except:
  86. os.write(2, s.encode())
  87. def msg(s):
  88. msg_r(s + '\n')
  89. def Msg_r(s):
  90. try:
  91. gv.stdout.write(s)
  92. gv.stdout.flush()
  93. except:
  94. os.write(1, s.encode())
  95. def Msg(s):
  96. Msg_r(s + '\n')
  97. else:
  98. def msg(s):
  99. gv.stderr.write(s + '\n')
  100. def msg_r(s):
  101. gv.stderr.write(s)
  102. gv.stderr.flush()
  103. def Msg(s):
  104. gv.stdout.write(s + '\n')
  105. def Msg_r(s):
  106. gv.stdout.write(s)
  107. gv.stdout.flush()
  108. def rmsg(s):
  109. msg(red(s))
  110. def ymsg(s):
  111. msg(yellow(s))
  112. def gmsg(s):
  113. msg(green(s))
  114. def gmsg_r(s):
  115. msg_r(green(s))
  116. def bmsg(s):
  117. msg(blue(s))
  118. def pumsg(s):
  119. msg(purple(s))
  120. def mmsg(*args):
  121. for d in args:
  122. Msg(repr(d))
  123. def mdie(*args):
  124. mmsg(*args)
  125. sys.exit(0)
  126. def die(ev, s='', *, stdout=False):
  127. match ev:
  128. case int():
  129. from .exception import MMGenSystemExit, MMGenError
  130. if ev <= 2:
  131. raise MMGenSystemExit(ev, s, stdout)
  132. else:
  133. raise MMGenError(ev, s, stdout)
  134. case str():
  135. from . import exception
  136. raise getattr(exception, ev)(s)
  137. case _:
  138. raise ValueError(f'{ev}: exit value must be string or integer')
  139. def Die(ev=0, s=''):
  140. die(ev=ev, s=s, stdout=True)
  141. def pp_fmt(d):
  142. import pprint
  143. return pprint.PrettyPrinter().pformat(d)
  144. def pp_msg(d):
  145. msg(pp_fmt(d))
  146. def indent(s, *, indent=' ', append='\n'):
  147. "indent multiple lines of text with specified string"
  148. return indent + ('\n'+indent).join(s.strip().splitlines()) + append
  149. def fmt(s, *, indent='', strip_char=None, append='\n'):
  150. "de-indent multiple lines of text, or indent with specified string"
  151. return indent + ('\n'+indent).join([l.lstrip(strip_char) for l in s.strip().splitlines()]) + append
  152. def fmt_list(iterable, *, fmt='dfl', indent='', conv=None):
  153. "pretty-format a list"
  154. _conv, sep, lq, rq = {
  155. 'dfl': (str, ", ", "'", "'"),
  156. 'utf8': (str, ", ", "“", "”"),
  157. 'bare': (repr, " ", "", ""),
  158. 'barest': (str, " ", "", ""),
  159. 'fancy': (str, " ", "‘", "’"),
  160. 'no_quotes': (str, ", ", "", ""),
  161. 'compact': (str, ",", "", ""),
  162. 'no_spc': (str, ",", "'", "'"),
  163. 'min': (str, ",", "", ""),
  164. 'repr': (repr, ", ", "", ""),
  165. 'csv': (repr, ",", "", ""),
  166. 'col': (str, "\n", "", ""),
  167. }[fmt]
  168. conv = conv or _conv
  169. return indent + (sep+indent).join(lq+conv(e)+rq for e in iterable)
  170. def fmt_dict(mapping, *, fmt='dfl', kconv=None, vconv=None):
  171. "pretty-format a dict"
  172. kc, vc, sep, fs = {
  173. 'dfl': (str, str, ", ", "'{}' ({})"),
  174. 'dfl_compact': (str, str, " ", "{} ({})"),
  175. 'square': (str, str, ", ", "'{}' [{}]"),
  176. 'square_compact':(str, str, " ", "{} [{}]"),
  177. 'equal': (str, str, ", ", "'{}'={}"),
  178. 'equal_spaced': (str, str, ", ", "'{}' = {}"),
  179. 'equal_compact': (str, str, " ", "{}={}"),
  180. 'kwargs': (str, repr, ", ", "{}={}"),
  181. 'colon': (str, repr, ", ", "{}:{}"),
  182. 'colon_compact': (str, str, " ", "{}:{}"),
  183. }[fmt]
  184. kconv = kconv or kc
  185. vconv = vconv or vc
  186. return sep.join(fs.format(kconv(k), vconv(v)) for k, v in mapping.items())
  187. def list_gen(*data):
  188. """
  189. Generate a list from an arg tuple of sublists
  190. - The last element of each sublist is a condition. If it evaluates to true, the preceding
  191. elements of the sublist are included in the result. Otherwise the sublist is skipped.
  192. - If a sublist contains only one element, the condition defaults to true.
  193. """
  194. assert type(data) in (list, tuple), f'{type(data).__name__} not in (list, tuple)'
  195. def gen():
  196. for d in data:
  197. match d:
  198. case [a]:
  199. yield a
  200. case [*a, b]:
  201. if b:
  202. yield from a
  203. case _:
  204. die(2, f'list_gen(): {d} (type {type(d).__name__}) is not an iterable')
  205. return list(gen())
  206. def remove_dups(iterable, *, edesc='element', desc='list', quiet=False, hide=False):
  207. """
  208. Remove duplicate occurrences of iterable elements, preserving first occurrence
  209. If iterable is a generator, return a list, else type(iterable)
  210. """
  211. ret = []
  212. for e in iterable:
  213. if e in ret:
  214. if not quiet:
  215. ymsg(f'Warning: removing duplicate {edesc} {"(hidden)" if hide else e} in {desc}')
  216. else:
  217. ret.append(e)
  218. return ret if type(iterable).__name__ == 'generator' else type(iterable)(ret)
  219. def contains_any(target_list, source_list):
  220. return any(map(target_list.count, source_list))
  221. def suf(arg, suf_type='s', *, verb='none'):
  222. suf_types = {
  223. 'none': {
  224. 's': ('s', ''),
  225. 'es': ('es', ''),
  226. 'ies': ('ies', 'y')},
  227. 'is': {
  228. 's': ('s are', ' is'),
  229. 'es': ('es are', ' is'),
  230. 'ies': ('ies are', 'y is')},
  231. 'has': {
  232. 's': ('s have', ' has'),
  233. 'es': ('es have', ' has'),
  234. 'ies': ('ies have', 'y has')}}
  235. match arg:
  236. case int():
  237. return suf_types[verb][suf_type][arg == 1]
  238. case list() | tuple() | set() | dict():
  239. return suf_types[verb][suf_type][len(arg) == 1]
  240. case _:
  241. die(2, f'{arg}: invalid parameter for suf()')
  242. def get_extension(fn):
  243. return os.path.splitext(fn)[1][1:]
  244. def remove_extension(fn, ext):
  245. a, b = os.path.splitext(fn)
  246. return a if b[1:] == ext else fn
  247. def make_chksum_N(s, nchars, *, sep=False, rounds=2, upper=True):
  248. if isinstance(s, str):
  249. s = s.encode()
  250. from hashlib import sha256
  251. for i in range(rounds):
  252. s = sha256(s).digest()
  253. ret = s.hex()[:nchars]
  254. if sep:
  255. assert 4 <= nchars <= 64 and (not nchars % 4), 'illegal ‘nchars’ value'
  256. ret = ' '.join(ret[i:i+4] for i in range(0, nchars, 4))
  257. else:
  258. assert 4 <= nchars <= 64, 'illegal ‘nchars’ value'
  259. return ret.upper() if upper else ret
  260. def make_chksum_8(s, *, sep=False):
  261. from .obj import HexStr
  262. from hashlib import sha256
  263. s = HexStr(sha256(sha256(s).digest()).hexdigest()[:8].upper(), case='upper')
  264. return '{} {}'.format(s[:4], s[4:]) if sep else s
  265. def make_chksum_6(s):
  266. from .obj import HexStr
  267. from hashlib import sha256
  268. if isinstance(s, str):
  269. s = s.encode()
  270. return HexStr(sha256(s).hexdigest()[:6])
  271. def is_chksum_6(s):
  272. return len(s) == 6 and set(s) <= set(hexdigits_lc)
  273. def split_into_cols(col_wid, s):
  274. return ' '.join([s[col_wid*i:col_wid*(i+1)] for i in range(len(s)//col_wid+1)]).rstrip()
  275. def capfirst(s): # different from str.capitalize() - doesn't downcase any uc in string
  276. return s if len(s) == 0 else s[0].upper() + s[1:]
  277. def decode_timestamp(s):
  278. # tz_save = open('/etc/timezone').read().rstrip()
  279. os.environ['TZ'] = 'UTC'
  280. # os.environ['TZ'] = tz_save
  281. return int(time.mktime(time.strptime(s, '%Y%m%d_%H%M%S')))
  282. def make_timestamp(secs=None):
  283. return '{:04d}{:02d}{:02d}_{:02d}{:02d}{:02d}'.format(*time.gmtime(
  284. int(secs) if secs is not None else time.time())[:6])
  285. def make_timestr(secs=None):
  286. return '{}-{:02d}-{:02d} {:02d}:{:02d}:{:02d}'.format(*time.gmtime(
  287. int(secs) if secs is not None else time.time())[:6])
  288. def secs_to_dhms(secs):
  289. hrs = secs // 3600
  290. return '{}{:02d}:{:02d}:{:02d} h/m/s'.format(
  291. ('{} day{}, '.format(hrs//24, suf(hrs//24)) if hrs > 24 else ''),
  292. hrs % 24,
  293. (secs // 60) % 60,
  294. secs % 60
  295. )
  296. def secs_to_hms(secs):
  297. return '{:02d}:{:02d}:{:02d}'.format(secs//3600, (secs//60) % 60, secs % 60)
  298. def secs_to_ms(secs):
  299. return '{:02d}:{:02d}'.format(secs//60, secs % 60)
  300. def is_int(s): # actually is_nonnegative_int()
  301. return set(str(s) or 'x') <= set(digits)
  302. def check_int_between(val, imin, imax, *, desc):
  303. if not imin <= int(val) <= imax:
  304. die(1, f'{val}: invalid value for {desc} (must be between {imin} and {imax})')
  305. return int(val)
  306. def check_member(e, iterable, desc, message='unsupported'):
  307. if e not in iterable:
  308. from mmgen.color import yellow
  309. die(1, yellow(f'{e}: {message} {desc} (must be one of {fmt_list(iterable)})'))
  310. def is_hex_str(s):
  311. return set(s) <= set(hexdigits)
  312. def is_hex_str_lc(s):
  313. return set(s) <= set(hexdigits_lc)
  314. def is_utf8(s):
  315. try:
  316. s.decode('utf8')
  317. except:
  318. return False
  319. else:
  320. return True
  321. def remove_whitespace(s, *, ws='\t\r\n '):
  322. return s.translate(dict((ord(e), None) for e in ws))
  323. def strip_comment(line):
  324. return re.sub('#.*', '', line).rstrip()
  325. def strip_comments(lines):
  326. pat = re.compile('#.*')
  327. return [m for m in [pat.sub('', l).rstrip() for l in lines] if m != '']
  328. def make_full_path(outdir, outfile):
  329. return os.path.normpath(os.path.join(outdir, os.path.basename(outfile)))
  330. class oneshot_warning:
  331. color = 'nocolor'
  332. def __init__(self, *, div=None, fmt_args=[], reverse=False):
  333. self.do(type(self), div, fmt_args, reverse)
  334. def do(self, wcls, div, fmt_args, reverse):
  335. def do_warning():
  336. from . import color
  337. msg(getattr(color, getattr(wcls, 'color'))('WARNING: ' + getattr(wcls, 'message').format(*fmt_args)))
  338. if not hasattr(wcls, 'data'):
  339. setattr(wcls, 'data', [])
  340. data = getattr(wcls, 'data')
  341. condition = (div in data) if reverse else (not div in data)
  342. if not div in data:
  343. data.append(div)
  344. if condition:
  345. do_warning()
  346. self.warning_shown = True
  347. else:
  348. self.warning_shown = False
  349. class oneshot_warning_group(oneshot_warning):
  350. def __init__(self, wcls, *, div=None, fmt_args=[], reverse=False):
  351. self.do(getattr(self, wcls), div, fmt_args, reverse)
  352. def get_subclasses(cls, *, names=False):
  353. def gen(cls):
  354. for i in cls.__subclasses__():
  355. yield i
  356. yield from gen(i)
  357. return tuple((c.__name__ for c in gen(cls)) if names else gen(cls))
  358. def async_run(cfg, func, *, args=(), kwargs={}):
  359. import asyncio
  360. if cfg.rpc_backend == 'aiohttp':
  361. async def func2():
  362. import aiohttp
  363. connector = aiohttp.TCPConnector(limit_per_host=cfg.aiohttp_rpc_queue_len)
  364. async with aiohttp.ClientSession(
  365. headers = {'Content-Type': 'application/json'},
  366. connector = connector) as cfg.aiohttp_session:
  367. return await func(*args, **kwargs)
  368. return asyncio.run(func2())
  369. else:
  370. return asyncio.run(func(*args, **kwargs))
  371. def wrap_ripemd160(called=[]):
  372. if not called:
  373. try:
  374. import hashlib
  375. hashlib.new('ripemd160')
  376. except ValueError:
  377. def hashlib_new_wrapper(name, *args, **kwargs):
  378. if name == 'ripemd160':
  379. return ripemd160(*args, **kwargs)
  380. else:
  381. return hashlib_new(name, *args, **kwargs)
  382. from .contrib.ripemd160 import ripemd160
  383. hashlib_new = hashlib.new
  384. hashlib.new = hashlib_new_wrapper
  385. called.append(True)
  386. def exit_if_mswin(feature):
  387. if sys.platform == 'win32':
  388. die(2, capfirst(feature) + ' not supported on the MSWin / MSYS2 platform')
  389. def have_sudo(*, silent=False):
  390. from subprocess import run, DEVNULL
  391. redir = DEVNULL if silent else None
  392. try:
  393. run(['sudo', '--non-interactive', 'true'], stdout=redir, stderr=redir, check=True)
  394. return True
  395. except:
  396. return False
  397. def in_nix_environment():
  398. for path in os.getenv('PATH').split(':'):
  399. if path.startswith('/nix/store/'):
  400. return True
  401. def cached_property(orig_func):
  402. @property
  403. def new_func(self):
  404. attr_name = '_' + orig_func.__name__
  405. if not hasattr(self, attr_name):
  406. setattr(self, attr_name, orig_func(self))
  407. return getattr(self, attr_name)
  408. return new_func
  409. def isAsync(func):
  410. return bool(func.__code__.co_flags & 128)