diff options
-rw-r--r-- | factory/containers.py | 51 | ||||
-rw-r--r-- | factory/declarations.py | 39 | ||||
-rw-r--r-- | tests/test_declarations.py | 24 |
3 files changed, 62 insertions, 52 deletions
diff --git a/factory/containers.py b/factory/containers.py index 0859a10..dc3a457 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -172,30 +172,6 @@ class LazyValue(object): raise NotImplementedError("This is an abstract method.") -class SubFactoryWrapper(LazyValue): - """Lazy wrapper around a SubFactory. - - Attributes: - subfactory (declarations.SubFactory): the SubFactory being wrapped - subfields (DeclarationDict): Default values to override when evaluating - the SubFactory - create (bool): whether to 'create' or 'build' the SubFactory. - """ - - def __init__(self, subfactory, subfields, create, *args, **kwargs): - super(SubFactoryWrapper, self).__init__(*args, **kwargs) - self.subfactory = subfactory - self.subfields = subfields - self.create = create - - def evaluate(self, obj, containers=()): - expanded_containers = (obj,) - if containers: - expanded_containers += tuple(containers) - return self.subfactory.evaluate(self.create, self.subfields, - expanded_containers) - - class OrderedDeclarationWrapper(LazyValue): """Lazy wrapper around an OrderedDeclaration. @@ -206,10 +182,12 @@ class OrderedDeclarationWrapper(LazyValue): declaration """ - def __init__(self, declaration, sequence, *args, **kwargs): - super(OrderedDeclarationWrapper, self).__init__(*args, **kwargs) + def __init__(self, declaration, sequence, create, extra=None, **kwargs): + super(OrderedDeclarationWrapper, self).__init__(**kwargs) self.declaration = declaration self.sequence = sequence + self.create = create + self.extra = extra def evaluate(self, obj, containers=()): """Lazily evaluate the attached OrderedDeclaration. @@ -219,7 +197,14 @@ class OrderedDeclarationWrapper(LazyValue): containers (object list): the chain of containers of the object being built, its immediate holder being first. """ - return self.declaration.evaluate(self.sequence, obj, containers) + return self.declaration.evaluate(self.sequence, obj, + create=self.create, + extra=self.extra, + containers=containers, + ) + + def __repr__(self): + return '<%s for %r>' % (self.__class__.__name__, self.declaration) class AttributeBuilder(object): @@ -240,7 +225,7 @@ class AttributeBuilder(object): extra = {} self.factory = factory - self._containers = extra.pop('__containers', None) + self._containers = extra.pop('__containers', ()) self._attrs = factory.declarations(extra) attrs_with_subfields = [k for k, v in self._attrs.items() if self.has_subfields(v)] @@ -263,10 +248,12 @@ class AttributeBuilder(object): # OrderedDeclaration. wrapped_attrs = {} for k, v in self._attrs.items(): - if isinstance(v, declarations.SubFactory): - v = SubFactoryWrapper(v, self._subfields.get(k, {}), create) - elif isinstance(v, declarations.OrderedDeclaration): - v = OrderedDeclarationWrapper(v, self.factory.sequence) + if isinstance(v, declarations.OrderedDeclaration): + v = OrderedDeclarationWrapper(v, + sequence=self.factory.sequence, + create=create, + extra=self._subfields.get(k, {}), + ) wrapped_attrs[k] = v stub = LazyStub(wrapped_attrs, containers=self._containers, diff --git a/factory/declarations.py b/factory/declarations.py index 2122bd2..15d8d5b 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -37,7 +37,7 @@ class OrderedDeclaration(object): in the same factory. """ - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Evaluate this declaration. Args: @@ -47,6 +47,10 @@ class OrderedDeclaration(object): attributes containers (list of containers.LazyStub): The chain of SubFactory which led to building this object. + create (bool): whether the target class should be 'built' or + 'created' + extra (DeclarationDict or None): extracted key/value extracted from + the attribute prefix """ raise NotImplementedError('This is an abstract method') @@ -63,7 +67,7 @@ class LazyAttribute(OrderedDeclaration): super(LazyAttribute, self).__init__(*args, **kwargs) self.function = function - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): return self.function(obj) @@ -122,7 +126,7 @@ class SelfAttribute(OrderedDeclaration): self.attribute_name = attribute_name self.default = default - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): if self.depth > 1: # Fetching from a parent target = containers[self.depth - 2] @@ -130,6 +134,13 @@ class SelfAttribute(OrderedDeclaration): target = obj return deepgetattr(target, self.attribute_name, self.default) + def __repr__(self): + return '<%s(%r, default=%r)>' % ( + self.__class__.__name__, + self.attribute_name, + self.default, + ) + class Iterator(OrderedDeclaration): """Fill this value using the values returned by an iterator. @@ -150,7 +161,7 @@ class Iterator(OrderedDeclaration): else: self.iterator = iter(iterator) - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): value = next(self.iterator) if self.getter is None: return value @@ -173,7 +184,7 @@ class Sequence(OrderedDeclaration): self.function = function self.type = type - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): return self.function(self.type(sequence)) @@ -186,7 +197,7 @@ class LazyAttributeSequence(Sequence): type (function): A function converting an integer into the expected kind of counter for the 'function' attribute. """ - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): return self.function(obj, self.type(sequence)) @@ -204,7 +215,7 @@ class ContainerAttribute(OrderedDeclaration): self.function = function self.strict = strict - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Evaluate the current ContainerAttribute. Args: @@ -237,11 +248,20 @@ class ParameteredAttribute(OrderedDeclaration): CONTAINERS_FIELD = '__containers' + # Whether to add the current object to the stack of containers + EXTEND_CONTAINERS = False + def __init__(self, **kwargs): super(ParameteredAttribute, self).__init__() self.defaults = kwargs - def evaluate(self, create, extra, containers): + def _prepare_containers(self, obj, containers=()): + if self.EXTEND_CONTAINERS: + return (obj,) + tuple(containers) + + return containers + + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Evaluate the current definition and fill its attributes. Uses attributes definition in the following order: @@ -260,6 +280,7 @@ class ParameteredAttribute(OrderedDeclaration): if extra: defaults.update(extra) if self.CONTAINERS_FIELD: + containers = self._prepare_containers(obj, containers) defaults[self.CONTAINERS_FIELD] = containers return self.generate(create, defaults) @@ -288,6 +309,8 @@ class SubFactory(ParameteredAttribute): factory (base.Factory): the wrapped factory """ + EXTEND_CONTAINERS = True + def __init__(self, factory, **kwargs): super(SubFactory, self).__init__(**kwargs) if isinstance(factory, type): diff --git a/tests/test_declarations.py b/tests/test_declarations.py index b7ae344..7b9b0af 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -33,7 +33,7 @@ from . import tools class OrderedDeclarationTestCase(unittest.TestCase): def test_errors(self): decl = declarations.OrderedDeclaration() - self.assertRaises(NotImplementedError, decl.evaluate, None, {}) + self.assertRaises(NotImplementedError, decl.evaluate, None, {}, False) class DigTestCase(unittest.TestCase): @@ -95,23 +95,23 @@ class SelfAttributeTestCase(unittest.TestCase): class IteratorTestCase(unittest.TestCase): def test_cycle(self): it = declarations.Iterator([1, 2]) - self.assertEqual(1, it.evaluate(0, None)) - self.assertEqual(2, it.evaluate(1, None)) - self.assertEqual(1, it.evaluate(2, None)) - self.assertEqual(2, it.evaluate(3, None)) + self.assertEqual(1, it.evaluate(0, None, False)) + self.assertEqual(2, it.evaluate(1, None, False)) + self.assertEqual(1, it.evaluate(2, None, False)) + self.assertEqual(2, it.evaluate(3, None, False)) def test_no_cycling(self): it = declarations.Iterator([1, 2], cycle=False) - self.assertEqual(1, it.evaluate(0, None)) - self.assertEqual(2, it.evaluate(1, None)) - self.assertRaises(StopIteration, it.evaluate, 2, None) + self.assertEqual(1, it.evaluate(0, None, False)) + self.assertEqual(2, it.evaluate(1, None, False)) + self.assertRaises(StopIteration, it.evaluate, 2, None, False) def test_getter(self): it = declarations.Iterator([(1, 2), (1, 3)], getter=lambda p: p[1]) - self.assertEqual(2, it.evaluate(0, None)) - self.assertEqual(3, it.evaluate(1, None)) - self.assertEqual(2, it.evaluate(2, None)) - self.assertEqual(3, it.evaluate(3, None)) + self.assertEqual(2, it.evaluate(0, None, False)) + self.assertEqual(3, it.evaluate(1, None, False)) + self.assertEqual(2, it.evaluate(2, None, False)) + self.assertEqual(3, it.evaluate(3, None, False)) class PostGenerationDeclarationTestCase(unittest.TestCase): |