summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/reference.rst33
-rw-r--r--factory/declarations.py25
-rw-r--r--tests/test_declarations.py35
3 files changed, 82 insertions, 11 deletions
diff --git a/docs/reference.rst b/docs/reference.rst
index efb70e8..e2246aa 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -215,7 +215,7 @@ factory_boy supports two main strategies for generating instances, plus stubs.
object, then save it:
.. code-block:: pycon
-
+
>>> obj = self._associated_class(*args, **kwargs)
>>> obj.save()
>>> return obj
@@ -849,18 +849,32 @@ To support this pattern, factory_boy provides the following tools:
RelatedFactory
""""""""""""""
-.. class:: RelatedFactory(some_factory, related_field, **kwargs)
+.. class:: RelatedFactory(factory, name='', **kwargs)
.. OHAI_VIM**
-A :class:`RelatedFactory` behaves mostly like a :class:`SubFactory`,
-with the main difference that it should be provided with a ``related_field`` name
-as second argument.
+ A :class:`RelatedFactory` behaves mostly like a :class:`SubFactory`,
+ with the main difference that the related :class:`Factory` will be generated
+ *after* the base :class:`Factory`.
+
+
+ .. attribute:: factory
+
+ As for :class:`SubFactory`, the :attr:`factory` argument can be:
+
+ - A :class:`Factory` subclass
+ - Or the fully qualified path to a :class:`Factory` subclass
+ (see :ref:`subfactory-circular` for details)
+
+ .. attribute:: name
+
+ The generated object (where the :class:`RelatedFactory` attribute will
+ set) may be passed to the related factory if the :attr:`name` parameter
+ is set.
+
+ It will be passed as a keyword argument, using the :attr:`name` value as
+ keyword:
-Once the base object has been built (or created), the :class:`RelatedFactory` will
-build the :class:`Factory` passed as first argument (with the same strategy),
-passing in the base object as a keyword argument whose name is passed in the
-``related_field`` argument:
.. code-block:: python
@@ -882,6 +896,7 @@ passing in the base object as a keyword argument whose name is passed in the
>>> City.objects.get(capital_of=france)
<City: Paris>
+
Extra kwargs may be passed to the related factory, through the usual ``ATTR__SUBATTR`` syntax:
.. code-block:: pycon
diff --git a/factory/declarations.py b/factory/declarations.py
index 83f4d32..d3d7659 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -445,16 +445,37 @@ class RelatedFactory(PostGenerationDeclaration):
def __init__(self, factory, name='', **defaults):
super(RelatedFactory, self).__init__(extract_prefix=None)
- self.factory = factory
self.name = name
self.defaults = defaults
+ if isinstance(factory, type):
+ self.factory = factory
+ self.factory_module = self.factory_name = ''
+ else:
+ # Must be a string
+ if not isinstance(factory, basestring) or '.' not in factory:
+ raise ValueError(
+ "The argument of a SubFactory must be either a class "
+ "or the fully qualified path to a Factory class; got "
+ "%r instead." % factory)
+ self.factory = None
+ self.factory_module, self.factory_name = factory.rsplit('.', 1)
+
+ def get_factory(self):
+ """Retrieve the wrapped factory.Factory subclass."""
+ if self.factory is None:
+ # Must be a module path
+ self.factory = utils.import_object(self.factory_module, self.factory_name)
+ return self.factory
+
def call(self, obj, create, extracted=None, **kwargs):
passed_kwargs = dict(self.defaults)
passed_kwargs.update(kwargs)
if self.name:
passed_kwargs[self.name] = obj
- self.factory.simple_generate(create, **passed_kwargs)
+
+ factory = self.get_factory()
+ factory.simple_generate(create, **passed_kwargs)
class PostGenerationMethodCall(PostGenerationDeclaration):
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index c57e77d..cc921d4 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -260,6 +260,41 @@ class SubFactoryTestCase(unittest.TestCase):
datetime.date = orig_date
+class RelatedFactoryTestCase(unittest.TestCase):
+
+ def test_arg(self):
+ self.assertRaises(ValueError, declarations.RelatedFactory, 'UnqualifiedSymbol')
+
+ def test_lazyness(self):
+ f = declarations.RelatedFactory('factory.declarations.Sequence', x=3)
+ self.assertEqual(None, f.factory)
+
+ self.assertEqual({'x': 3}, f.defaults)
+
+ factory_class = f.get_factory()
+ self.assertEqual(declarations.Sequence, factory_class)
+
+ def test_cache(self):
+ """Ensure that RelatedFactory tries to import only once."""
+ orig_date = datetime.date
+ f = declarations.RelatedFactory('datetime.date')
+ self.assertEqual(None, f.factory)
+
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ try:
+ # Modify original value
+ datetime.date = None
+ # Repeat import
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ finally:
+ # IMPORTANT: restore attribute.
+ datetime.date = orig_date
+
+
class CircularSubFactoryTestCase(unittest.TestCase):
def test_circularsubfactory_deprecated(self):