Browse Source

SeedSplit: minor cleanups, new SeedSplitIdx,SeedSplitCount classes

MMGen 5 years ago
parent
commit
2388662e72
5 changed files with 86 additions and 72 deletions
  1. 0 1
      mmgen/globalvars.py
  2. 15 5
      mmgen/obj.py
  3. 39 29
      mmgen/seed.py
  4. 10 2
      test/objtest_py_d/ot_btc_mainnet.py
  5. 22 35
      test/unit_tests_d/ut_seedsplit.py

+ 0 - 1
mmgen/globalvars.py

@@ -211,7 +211,6 @@ class g(object):
 	seed_lens = 128,192,256
 	scramble_hash_rounds = 10
 	subseeds = 100
-	max_seed_splits = 1024
 
 	mmenc_ext      = 'mmenc'
 	salt_len       = 16

+ 15 - 5
mmgen/obj.py

@@ -321,19 +321,29 @@ class MMGenListItem(MMGenObject):
 			raise AttributeError(m.format(name,type(self)))
 		return object.__setattr__(self,name,value)
 
-class AddrIdx(int,InitErrors):
-	max_digits = 7
+class MMGenIdx(int,InitErrors):
+	min_val = 1
+	max_val = None
+	max_digits = None
 	def __new__(cls,num,on_fail='die'):
 		cls.arg_chk(on_fail)
 		try:
 			assert type(num) is not float,'is float'
 			me = int.__new__(cls,num)
-			assert len(str(me)) <= cls.max_digits,'is more than {} digits'.format(cls.max_digits)
-			assert me > 0,'is less than one'
+			if cls.max_digits:
+				assert len(str(me)) <= cls.max_digits,'has more than {} digits'.format(cls.max_digits)
+			if cls.max_val:
+				assert me <= cls.max_val,'is greater than {}'.format(cls.max_val)
+			assert me >= cls.min_val,'is less than {}'.format(cls.min_val)
 			return me
 		except Exception as e:
 			return cls.init_fail(e,num)
 
+class SeedSplitIdx(MMGenIdx): max_val = 1024
+class SeedSplitCount(SeedSplitIdx): min_val = 2
+class MasterSplitIdx(MMGenIdx): max_val = 1024
+class AddrIdx(MMGenIdx): max_digits = 7
+
 class AddrIdxList(list,InitErrors,MMGenObject):
 	max_len = 1000000
 	def __init__(self,fmt_str=None,idx_list=None,on_fail='die',sep=','):
@@ -862,7 +872,7 @@ class MMGenPWIDString(MMGenLabel):
 	desc = 'password ID string'
 	forbidden = list(' :/\\')
 
-class MMGenSeedSplitIDString(MMGenPWIDString):
+class SeedSplitIDString(MMGenPWIDString):
 	desc = 'seed split ID string'
 
 class MMGenAddrType(str,Hilite,InitErrors,MMGenObject):

+ 39 - 29
mmgen/seed.py

@@ -210,26 +210,40 @@ class Seed(SeedBase):
 		return SeedSplitList(self,count,id_str)
 
 	@staticmethod
-	def join_splits(seed_list): # seed_list must be a generator
-		seed1 = next(seed_list)
-		length = seed1.length
-		ret = int(seed1.data.hex(),16)
+	def join_splits(seed_list):
+		if not hasattr(seed_list,'__next__'): # seed_list can be iterator or iterable
+			seed_list = iter(seed_list)
+
+		class d(object):
+			slen = None
+			ret = 0
+			count = 0
+
+		def add_split(ss):
+			if d.slen:
+				assert ss.length == d.slen,'Seed length mismatch! {} != {}'.format(ss.length,d.slen)
+			else:
+				d.slen = ss.length
+			d.ret ^= int(ss.data.hex(),16)
+			d.count += 1
+
 		for ss in seed_list:
