Browse Source

seed.py: new SeedSourceMeta metaclass

The MMGen Project 5 years ago
parent
commit
f052968e21
2 changed files with 14 additions and 17 deletions
  1. 3 3
      mmgen/filename.py
  2. 11 14
      mmgen/seed.py

+ 3 - 3
mmgen/filename.py

@@ -42,7 +42,7 @@ class Filename(MMGenObject):
 		from mmgen.seed import SeedSource
 		from mmgen.tx import MMGenTX
 		if ftype:
-			if type(ftype) == type:
+			if isinstance(ftype,type):
 				if issubclass(ftype,(SeedSource,MMGenTX)):
 					self.ftype = ftype
 				# elif: # other MMGen file types
@@ -96,8 +96,8 @@ class MMGenFileList(list,MMGenObject):
 		self.sort(key=lambda a: getattr(a,key),reverse=reverse)
 
 def find_files_in_dir(ftype,fdir,no_dups=False):
-	if type(ftype) != type:
-		die(3,"'{}': not a type".format(ftype))
+	if not isinstance(ftype,type):
+		die(3,"'{}': is of type {} (not a subclass of type 'type')".format(ftype,type(ftype)))
 
 	from mmgen.seed import SeedSource
 	if not issubclass(ftype,SeedSource):

+ 11 - 14
mmgen/seed.py

@@ -499,7 +499,13 @@ class SeedShareMasterJoining(SeedShareMaster):
 		self.count = count
 		self.derived_seed = SeedBase(self.make_derived_seed_bin(self.id_str,self.count))
 
-class SeedSource(MMGenObject):
+class SeedSourceMeta(type):
+	wallet_classes = set() # one-instance class, so store data in class attr
+	def __init__(cls,name,bases,namespace):
+		cls.wallet_classes.add(cls)
+		cls.wallet_classes -= set(bases)
+
+class SeedSource(MMGenObject,metaclass=SeedSourceMeta):
 
 	desc = g.proj_name + ' seed source'
 	file_mode = 'text'
@@ -617,23 +623,14 @@ class SeedSource(MMGenObject):
 				die(2,'Passphrase from password file, so exiting')
 			msg('Trying again...')
 
-	@classmethod
-	def get_subclasses(cls): # returns calling class too
-		def GetSubclassesTree(cls,acc):
-			acc += [cls]
-			for c in cls.__subclasses__(): GetSubclassesTree(c,acc)
-		acc = []
-		GetSubclassesTree(cls,acc)
-		return acc
-
 	@classmethod
 	def get_extensions(cls):
-		return [s.ext for s in cls.get_subclasses() if hasattr(s,'ext')]
+		return [c.ext for c in cls.wallet_classes if hasattr(c,'ext')]
 
 	@classmethod
 	def fmt_code_to_type(cls,fmt_code):
 		if fmt_code:
-			for c in cls.get_subclasses():
+			for c in cls.wallet_classes:
 				if fmt_code in getattr(c,'fmt_codes',[]):
 					return c
 		return None
@@ -641,7 +638,7 @@ class SeedSource(MMGenObject):
 	@classmethod
 	def ext_to_type(cls,ext):
 		if ext:
-			for c in cls.get_subclasses():
+			for c in cls.wallet_classes:
 				if ext == getattr(c,'ext',None):
 					return c
 		return None
@@ -649,7 +646,7 @@ class SeedSource(MMGenObject):
 	@classmethod
 	def format_fmt_codes(cls):
 		d = [(c.__name__,('.'+c.ext if c.ext else str(c.ext)),','.join(c.fmt_codes))
-					for c in cls.get_subclasses()
+					for c in cls.wallet_classes
 				if hasattr(c,'fmt_codes')]
 		w = max(len(i[0]) for i in d)
 		ret = ['{:<{w}}  {:<9} {}'.format(a,b,c,w=w) for a,b,c in [