diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-03-03 22:10:42 +0100 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-03-03 22:10:42 +0100 |
commit | 9422cf12516143650f1014f34f996260c00d4c0a (patch) | |
tree | 879113f8d65560dc9fd435872ffbf1cc44af01de /factory | |
parent | f8708d936be1aa53a8b61f95cda6edcdbd8fc00a (diff) | |
download | factory-boy-9422cf12516143650f1014f34f996260c00d4c0a.tar factory-boy-9422cf12516143650f1014f34f996260c00d4c0a.tar.gz |
Allow symbol names in RelatedFactory (Closes #30).
This works exactly as for SubFactory.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
Diffstat (limited to 'factory')
-rw-r--r-- | factory/declarations.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 83f4d32..d3d7659 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -445,16 +445,37 @@ class RelatedFactory(PostGenerationDeclaration): def __init__(self, factory, name='', **defaults): super(RelatedFactory, self).__init__(extract_prefix=None) - self.factory = factory self.name = name self.defaults = defaults + if isinstance(factory, type): + self.factory = factory + self.factory_module = self.factory_name = '' + else: + # Must be a string + if not isinstance(factory, basestring) or '.' not in factory: + raise ValueError( + "The argument of a SubFactory must be either a class " + "or the fully qualified path to a Factory class; got " + "%r instead." % factory) + self.factory = None + self.factory_module, self.factory_name = factory.rsplit('.', 1) + + def get_factory(self): + """Retrieve the wrapped factory.Factory subclass.""" + if self.factory is None: + # Must be a module path + self.factory = utils.import_object(self.factory_module, self.factory_name) + return self.factory + def call(self, obj, create, extracted=None, **kwargs): passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - self.factory.simple_generate(create, **passed_kwargs) + + factory = self.get_factory() + factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): |