summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-03 22:10:42 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-03 22:10:42 +0100
commit9422cf12516143650f1014f34f996260c00d4c0a (patch)
tree879113f8d65560dc9fd435872ffbf1cc44af01de /factory/declarations.py
parentf8708d936be1aa53a8b61f95cda6edcdbd8fc00a (diff)
downloadfactory-boy-9422cf12516143650f1014f34f996260c00d4c0a.tar
factory-boy-9422cf12516143650f1014f34f996260c00d4c0a.tar.gz
Allow symbol names in RelatedFactory (Closes #30).
This works exactly as for SubFactory. Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 83f4d32..d3d7659 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -445,16 +445,37 @@ class RelatedFactory(PostGenerationDeclaration):
def __init__(self, factory, name='', **defaults):
super(RelatedFactory, self).__init__(extract_prefix=None)
- self.factory = factory
self.name = name
self.defaults = defaults
+ if isinstance(factory, type):
+ self.factory = factory
+ self.factory_module = self.factory_name = ''
+ else:
+ # Must be a string
+ if not isinstance(factory, basestring) or '.' not in factory:
+ raise ValueError(
+ "The argument of a SubFactory must be either a class "
+ "or the fully qualified path to a Factory class; got "
+ "%r instead." % factory)
+ self.factory = None
+ self.factory_module, self.factory_name = factory.rsplit('.', 1)
+
+ def get_factory(self):
+ """Retrieve the wrapped factory.Factory subclass."""
+ if self.factory is None:
+ # Must be a module path
+ self.factory = utils.import_object(self.factory_module, self.factory_name)
+ return self.factory
+
def call(self, obj, create, extracted=None, **kwargs):
passed_kwargs = dict(self.defaults)
passed_kwargs.update(kwargs)
if self.name:
passed_kwargs[self.name] = obj
- self.factory.simple_generate(create, **passed_kwargs)
+
+ factory = self.get_factory()
+ factory.simple_generate(create, **passed_kwargs)
class PostGenerationMethodCall(PostGenerationDeclaration):