aboutsummaryrefslogtreecommitdiff
path: root/factory
diff options
context:
space:
mode:
Diffstat (limited to 'factory')
-rw-r--r--factory/declarations.py73
1 files changed, 38 insertions, 35 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 17f4434..abd384e 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -301,6 +301,40 @@ class ParameteredAttribute(OrderedDeclaration):
raise NotImplementedError()
+class _FactoryWrapper(object):
+ """Handle a 'factory' arg.
+
+ Such args can be either a Factory subclass, or a fully qualified import
+ path for that subclass (e.g 'myapp.factories.MyFactory').
+ """
+ def __init__(self, factory_or_path):
+ self.factory = None
+ self.module = self.name = ''
+ if isinstance(factory_or_path, type):
+ self.factory = factory_or_path
+ else:
+ if not (compat.is_string(factory_or_path) and '.' in factory_or_path):
+ raise ValueError(
+ "A factory= argument must receive either a class "
+ "or the fully qualified path to a Factory subclass; got "
+ "%r instead." % factory_or_path)
+ self.module, self.name = factory_or_path.rsplit('.', 1)
+
+ def get(self):
+ if self.factory is None:
+ self.factory = utils.import_object(
+ self.module,
+ self.name,
+ )
+ return self.factory
+
+ def __repr__(self):
+ if self.factory is None:
+ return '<_FactoryImport: %s.%s>' % (self.module, self.name)
+ else:
+ return '<_FactoryImport: %s>' % self.factory.__class__
+
+
class SubFactory(ParameteredAttribute):
"""Base class for attributes based upon a sub-factory.
@@ -314,26 +348,11 @@ class SubFactory(ParameteredAttribute):
def __init__(self, factory, **kwargs):
super(SubFactory, self).__init__(**kwargs)
- if isinstance(factory, type):
- self.factory = factory
- self.factory_module = self.factory_name = ''
- else:
- # Must be a string
- if not (compat.is_string(factory) and '.' in factory):
- raise ValueError(
- "The argument of a SubFactory must be either a class "
- "or the fully qualified path to a Factory class; got "
- "%r instead." % factory)
- self.factory = None
- self.factory_module, self.factory_name = factory.rsplit('.', 1)
+ self.factory_wrapper = _FactoryWrapper(factory)
def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
- if self.factory is None:
- # Must be a module path
- self.factory = utils.import_object(
- self.factory_module, self.factory_name)
- return self.factory
+ return self.factory_wrapper.get()
def generate(self, sequence, obj, create, params):
"""Evaluate the current definition and fill its attributes.
@@ -442,27 +461,11 @@ class RelatedFactory(PostGenerationDeclaration):
self.name = factory_related_name
self.defaults = defaults
-
- if isinstance(factory, type):
- self.factory = factory
- self.factory_module = self.factory_name = ''
- else:
- # Must be a string
- if not (compat.is_string(factory) and '.' in factory):
- raise ValueError(
- "The argument of a SubFactory must be either a class "
- "or the fully qualified path to a Factory class; got "
- "%r instead." % factory)
- self.factory = None
- self.factory_module, self.factory_name = factory.rsplit('.', 1)
+ self.factory_wrapper = _FactoryWrapper(factory)
def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
- if self.factory is None:
- # Must be a module path
- self.factory = utils.import_object(
- self.factory_module, self.factory_name)
- return self.factory
+ return self.factory_wrapper.get()
def call(self, obj, create, extracted=None, **kwargs):
passed_kwargs = dict(self.defaults)