summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 83f4d32..d3d7659 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -445,16 +445,37 @@ class RelatedFactory(PostGenerationDeclaration):
def __init__(self, factory, name='', **defaults):
super(RelatedFactory, self).__init__(extract_prefix=None)
- self.factory = factory
self.name = name
self.defaults = defaults
+ if isinstance(factory, type):
+ self.factory = factory
+ self.factory_module = self.factory_name = ''
+ else:
+ # Must be a string
+ if not isinstance(factory, basestring) or '.' not in factory:
+ raise ValueError(
+ "The argument of a SubFactory must be either a class "
+ "or the fully qualified path to a Factory class; got "
+ "%r instead." % factory)
+ self.factory = None
+ self.factory_module, self.factory_name = factory.rsplit('.', 1)
+
+ def get_factory(self):
+ """Retrieve the wrapped factory.Factory subclass."""
+ if self.factory is None:
+ # Must be a module path
+ self.factory = utils.import_object(self.factory_module, self.factory_name)
+ return self.factory
+
def call(self, obj, create, extracted=None, **kwargs):
passed_kwargs = dict(self.defaults)
passed_kwargs.update(kwargs)
if self.name:
passed_kwargs[self.name] = obj
- self.factory.simple_generate(create, **passed_kwargs)
+
+ factory = self.get_factory()
+ factory.simple_generate(create, **passed_kwargs)
class PostGenerationMethodCall(PostGenerationDeclaration):