summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/subfactory.rst25
-rw-r--r--factory/__init__.py1
-rw-r--r--factory/declarations.py105
-rw-r--r--factory/utils.py12
-rw-r--r--tests/test_declarations.py31
-rw-r--r--tests/test_utils.py15
6 files changed, 170 insertions, 19 deletions
diff --git a/docs/subfactory.rst b/docs/subfactory.rst
index 1815312..f8642f3 100644
--- a/docs/subfactory.rst
+++ b/docs/subfactory.rst
@@ -54,3 +54,28 @@ Fields of the SubFactory can also be overridden when instantiating the external
<User: Henry Jones>
>>> c.owner.email
henry.jones@example.org
+
+
+Circular dependencies
+---------------------
+
+In order to solve circular dependency issues, Factory Boy provides the :class:`~factory.CircularSubFactory` class.
+
+This class expects a module name and a factory name to import from that module; the given module will be imported
+(as an absolute import) when the factory is first accessed::
+
+ # foo/factories.py
+ import factory
+
+ from bar import factories
+
+ class FooFactory(factory.Factory):
+ bar = factory.SubFactory(factories.BarFactory)
+
+
+ # bar/factories.py
+ import factory
+
+ class BarFactory(factory.Factory):
+ # Avoid circular import
+ foo = factory.CircularSubFactory('foo.factories', 'FooFactory', bar=None)
diff --git a/factory/__init__.py b/factory/__init__.py
index 3753461..789d88e 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -60,6 +60,7 @@ from declarations import (
SelfAttribute,
ContainerAttribute,
SubFactory,
+ CircularSubFactory,
PostGeneration,
RelatedFactory,
diff --git a/factory/declarations.py b/factory/declarations.py
index 83c32ab..5e45255 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -208,46 +208,115 @@ class ContainerAttribute(OrderedDeclaration):
return self.function(obj, containers)
-class SubFactory(OrderedDeclaration):
- """Base class for attributes based upon a sub-factory.
+class ParameteredAttribute(OrderedDeclaration):
+ """Base class for attributes expecting parameters.
Attributes:
- defaults (dict): Overrides to the defaults defined in the wrapped
- factory
- factory (base.Factory): the wrapped factory
+ defaults (dict): Default values for the paramters.
+ May be overridden by call-time parameters.
+
+ Class attributes:
+ CONTAINERS_FIELD (str): name of the field, if any, where container
+ information (e.g for SubFactory) should be stored. If empty,
+ containers data isn't merged into generate() parameters.
"""
- def __init__(self, factory, **kwargs):
- super(SubFactory, self).__init__()
+ CONTAINERS_FIELD = '__containers'
+
+ def __init__(self, **kwargs):
+ super(ParameteredAttribute, self).__init__()
self.defaults = kwargs
- self.factory = factory
def evaluate(self, create, extra, containers):
"""Evaluate the current definition and fill its attributes.
Uses attributes definition in the following order:
- - attributes defined in the wrapped factory class
- - values defined when defining the SubFactory
- - additional values defined in attributes
+ - values defined when defining the ParameteredAttribute
+ - additional values defined when instantiating the containing factory
Args:
- create (bool): whether the subfactory should call 'build' or
- 'create'
+ create (bool): whether the parent factory is being 'built' or
+ 'created'
extra (containers.DeclarationDict): extra values that should
- override the wrapped factory's defaults
+ override the defaults
containers (list of LazyStub): List of LazyStub for the chain of
factories being evaluated, the calling stub being first.
"""
-
defaults = dict(self.defaults)
if extra:
defaults.update(extra)
- defaults['__containers'] = containers
+ if self.CONTAINERS_FIELD:
+ defaults[self.CONTAINERS_FIELD] = containers
+
+ return self.generate(create, defaults)
+
+ def generate(self, create, params):
+ """Actually generate the related attribute.
+
+ Args:
+ create (bool): whether the calling factory was in 'create' or
+ 'build' mode
+ params (dict): parameters inherited from init and evaluation-time
+ overrides.
+
+ Returns:
+ Computed value for the current declaration.
+ """
+ raise NotImplementedError()
+
+
+class SubFactory(ParameteredAttribute):
+ """Base class for attributes based upon a sub-factory.
+ Attributes:
+ defaults (dict): Overrides to the defaults defined in the wrapped
+ factory
+ factory (base.Factory): the wrapped factory
+ """
+
+ def __init__(self, factory, **kwargs):
+ super(SubFactory, self).__init__(**kwargs)
+ self.factory = factory
+
+ def get_factory(self):
+ """Retrieve the wrapped factory.Factory subclass."""
+ return self.factory
+
+ def generate(self, create, params):
+ """Evaluate the current definition and fill its attributes.
+
+ Args:
+ create (bool): whether the subfactory should call 'build' or
+ 'create'
+ params (containers.DeclarationDict): extra values that should
+ override the wrapped factory's defaults
+ """
+ subfactory = self.get_factory()
if create:
- return self.factory.create(**defaults)
+ return subfactory.create(**params)
else:
- return self.factory.build(**defaults)
+ return subfactory.build(**params)
+
+
+class CircularSubFactory(SubFactory):
+ """Use to solve circular dependencies issues."""
+ def __init__(self, module_name, factory_name, **kwargs):
+ super(CircularSubFactory, self).__init__(None, **kwargs)
+ self.module_name = module_name
+ self.factory_name = factory_name
+
+ def get_factory(self):
+ """Retrieve the factory.Factory subclass.
+
+ Its value is cached in the 'factory' attribute, and retrieved through
+ the factory_getter callable.
+ """
+ if self.factory is None:
+ factory_class = utils.import_object(
+ self.module_name, self.factory_name)
+
+ self.factory = factory_class
+ return self.factory
class PostGenerationDeclaration(object):
diff --git a/factory/utils.py b/factory/utils.py
index c592da4..e7cdf5f 100644
--- a/factory/utils.py
+++ b/factory/utils.py
@@ -81,3 +81,15 @@ def multi_extract_dict(prefixes, kwargs, pop=True, exclude=()):
['%s%s%s' % (prefix, ATTR_SPLITTER, key) for key in extracted])
return results
+
+
+def import_object(module_name, attribute_name):
+ """Import an object from its absolute path.
+
+ Example:
+ >>> import_object('datetime', 'datetime')
+ <type 'datetime.datetime'>
+ """
+ module = __import__(module_name, {}, {}, [attribute_name], 0)
+ return getattr(module, attribute_name)
+
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index 1c0502b..c0b3539 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -20,8 +20,9 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import datetime
-from factory.declarations import deepgetattr, OrderedDeclaration, \
+from factory.declarations import deepgetattr, CircularSubFactory, OrderedDeclaration, \
PostGenerationDeclaration, Sequence
from .compat import unittest
@@ -73,6 +74,34 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
self.assertEqual(kwargs, {'baz': 1})
+class CircularSubFactoryTestCase(unittest.TestCase):
+ def test_lazyness(self):
+ f = CircularSubFactory('factory.declarations', 'Sequence', x=3)
+ self.assertEqual(None, f.factory)
+
+ self.assertEqual({'x': 3}, f.defaults)
+
+ factory_class = f.get_factory()
+ self.assertEqual(Sequence, factory_class)
+
+ def test_cache(self):
+ orig_date = datetime.date
+ f = CircularSubFactory('datetime', 'date')
+ self.assertEqual(None, f.factory)
+
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ try:
+ # Modify original value
+ datetime.date = None
+ # Repeat import
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ finally:
+ # IMPORTANT: restore attribute.
+ datetime.date = orig_date
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index f30c0e3..dbc357b 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -212,3 +212,18 @@ class MultiExtractDictTestCase(unittest.TestCase):
self.assertNotIn('foo__foo__bar', d)
self.assertNotIn('bar__foo', d)
self.assertNotIn('bar__bar__baz', d)
+
+class ImportObjectTestCase(unittest.TestCase):
+ def test_datetime(self):
+ imported = utils.import_object('datetime', 'date')
+ import datetime
+ d = datetime.date
+ self.assertEqual(d, imported)
+
+ def test_unknown_attribute(self):
+ self.assertRaises(AttributeError, utils.import_object,
+ 'datetime', 'foo')
+
+ def test_invalid_module(self):
+ self.assertRaises(ImportError, utils.import_object,
+ 'this-is-an-invalid-module', '__name__')