From 9422cf12516143650f1014f34f996260c00d4c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Sun, 3 Mar 2013 22:10:42 +0100 Subject: Allow symbol names in RelatedFactory (Closes #30). MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This works exactly as for SubFactory. Signed-off-by: Raphaƫl Barrois --- factory/declarations.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) (limited to 'factory/declarations.py') diff --git a/factory/declarations.py b/factory/declarations.py index 83f4d32..d3d7659 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -445,16 +445,37 @@ class RelatedFactory(PostGenerationDeclaration): def __init__(self, factory, name='', **defaults): super(RelatedFactory, self).__init__(extract_prefix=None) - self.factory = factory self.name = name self.defaults = defaults + if isinstance(factory, type): + self.factory = factory + self.factory_module = self.factory_name = '' + else: + # Must be a string + if not isinstance(factory, basestring) or '.' not in factory: + raise ValueError( + "The argument of a SubFactory must be either a class " + "or the fully qualified path to a Factory class; got " + "%r instead." % factory) + self.factory = None + self.factory_module, self.factory_name = factory.rsplit('.', 1) + + def get_factory(self): + """Retrieve the wrapped factory.Factory subclass.""" + if self.factory is None: + # Must be a module path + self.factory = utils.import_object(self.factory_module, self.factory_name) + return self.factory + def call(self, obj, create, extracted=None, **kwargs): passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - self.factory.simple_generate(create, **passed_kwargs) + + factory = self.get_factory() + factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): -- cgit v1.2.3