diff options
-rw-r--r-- | docs/reference.rst | 33 | ||||
-rw-r--r-- | factory/declarations.py | 25 | ||||
-rw-r--r-- | tests/test_declarations.py | 35 |
3 files changed, 82 insertions, 11 deletions
diff --git a/docs/reference.rst b/docs/reference.rst index efb70e8..e2246aa 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -215,7 +215,7 @@ factory_boy supports two main strategies for generating instances, plus stubs. object, then save it: .. code-block:: pycon - + >>> obj = self._associated_class(*args, **kwargs) >>> obj.save() >>> return obj @@ -849,18 +849,32 @@ To support this pattern, factory_boy provides the following tools: RelatedFactory """""""""""""" -.. class:: RelatedFactory(some_factory, related_field, **kwargs) +.. class:: RelatedFactory(factory, name='', **kwargs) .. OHAI_VIM** -A :class:`RelatedFactory` behaves mostly like a :class:`SubFactory`, -with the main difference that it should be provided with a ``related_field`` name -as second argument. + A :class:`RelatedFactory` behaves mostly like a :class:`SubFactory`, + with the main difference that the related :class:`Factory` will be generated + *after* the base :class:`Factory`. + + + .. attribute:: factory + + As for :class:`SubFactory`, the :attr:`factory` argument can be: + + - A :class:`Factory` subclass + - Or the fully qualified path to a :class:`Factory` subclass + (see :ref:`subfactory-circular` for details) + + .. attribute:: name + + The generated object (where the :class:`RelatedFactory` attribute will + set) may be passed to the related factory if the :attr:`name` parameter + is set. + + It will be passed as a keyword argument, using the :attr:`name` value as + keyword: -Once the base object has been built (or created), the :class:`RelatedFactory` will -build the :class:`Factory` passed as first argument (with the same strategy), -passing in the base object as a keyword argument whose name is passed in the -``related_field`` argument: .. code-block:: python @@ -882,6 +896,7 @@ passing in the base object as a keyword argument whose name is passed in the >>> City.objects.get(capital_of=france) <City: Paris> + Extra kwargs may be passed to the related factory, through the usual ``ATTR__SUBATTR`` syntax: .. code-block:: pycon diff --git a/factory/declarations.py b/factory/declarations.py index 83f4d32..d3d7659 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -445,16 +445,37 @@ class RelatedFactory(PostGenerationDeclaration): def __init__(self, factory, name='', **defaults): super(RelatedFactory, self).__init__(extract_prefix=None) - self.factory = factory self.name = name self.defaults = defaults + if isinstance(factory, type): + self.factory = factory + self.factory_module = self.factory_name = '' + else: + # Must be a string + if not isinstance(factory, basestring) or '.' not in factory: + raise ValueError( + "The argument of a SubFactory must be either a class " + "or the fully qualified path to a Factory class; got " + "%r instead." % factory) + self.factory = None + self.factory_module, self.factory_name = factory.rsplit('.', 1) + + def get_factory(self): + """Retrieve the wrapped factory.Factory subclass.""" + if self.factory is None: + # Must be a module path + self.factory = utils.import_object(self.factory_module, self.factory_name) + return self.factory + def call(self, obj, create, extracted=None, **kwargs): passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - self.factory.simple_generate(create, **passed_kwargs) + + factory = self.get_factory() + factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): diff --git a/tests/test_declarations.py b/tests/test_declarations.py index c57e77d..cc921d4 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -260,6 +260,41 @@ class SubFactoryTestCase(unittest.TestCase): datetime.date = orig_date +class RelatedFactoryTestCase(unittest.TestCase): + + def test_arg(self): + self.assertRaises(ValueError, declarations.RelatedFactory, 'UnqualifiedSymbol') + + def test_lazyness(self): + f = declarations.RelatedFactory('factory.declarations.Sequence', x=3) + self.assertEqual(None, f.factory) + + self.assertEqual({'x': 3}, f.defaults) + + factory_class = f.get_factory() + self.assertEqual(declarations.Sequence, factory_class) + + def test_cache(self): + """Ensure that RelatedFactory tries to import only once.""" + orig_date = datetime.date + f = declarations.RelatedFactory('datetime.date') + self.assertEqual(None, f.factory) + + factory_class = f.get_factory() + self.assertEqual(orig_date, factory_class) + + try: + # Modify original value + datetime.date = None + # Repeat import + factory_class = f.get_factory() + self.assertEqual(orig_date, factory_class) + + finally: + # IMPORTANT: restore attribute. + datetime.date = orig_date + + class CircularSubFactoryTestCase(unittest.TestCase): def test_circularsubfactory_deprecated(self): |