From 2388662e72fe1571877f72744526e9ca59d775c7 Mon Sep 17 00:00:00 2001 From: MMGen Date: Mon, 10 Jun 2019 09:19:13 +0000 Subject: [PATCH] SeedSplit: minor cleanups, new SeedSplitIdx,SeedSplitCount classes --- mmgen/globalvars.py | 1 - mmgen/obj.py | 20 ++++++--- mmgen/seed.py | 68 +++++++++++++++++------------ test/objtest_py_d/ot_btc_mainnet.py | 12 ++++- test/unit_tests_d/ut_seedsplit.py | 57 ++++++++++-------------- 5 files changed, 86 insertions(+), 72 deletions(-) diff --git a/mmgen/globalvars.py b/mmgen/globalvars.py index 00d40c3c..657fdc55 100755 --- a/mmgen/globalvars.py +++ b/mmgen/globalvars.py @@ -211,7 +211,6 @@ class g(object): seed_lens = 128,192,256 scramble_hash_rounds = 10 subseeds = 100 - max_seed_splits = 1024 mmenc_ext = 'mmenc' salt_len = 16 diff --git a/mmgen/obj.py b/mmgen/obj.py index 2ef3d69f..1d8a16d6 100755 --- a/mmgen/obj.py +++ b/mmgen/obj.py @@ -321,19 +321,29 @@ class MMGenListItem(MMGenObject): raise AttributeError(m.format(name,type(self))) return object.__setattr__(self,name,value) -class AddrIdx(int,InitErrors): - max_digits = 7 +class MMGenIdx(int,InitErrors): + min_val = 1 + max_val = None + max_digits = None def __new__(cls,num,on_fail='die'): cls.arg_chk(on_fail) try: assert type(num) is not float,'is float' me = int.__new__(cls,num) - assert len(str(me)) <= cls.max_digits,'is more than {} digits'.format(cls.max_digits) - assert me > 0,'is less than one' + if cls.max_digits: + assert len(str(me)) <= cls.max_digits,'has more than {} digits'.format(cls.max_digits) + if cls.max_val: + assert me <= cls.max_val,'is greater than {}'.format(cls.max_val) + assert me >= cls.min_val,'is less than {}'.format(cls.min_val) return me except Exception as e: return cls.init_fail(e,num) +class SeedSplitIdx(MMGenIdx): max_val = 1024 +class SeedSplitCount(SeedSplitIdx): min_val = 2 +class MasterSplitIdx(MMGenIdx): max_val = 1024 +class AddrIdx(MMGenIdx): max_digits = 7 + class AddrIdxList(list,InitErrors,MMGenObject): max_len = 1000000 def __init__(self,fmt_str=None,idx_list=None,on_fail='die',sep=','): @@ -862,7 +872,7 @@ class MMGenPWIDString(MMGenLabel): desc = 'password ID string' forbidden = list(' :/\\') -class MMGenSeedSplitIDString(MMGenPWIDString): +class SeedSplitIDString(MMGenPWIDString): desc = 'seed split ID string' class MMGenAddrType(str,Hilite,InitErrors,MMGenObject): diff --git a/mmgen/seed.py b/mmgen/seed.py index b78ccf91..cacd880a 100755 --- a/mmgen/seed.py +++ b/mmgen/seed.py @@ -210,26 +210,40 @@ class Seed(SeedBase): return SeedSplitList(self,count,id_str) @staticmethod - def join_splits(seed_list): # seed_list must be a generator - seed1 = next(seed_list) - length = seed1.length - ret = int(seed1.data.hex(),16) + def join_splits(seed_list): + if not hasattr(seed_list,'__next__'): # seed_list can be iterator or iterable + seed_list = iter(seed_list) + + class d(object): + slen = None + ret = 0 + count = 0 + + def add_split(ss): + if d.slen: + assert ss.length == d.slen,'Seed length mismatch! {} != {}'.format(ss.length,d.slen) + else: + d.slen = ss.length + d.ret ^= int(ss.data.hex(),16) + d.count += 1 + for ss in seed_list: - assert ss.length == length,'Seed length mismatch! {} != {}'.format(ss.length,length) - ret ^= int(ss.data.hex(),16) - return Seed(seed_bin=ret.to_bytes(length // 8,'big')) + add_split(ss) + + SeedSplitCount(d.count) + return Seed(seed_bin=d.ret.to_bytes(d.slen // 8,'big')) class SubSeed(SeedBase): idx = MMGenImmutableAttr('idx',int,typeconv=False) nonce = MMGenImmutableAttr('nonce',int,typeconv=False) - ss_idx = MMGenImmutableAttr('ss_idx',SubSeedIdx,typeconv=False) + ss_idx = MMGenImmutableAttr('ss_idx',SubSeedIdx) max_nonce = 1000 def __init__(self,parent_list,idx,nonce,length): self.idx = idx self.nonce = nonce - self.ss_idx = SubSeedIdx(str(idx) + { 'long': 'L', 'short': 'S' }[length]) + self.ss_idx = str(idx) + { 'long': 'L', 'short': 'S' }[length] SeedBase.__init__(self,seed_bin=type(self).make_subseed_bin(parent_list,idx,nonce,length)) @staticmethod @@ -246,26 +260,21 @@ class SubSeed(SeedBase): class SeedSplitList(SubSeedList): have_short = False split_type = 'N-of-N' - id_str = 'default' - count = MMGenImmutableAttr('count',int,typeconv=False) + count = MMGenImmutableAttr('count',SeedSplitCount) + id_str = MMGenImmutableAttr('id_str',SeedSplitIDString) def __init__(self,parent_seed,count,id_str=None): self.member_type = SeedSplit self.parent_seed = parent_seed - self.id_str = MMGenSeedSplitIDString(id_str if id_str is not None else type(self).id_str) - - assert issubclass(type(count),int) and count > 1,( - "{!r}: illegal value for 'count' (not a positive integer greater than one)".format(count)) - assert count <= g.max_seed_splits,( - "{!r}: illegal value for 'count' (> {})".format(count,g.max_seed_splits)) + self.id_str = id_str or 'default' self.count = count while True: self.data = { 'long': IndexedDict(), 'short': IndexedDict() } self._generate(count-1) - self.last_seed = SeedSplitLast(self) - sid = self.last_seed.sid + self.last_split = SeedSplitLast(self) + sid = self.last_split.sid if sid in self.data['long'] or sid == parent_seed.sid: # collision: throw out entire split list and redo with new start nonce if g.debug_subseed: @@ -280,28 +289,28 @@ class SeedSplitList(SubSeedList): B = self.join().data assert A == B,'Data mismatch!\noriginal seed: {!r}\nrejoined seed: {!r}'.format(A,B) - def get_split_by_idx(self,idx,print_msg=False): + def get_split_by_idx(self,idx): if idx == self.count: - return self.last_seed # TODO: msg? + return self.last_split else: ss_idx = SubSeedIdx(str(idx) + 'L') - return self.get_subseed_by_ss_idx(ss_idx,print_msg=print_msg) + return self.get_subseed_by_ss_idx(ss_idx) - def get_split_by_seed_id(self,sid,last_idx=None,print_msg=False): + def get_split_by_seed_id(self,sid,last_idx=None): if sid == self.data['long'].key(self.count-1): - return self.last_seed # TODO: msg? + return self.last_split else: - return self.get_subseed_by_seed_id(sid,last_idx=last_idx,print_msg=print_msg) + return self.get_subseed_by_seed_id(sid,last_idx=last_idx) def join(self): return Seed.join_splits(self.get_split_by_idx(i+1) for i in range(len(self))) def format(self): + assert self.split_type == 'N-of-N' fs1 = ' {}\n' fs2 = '{i:>5}: {}\n' hdr = ' {} {} ({} bits)\n'.format('Seed:',self.parent_seed.sid.hl(),self.parent_seed.length) - assert self.split_type == 'N-of-N' hdr += ' {} {c}-of-{c} (XOR)\n'.format('Split Type:',c=self.count) hdr += ' {} {}\n\n'.format('ID String:',self.id_str.hl()) hdr += fs1.format('Splits') @@ -329,15 +338,16 @@ class SeedSplit(SubSeed): class SeedSplitLast(SubSeed): + idx = MMGenImmutableAttr('idx',SeedSplitIdx) + nonce = 0 + def __init__(self,parent_list): self.idx = parent_list.count - self.nonce = 0 - self.ss_idx = SubSeedIdx(str(self.idx) + 'L') SeedBase.__init__(self,seed_bin=self.make_subseed_bin(parent_list)) @staticmethod def make_subseed_bin(parent_list): - seed_list = (parent_list.get_subseed_by_ss_idx(str(i+1)+'L') for i in range(len(parent_list))) + seed_list = (parent_list.get_split_by_idx(i+1) for i in range(len(parent_list))) seed = parent_list.parent_seed ret = int(seed.data.hex(),16) diff --git a/test/objtest_py_d/ot_btc_mainnet.py b/test/objtest_py_d/ot_btc_mainnet.py index 06c6a2f0..2702a533 100755 --- a/test/objtest_py_d/ot_btc_mainnet.py +++ b/test/objtest_py_d/ot_btc_mainnet.py @@ -15,8 +15,16 @@ from .ot_common import * tests = OrderedDict([ ('AddrIdx', { - 'bad': ('s',1.1,12345678,-1), - 'good': (('7',7),) + 'bad': ('s',1.1,10000000,-1,0), + 'good': (('7',7),(1,1),(9999999,9999999)) + }), + ('SeedSplitIdx', { + 'bad': ('s',1.1,1025,-1,0), + 'good': (('7',7),(1,1),(1024,1024)) + }), + ('SeedSplitCount', { + 'bad': ('s',2.1,1025,-1,0,1), + 'good': (('7',7),(2,2),(1024,1024)) }), ('AddrIdxList', { 'bad': ('x','5,9,1-2-3','8,-11','66,3-2'), diff --git a/test/unit_tests_d/ut_seedsplit.py b/test/unit_tests_d/ut_seedsplit.py index 81392a3b..fdcd466b 100755 --- a/test/unit_tests_d/ut_seedsplit.py +++ b/test/unit_tests_d/ut_seedsplit.py @@ -9,6 +9,7 @@ class unit_test(object): def run_test(self,name): from mmgen.seed import Seed + from mmgen.obj import SeedSplitIdx def basic_ops(): test_data = { @@ -33,44 +34,30 @@ class unit_test(object): seed = Seed(seed_bin) assert seed.sid == b, seed.sid - splitlist = seed.splitlist(2,id_str) - A = len(splitlist) - assert A == 2, A + for split_count,j,k,l in ((2,c,c,d),(5,e,f,h)): - s = splitlist.format() - vmsg_r('\n{}'.format(s)) - assert len(s.strip().split('\n')) == 8, s + splitlist = seed.splitlist(split_count,id_str) + A = len(splitlist) + assert A == split_count, A - A = splitlist.get_split_by_idx(1).sid - B = splitlist.get_split_by_seed_id(c).sid - assert A == B == c, A + s = splitlist.format() + vmsg_r('\n{}'.format(s)) + assert len(s.strip().split('\n')) == split_count+6, s - A = splitlist.get_split_by_idx(2).sid - B = splitlist.get_split_by_seed_id(d).sid - assert A == B == d, A + A = splitlist.get_split_by_idx(1).sid + B = splitlist.get_split_by_seed_id(j).sid + assert A == B == j, A - splitlist = seed.splitlist(5,id_str) - A = len(splitlist) - assert A == 5, A + A = splitlist.get_split_by_idx(split_count-1).sid + B = splitlist.get_split_by_seed_id(k).sid + assert A == B == k, A - s = splitlist.format() - vmsg_r('\n{}'.format(s)) - assert len(s.strip().split('\n')) == 11, s + A = splitlist.get_split_by_idx(split_count).sid + B = splitlist.get_split_by_seed_id(l).sid + assert A == B == l, A - A = splitlist.get_split_by_idx(1).sid - B = splitlist.get_split_by_seed_id(e).sid - assert A == B == e, A - - A = splitlist.get_split_by_idx(4).sid - B = splitlist.get_split_by_seed_id(f).sid - assert A == B == f, A - - A = splitlist.get_split_by_idx(5).sid - B = splitlist.get_split_by_seed_id(h).sid - assert A == B == h, A - - A = splitlist.join().sid - assert A == b, A + A = splitlist.join().sid + assert A == b, A msg('OK') @@ -80,7 +67,7 @@ class unit_test(object): seed_bin = bytes.fromhex('deadbeef' * 8) seed = Seed(seed_bin) - splitlist = seed.splitlist(g.max_seed_splits) + splitlist = seed.splitlist(SeedSplitIdx.max_val) s = splitlist.format() # vmsg_r('\n{}'.format(s)) assert len(s.strip().split('\n')) == 1030, s @@ -105,8 +92,8 @@ class unit_test(object): seed_bin = bytes.fromhex('1dabcdef' * 4) seed = Seed(seed_bin) - g.max_seed_splits = ss_count - splitlist = seed.splitlist(g.max_seed_splits) + SeedSplitIdx.max_val = ss_count + splitlist = seed.splitlist(ss_count) A = splitlist.get_split_by_idx(ss_count).sid B = splitlist.get_split_by_seed_id(last_sid).sid assert A == last_sid, A