diff options
-rw-r--r-- | docs/subfactory.rst | 25 | ||||
-rw-r--r-- | factory/__init__.py | 1 | ||||
-rw-r--r-- | factory/declarations.py | 105 | ||||
-rw-r--r-- | factory/utils.py | 12 | ||||
-rw-r--r-- | tests/test_declarations.py | 31 | ||||
-rw-r--r-- | tests/test_utils.py | 15 |
6 files changed, 170 insertions, 19 deletions
diff --git a/docs/subfactory.rst b/docs/subfactory.rst index 1815312..f8642f3 100644 --- a/docs/subfactory.rst +++ b/docs/subfactory.rst @@ -54,3 +54,28 @@ Fields of the SubFactory can also be overridden when instantiating the external <User: Henry Jones> >>> c.owner.email henry.jones@example.org + + +Circular dependencies +--------------------- + +In order to solve circular dependency issues, Factory Boy provides the :class:`~factory.CircularSubFactory` class. + +This class expects a module name and a factory name to import from that module; the given module will be imported +(as an absolute import) when the factory is first accessed:: + + # foo/factories.py + import factory + + from bar import factories + + class FooFactory(factory.Factory): + bar = factory.SubFactory(factories.BarFactory) + + + # bar/factories.py + import factory + + class BarFactory(factory.Factory): + # Avoid circular import + foo = factory.CircularSubFactory('foo.factories', 'FooFactory', bar=None) diff --git a/factory/__init__.py b/factory/__init__.py index 3753461..789d88e 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -60,6 +60,7 @@ from declarations import ( SelfAttribute, ContainerAttribute, SubFactory, + CircularSubFactory, PostGeneration, RelatedFactory, diff --git a/factory/declarations.py b/factory/declarations.py index 83c32ab..5e45255 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -208,46 +208,115 @@ class ContainerAttribute(OrderedDeclaration): return self.function(obj, containers) -class SubFactory(OrderedDeclaration): - """Base class for attributes based upon a sub-factory. +class ParameteredAttribute(OrderedDeclaration): + """Base class for attributes expecting parameters. Attributes: - defaults (dict): Overrides to the defaults defined in the wrapped - factory - factory (base.Factory): the wrapped factory + defaults (dict): Default values for the paramters. + May be overridden by call-time parameters. + + Class attributes: + CONTAINERS_FIELD (str): name of the field, if any, where container + information (e.g for SubFactory) should be stored. If empty, + containers data isn't merged into generate() parameters. """ - def __init__(self, factory, **kwargs): - super(SubFactory, self).__init__() + CONTAINERS_FIELD = '__containers' + + def __init__(self, **kwargs): + super(ParameteredAttribute, self).__init__() self.defaults = kwargs - self.factory = factory def evaluate(self, create, extra, containers): """Evaluate the current definition and fill its attributes. Uses attributes definition in the following order: - - attributes defined in the wrapped factory class - - values defined when defining the SubFactory - - additional values defined in attributes + - values defined when defining the ParameteredAttribute + - additional values defined when instantiating the containing factory Args: - create (bool): whether the subfactory should call 'build' or - 'create' + create (bool): whether the parent factory is being 'built' or + 'created' extra (containers.DeclarationDict): extra values that should - override the wrapped factory's defaults + override the defaults containers (list of LazyStub): List of LazyStub for the chain of factories being evaluated, the calling stub being first. """ - defaults = dict(self.defaults) if extra: defaults.update(extra) - defaults['__containers'] = containers + if self.CONTAINERS_FIELD: + defaults[self.CONTAINERS_FIELD] = containers + + return self.generate(create, defaults) + + def generate(self, create, params): + """Actually generate the related attribute. + + Args: + create (bool): whether the calling factory was in 'create' or + 'build' mode + params (dict): parameters inherited from init and evaluation-time + overrides. + + Returns: + Computed value for the current declaration. + """ + raise NotImplementedError() + + +class SubFactory(ParameteredAttribute): + """Base class for attributes based upon a sub-factory. + Attributes: + defaults (dict): Overrides to the defaults defined in the wrapped + factory + factory (base.Factory): the wrapped factory + """ + + def __init__(self, factory, **kwargs): + super(SubFactory, self).__init__(**kwargs) + self.factory = factory + + def get_factory(self): + """Retrieve the wrapped factory.Factory subclass.""" + return self.factory + + def generate(self, create, params): + """Evaluate the current definition and fill its attributes. + + Args: + create (bool): whether the subfactory should call 'build' or + 'create' + params (containers.DeclarationDict): extra values that should + override the wrapped factory's defaults + """ + subfactory = self.get_factory() if create: - return self.factory.create(**defaults) + return subfactory.create(**params) else: - return self.factory.build(**defaults) + return subfactory.build(**params) + + +class CircularSubFactory(SubFactory): + """Use to solve circular dependencies issues.""" + def __init__(self, module_name, factory_name, **kwargs): + super(CircularSubFactory, self).__init__(None, **kwargs) + self.module_name = module_name + self.factory_name = factory_name + + def get_factory(self): + """Retrieve the factory.Factory subclass. + + Its value is cached in the 'factory' attribute, and retrieved through + the factory_getter callable. + """ + if self.factory is None: + factory_class = utils.import_object( + self.module_name, self.factory_name) + + self.factory = factory_class + return self.factory class PostGenerationDeclaration(object): diff --git a/factory/utils.py b/factory/utils.py index c592da4..e7cdf5f 100644 --- a/factory/utils.py +++ b/factory/utils.py @@ -81,3 +81,15 @@ def multi_extract_dict(prefixes, kwargs, pop=True, exclude=()): ['%s%s%s' % (prefix, ATTR_SPLITTER, key) for key in extracted]) return results + + +def import_object(module_name, attribute_name): + """Import an object from its absolute path. + + Example: + >>> import_object('datetime', 'datetime') + <type 'datetime.datetime'> + """ + module = __import__(module_name, {}, {}, [attribute_name], 0) + return getattr(module, attribute_name) + diff --git a/tests/test_declarations.py b/tests/test_declarations.py index 1c0502b..c0b3539 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -20,8 +20,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import datetime -from factory.declarations import deepgetattr, OrderedDeclaration, \ +from factory.declarations import deepgetattr, CircularSubFactory, OrderedDeclaration, \ PostGenerationDeclaration, Sequence from .compat import unittest @@ -73,6 +74,34 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): self.assertEqual(kwargs, {'baz': 1}) +class CircularSubFactoryTestCase(unittest.TestCase): + def test_lazyness(self): + f = CircularSubFactory('factory.declarations', 'Sequence', x=3) + self.assertEqual(None, f.factory) + + self.assertEqual({'x': 3}, f.defaults) + + factory_class = f.get_factory() + self.assertEqual(Sequence, factory_class) + + def test_cache(self): + orig_date = datetime.date + f = CircularSubFactory('datetime', 'date') + self.assertEqual(None, f.factory) + + factory_class = f.get_factory() + self.assertEqual(orig_date, factory_class) + + try: + # Modify original value + datetime.date = None + # Repeat import + factory_class = f.get_factory() + self.assertEqual(orig_date, factory_class) + + finally: + # IMPORTANT: restore attribute. + datetime.date = orig_date if __name__ == '__main__': unittest.main() diff --git a/tests/test_utils.py b/tests/test_utils.py index f30c0e3..dbc357b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -212,3 +212,18 @@ class MultiExtractDictTestCase(unittest.TestCase): self.assertNotIn('foo__foo__bar', d) self.assertNotIn('bar__foo', d) self.assertNotIn('bar__bar__baz', d) + +class ImportObjectTestCase(unittest.TestCase): + def test_datetime(self): + imported = utils.import_object('datetime', 'date') + import datetime + d = datetime.date + self.assertEqual(d, imported) + + def test_unknown_attribute(self): + self.assertRaises(AttributeError, utils.import_object, + 'datetime', 'foo') + + def test_invalid_module(self): + self.assertRaises(ImportError, utils.import_object, + 'this-is-an-invalid-module', '__name__') |