-			assert ss.length == length,'Seed length mismatch! {} != {}'.format(ss.length,length)
-			ret ^= int(ss.data.hex(),16)
-		return Seed(seed_bin=ret.to_bytes(length // 8,'big'))
+			add_split(ss)
+
+		SeedSplitCount(d.count)
+		return Seed(seed_bin=d.ret.to_bytes(d.slen // 8,'big'))
 
 class SubSeed(SeedBase):
 
 	idx    = MMGenImmutableAttr('idx',int,typeconv=False)
 	nonce  = MMGenImmutableAttr('nonce',int,typeconv=False)
-	ss_idx = MMGenImmutableAttr('ss_idx',SubSeedIdx,typeconv=False)
+	ss_idx = MMGenImmutableAttr('ss_idx',SubSeedIdx)
 	max_nonce = 1000
 
 	def __init__(self,parent_list,idx,nonce,length):
 		self.idx = idx
 		self.nonce = nonce
-		self.ss_idx = SubSeedIdx(str(idx) + { 'long': 'L', 'short': 'S' }[length])
+		self.ss_idx = str(idx) + { 'long': 'L', 'short': 'S' }[length]
 		SeedBase.__init__(self,seed_bin=type(self).make_subseed_bin(parent_list,idx,nonce,length))
 
 	@staticmethod
@@ -246,26 +260,21 @@ class SubSeed(SeedBase):
 class SeedSplitList(SubSeedList):
 	have_short = False
 	split_type = 'N-of-N'
-	id_str = 'default'
 
-	count = MMGenImmutableAttr('count',int,typeconv=False)
+	count = MMGenImmutableAttr('count',SeedSplitCount)
+	id_str = MMGenImmutableAttr('id_str',SeedSplitIDString)
 
 	def __init__(self,parent_seed,count,id_str=None):
 		self.member_type = SeedSplit
 		self.parent_seed = parent_seed
-		self.id_str = MMGenSeedSplitIDString(id_str if id_str is not None else type(self).id_str)
-
-		assert issubclass(type(count),int) and count > 1,(
-			"{!r}: illegal value for 'count' (not a positive integer greater than one)".format(count))
-		assert count <= g.max_seed_splits,(
-			"{!r}: illegal value for 'count' (> {})".format(count,g.max_seed_splits))
+		self.id_str = id_str or 'default'
 		self.count = count
 
 		while True:
 			self.data = { 'long': IndexedDict(), 'short': IndexedDict() }
 			self._generate(count-1)
-			self.last_seed = SeedSplitLast(self)
-			sid = self.last_seed.sid
+			self.last_split = SeedSplitLast(self)
+			sid = self.last_split.sid
 			if sid in self.data['long'] or sid == parent_seed.sid:
 				# collision: throw out entire split list and redo with new start nonce
 				if g.debug_subseed:
@@ -280,28 +289,28 @@ class SeedSplitList(SubSeedList):
 			B = self.join().data
 			assert A == B,'Data mismatch!\noriginal seed: {!r}\nrejoined seed: {!r}'.format(A,B)
 
-	def get_split_by_idx(self,idx,print_msg=False):
+	def get_split_by_idx(self,idx):
 		if idx == self.count:
-			return self.last_seed # TODO: msg?
+			return self.last_split
 		else:
 			ss_idx = SubSeedIdx(str(idx) + 'L')
-			return self.get_subseed_by_ss_idx(ss_idx,print_msg=print_msg)
+			return self.get_subseed_by_ss_idx(ss_idx)
 
-	def get_split_by_seed_id(self,sid,last_idx=None,print_msg=False):
+	def get_split_by_seed_id(self,sid,last_idx=None):
 		if sid == self.data['long'].key(self.count-1):
-			return self.last_seed # TODO: msg?
+			return self.last_split
 		else:
-			return self.get_subseed_by_seed_id(sid,last_idx=last_idx,print_msg=print_msg)
+			return self.get_subseed_by_seed_id(sid,last_idx=last_idx)
 
 	def join(self):
 		return Seed.join_splits(self.get_split_by_idx(i+1) for i in range(len(self)))
 
 	def format(self):
+		assert self.split_type == 'N-of-N'
 		fs1 = '    {}\n'
 		fs2 = '{i:>5}: {}\n'
 
 		hdr  = '    {} {} ({} bits)\n'.format('Seed:',self.parent_seed.sid.hl(),self.parent_seed.length)
-		assert self.split_type == 'N-of-N'
 		hdr += '    {} {c}-of-{c} (XOR)\n'.format('Split Type:',c=self.count)
 		hdr += '    {} {}\n\n'.format('ID String:',self.id_str.hl())
 		hdr += fs1.format('Splits')
@@ -329,15 +338,16 @@ class SeedSplit(SubSeed):
 
 class SeedSplitLast(SubSeed):
 
+	idx = MMGenImmutableAttr('idx',SeedSplitIdx)
+	nonce = 0
+
 	def __init__(self,parent_list):
 		self.idx = parent_list.count
-		self.nonce = 0
-		self.ss_idx = SubSeedIdx(str(self.idx) + 'L')
 		SeedBase.__init__(self,seed_bin=self.make_subseed_bin(parent_list))
 
 	@staticmethod
 	def make_subseed_bin(parent_list):
-		seed_list = (parent_list.get_subseed_by_ss_idx(str(i+1)+'L') for i in range(len(parent_list)))
+		seed_list = (parent_list.get_split_by_idx(i+1) for i in range(len(parent_list)))
 		seed = parent_list.parent_seed
 
 		ret = int(seed.data.hex(),16)

+ 10 - 2
test/objtest_py_d/ot_btc_mainnet.py

@@ -15,8 +15,16 @@ from .ot_common import *
 
 tests = OrderedDict([
 	('AddrIdx', {
-		'bad':  ('s',1.1,12345678,-1),
-		'good': (('7',7),)
+		'bad':  ('s',1.1,10000000,-1,0),
+		'good': (('7',7),(1,1),(9999999,9999999))
+	}),
+	('SeedSplitIdx', {
+		'bad':  ('s',1.1,1025,-1,0),
+		'good': (('7',7),(1,1),(1024,1024))
+	}),
+	('SeedSplitCount', {
+		'bad':  ('s',2.1,1025,-1,0,1),
+		'good': (('7',7),(2,2),(1024,1024))
 	}),
 	('AddrIdxList', {
 		'bad':  ('x','5,9,1-2-3','8,-11','66,3-2'),

+ 22 - 35
test/unit_tests_d/ut_seedsplit.py

@@ -9,6 +9,7 @@ class unit_test(object):
 
 	def run_test(self,name):
 		from mmgen.seed import Seed
+		from mmgen.obj import SeedSplitIdx
 
 		def basic_ops():
 			test_data = {
@@ -33,44 +34,30 @@ class unit_test(object):
 					seed = Seed(seed_bin)
 					assert seed.sid == b, seed.sid
 
-					splitlist = seed.splitlist(2,id_str)
-					A = len(splitlist)
-					assert A == 2, A
+					for split_count,j,k,l in ((2,c,c,d),(5,e,f,h)):
 
-					s = splitlist.format()
-					vmsg_r('\n{}'.format(s))
-					assert len(s.strip().split('\n')) == 8, s
+						splitlist = seed.splitlist(split_count,id_str)
+						A = len(splitlist)
+						assert A == split_count, A
 
-					A = splitlist.get_split_by_idx(1).sid
-					B = splitlist.get_split_by_seed_id(c).sid
-					assert A == B == c, A
+						s = splitlist.format()
+						vmsg_r('\n{}'.format(s))
+						assert len(s.strip().split('\n')) == split_count+6, s
 
-					A = splitlist.get_split_by_idx(2).sid
-					B = splitlist.get_split_by_seed_id(d).sid
-					assert A == B == d, A
+						A = splitlist.get_split_by_idx(1).sid
+						B = splitlist.get_split_by_seed_id(j).sid
+						assert A == B == j, A
 
-					splitlist = seed.splitlist(5,id_str)
-					A = len(splitlist)
-					assert A == 5, A
+						A = splitlist.get_split_by_idx(split_count-1).sid
+						B = splitlist.get_split_by_seed_id(k).sid
+						assert A == B == k, A
 
-					s = splitlist.format()
-					vmsg_r('\n{}'.format(s))
-					assert len(s.strip().split('\n')) == 11, s
+						A = splitlist.get_split_by_idx(split_count).sid
+						B = splitlist.get_split_by_seed_id(l).sid
+						assert A == B == l, A
 
-					A = splitlist.get_split_by_idx(1).sid
-					B = splitlist.get_split_by_seed_id(e).sid
-					assert A == B == e, A
-
-					A = splitlist.get_split_by_idx(4).sid
-					B = splitlist.get_split_by_seed_id(f).sid
-					assert A == B == f, A
-
-					A = splitlist.get_split_by_idx(5).sid
-					B = splitlist.get_split_by_seed_id(h).sid
-					assert A == B == h, A
-
-					A = splitlist.join().sid
-					assert A == b, A
+						A = splitlist.join().sid
+						assert A == b, A
 
 				msg('OK')
 
@@ -80,7 +67,7 @@ class unit_test(object):
 			seed_bin = bytes.fromhex('deadbeef' * 8)
 			seed = Seed(seed_bin)
 
-			splitlist = seed.splitlist(g.max_seed_splits)
+			splitlist = seed.splitlist(SeedSplitIdx.max_val)
 			s = splitlist.format()
 #			vmsg_r('\n{}'.format(s))
 			assert len(s.strip().split('\n')) == 1030, s
@@ -105,8 +92,8 @@ class unit_test(object):
 			seed_bin = bytes.fromhex('1dabcdef' * 4)
 			seed = Seed(seed_bin)
 
-			g.max_seed_splits = ss_count
-			splitlist = seed.splitlist(g.max_seed_splits)
+			SeedSplitIdx.max_val = ss_count
+			splitlist = seed.splitlist(ss_count)
 			A = splitlist.get_split_by_idx(ss_count).sid
 			B = splitlist.get_split_by_seed_id(last_sid).sid
 			assert A == last_sid, A