diff options
Diffstat (limited to 'factory')
-rw-r--r-- | factory/declarations.py | 73 |
1 files changed, 38 insertions, 35 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 17f4434..abd384e 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -301,6 +301,40 @@ class ParameteredAttribute(OrderedDeclaration): raise NotImplementedError() +class _FactoryWrapper(object): + """Handle a 'factory' arg. + + Such args can be either a Factory subclass, or a fully qualified import + path for that subclass (e.g 'myapp.factories.MyFactory'). + """ + def __init__(self, factory_or_path): + self.factory = None + self.module = self.name = '' + if isinstance(factory_or_path, type): + self.factory = factory_or_path + else: + if not (compat.is_string(factory_or_path) and '.' in factory_or_path): + raise ValueError( + "A factory= argument must receive either a class " + "or the fully qualified path to a Factory subclass; got " + "%r instead." % factory_or_path) + self.module, self.name = factory_or_path.rsplit('.', 1) + + def get(self): + if self.factory is None: + self.factory = utils.import_object( + self.module, + self.name, + ) + return self.factory + + def __repr__(self): + if self.factory is None: + return '<_FactoryImport: %s.%s>' % (self.module, self.name) + else: + return '<_FactoryImport: %s>' % self.factory.__class__ + + class SubFactory(ParameteredAttribute): """Base class for attributes based upon a sub-factory. @@ -314,26 +348,11 @@ class SubFactory(ParameteredAttribute): def __init__(self, factory, **kwargs): super(SubFactory, self).__init__(**kwargs) - if isinstance(factory, type): - self.factory = factory - self.factory_module = self.factory_name = '' - else: - # Must be a string - if not (compat.is_string(factory) and '.' in factory): - raise ValueError( - "The argument of a SubFactory must be either a class " - "or the fully qualified path to a Factory class; got " - "%r instead." % factory) - self.factory = None - self.factory_module, self.factory_name = factory.rsplit('.', 1) + self.factory_wrapper = _FactoryWrapper(factory) def get_factory(self): """Retrieve the wrapped factory.Factory subclass.""" - if self.factory is None: - # Must be a module path - self.factory = utils.import_object( - self.factory_module, self.factory_name) - return self.factory + return self.factory_wrapper.get() def generate(self, sequence, obj, create, params): """Evaluate the current definition and fill its attributes. @@ -442,27 +461,11 @@ class RelatedFactory(PostGenerationDeclaration): self.name = factory_related_name self.defaults = defaults - - if isinstance(factory, type): - self.factory = factory - self.factory_module = self.factory_name = '' - else: - # Must be a string - if not (compat.is_string(factory) and '.' in factory): - raise ValueError( - "The argument of a SubFactory must be either a class " - "or the fully qualified path to a Factory class; got " - "%r instead." % factory) - self.factory = None - self.factory_module, self.factory_name = factory.rsplit('.', 1) + self.factory_wrapper = _FactoryWrapper(factory) def get_factory(self): """Retrieve the wrapped factory.Factory subclass.""" - if self.factory is None: - # Must be a module path - self.factory = utils.import_object( - self.factory_module, self.factory_name) - return self.factory + return self.factory_wrapper.get() def call(self, obj, create, extracted=None, **kwargs): passed_kwargs = dict(self.defaults) |