summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/containers.py51
-rw-r--r--factory/declarations.py39
-rw-r--r--tests/test_declarations.py24
3 files changed, 62 insertions, 52 deletions
diff --git a/factory/containers.py b/factory/containers.py
index 0859a10..dc3a457 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -172,30 +172,6 @@ class LazyValue(object):
raise NotImplementedError("This is an abstract method.")
-class SubFactoryWrapper(LazyValue):
- """Lazy wrapper around a SubFactory.
-
- Attributes:
- subfactory (declarations.SubFactory): the SubFactory being wrapped
- subfields (DeclarationDict): Default values to override when evaluating
- the SubFactory
- create (bool): whether to 'create' or 'build' the SubFactory.
- """
-
- def __init__(self, subfactory, subfields, create, *args, **kwargs):
- super(SubFactoryWrapper, self).__init__(*args, **kwargs)
- self.subfactory = subfactory
- self.subfields = subfields
- self.create = create
-
- def evaluate(self, obj, containers=()):
- expanded_containers = (obj,)
- if containers:
- expanded_containers += tuple(containers)
- return self.subfactory.evaluate(self.create, self.subfields,
- expanded_containers)
-
-
class OrderedDeclarationWrapper(LazyValue):
"""Lazy wrapper around an OrderedDeclaration.
@@ -206,10 +182,12 @@ class OrderedDeclarationWrapper(LazyValue):
declaration
"""
- def __init__(self, declaration, sequence, *args, **kwargs):
- super(OrderedDeclarationWrapper, self).__init__(*args, **kwargs)
+ def __init__(self, declaration, sequence, create, extra=None, **kwargs):
+ super(OrderedDeclarationWrapper, self).__init__(**kwargs)
self.declaration = declaration
self.sequence = sequence
+ self.create = create
+ self.extra = extra
def evaluate(self, obj, containers=()):
"""Lazily evaluate the attached OrderedDeclaration.
@@ -219,7 +197,14 @@ class OrderedDeclarationWrapper(LazyValue):
containers (object list): the chain of containers of the object
being built, its immediate holder being first.
"""
- return self.declaration.evaluate(self.sequence, obj, containers)
+ return self.declaration.evaluate(self.sequence, obj,
+ create=self.create,
+ extra=self.extra,
+ containers=containers,
+ )
+
+ def __repr__(self):
+ return '<%s for %r>' % (self.__class__.__name__, self.declaration)
class AttributeBuilder(object):
@@ -240,7 +225,7 @@ class AttributeBuilder(object):
extra = {}
self.factory = factory
- self._containers = extra.pop('__containers', None)
+ self._containers = extra.pop('__containers', ())
self._attrs = factory.declarations(extra)
attrs_with_subfields = [k for k, v in self._attrs.items() if self.has_subfields(v)]
@@ -263,10 +248,12 @@ class AttributeBuilder(object):
# OrderedDeclaration.
wrapped_attrs = {}
for k, v in self._attrs.items():
- if isinstance(v, declarations.SubFactory):
- v = SubFactoryWrapper(v, self._subfields.get(k, {}), create)
- elif isinstance(v, declarations.OrderedDeclaration):
- v = OrderedDeclarationWrapper(v, self.factory.sequence)
+ if isinstance(v, declarations.OrderedDeclaration):
+ v = OrderedDeclarationWrapper(v,
+ sequence=self.factory.sequence,
+ create=create,
+ extra=self._subfields.get(k, {}),
+ )
wrapped_attrs[k] = v
stub = LazyStub(wrapped_attrs, containers=self._containers,
diff --git a/factory/declarations.py b/factory/declarations.py
index 2122bd2..15d8d5b 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -37,7 +37,7 @@ class OrderedDeclaration(object):
in the same factory.
"""
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
"""Evaluate this declaration.
Args:
@@ -47,6 +47,10 @@ class OrderedDeclaration(object):
attributes
containers (list of containers.LazyStub): The chain of SubFactory
which led to building this object.
+ create (bool): whether the target class should be 'built' or
+ 'created'
+ extra (DeclarationDict or None): extracted key/value extracted from
+ the attribute prefix
"""
raise NotImplementedError('This is an abstract method')
@@ -63,7 +67,7 @@ class LazyAttribute(OrderedDeclaration):
super(LazyAttribute, self).__init__(*args, **kwargs)
self.function = function
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
return self.function(obj)
@@ -122,7 +126,7 @@ class SelfAttribute(OrderedDeclaration):
self.attribute_name = attribute_name
self.default = default
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
if self.depth > 1:
# Fetching from a parent
target = containers[self.depth - 2]
@@ -130,6 +134,13 @@ class SelfAttribute(OrderedDeclaration):
target = obj
return deepgetattr(target, self.attribute_name, self.default)
+ def __repr__(self):
+ return '<%s(%r, default=%r)>' % (
+ self.__class__.__name__,
+ self.attribute_name,
+ self.default,
+ )
+
class Iterator(OrderedDeclaration):
"""Fill this value using the values returned by an iterator.
@@ -150,7 +161,7 @@ class Iterator(OrderedDeclaration):
else:
self.iterator = iter(iterator)
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
value = next(self.iterator)
if self.getter is None:
return value
@@ -173,7 +184,7 @@ class Sequence(OrderedDeclaration):
self.function = function
self.type = type
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
return self.function(self.type(sequence))
@@ -186,7 +197,7 @@ class LazyAttributeSequence(Sequence):
type (function): A function converting an integer into the expected kind
of counter for the 'function' attribute.
"""
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
return self.function(obj, self.type(sequence))
@@ -204,7 +215,7 @@ class ContainerAttribute(OrderedDeclaration):
self.function = function
self.strict = strict
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
"""Evaluate the current ContainerAttribute.
Args:
@@ -237,11 +248,20 @@ class ParameteredAttribute(OrderedDeclaration):
CONTAINERS_FIELD = '__containers'
+ # Whether to add the current object to the stack of containers
+ EXTEND_CONTAINERS = False
+
def __init__(self, **kwargs):
super(ParameteredAttribute, self).__init__()
self.defaults = kwargs
- def evaluate(self, create, extra, containers):
+ def _prepare_containers(self, obj, containers=()):
+ if self.EXTEND_CONTAINERS:
+ return (obj,) + tuple(containers)
+
+ return containers
+
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
"""Evaluate the current definition and fill its attributes.
Uses attributes definition in the following order:
@@ -260,6 +280,7 @@ class ParameteredAttribute(OrderedDeclaration):
if extra:
defaults.update(extra)
if self.CONTAINERS_FIELD:
+ containers = self._prepare_containers(obj, containers)
defaults[self.CONTAINERS_FIELD] = containers
return self.generate(create, defaults)
@@ -288,6 +309,8 @@ class SubFactory(ParameteredAttribute):
factory (base.Factory): the wrapped factory
"""
+ EXTEND_CONTAINERS = True
+
def __init__(self, factory, **kwargs):
super(SubFactory, self).__init__(**kwargs)
if isinstance(factory, type):
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index b7ae344..7b9b0af 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -33,7 +33,7 @@ from . import tools
class OrderedDeclarationTestCase(unittest.TestCase):
def test_errors(self):
decl = declarations.OrderedDeclaration()
- self.assertRaises(NotImplementedError, decl.evaluate, None, {})
+ self.assertRaises(NotImplementedError, decl.evaluate, None, {}, False)
class DigTestCase(unittest.TestCase):
@@ -95,23 +95,23 @@ class SelfAttributeTestCase(unittest.TestCase):
class IteratorTestCase(unittest.TestCase):
def test_cycle(self):
it = declarations.Iterator([1, 2])
- self.assertEqual(1, it.evaluate(0, None))
- self.assertEqual(2, it.evaluate(1, None))
- self.assertEqual(1, it.evaluate(2, None))
- self.assertEqual(2, it.evaluate(3, None))
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertEqual(1, it.evaluate(2, None, False))
+ self.assertEqual(2, it.evaluate(3, None, False))
def test_no_cycling(self):
it = declarations.Iterator([1, 2], cycle=False)
- self.assertEqual(1, it.evaluate(0, None))
- self.assertEqual(2, it.evaluate(1, None))
- self.assertRaises(StopIteration, it.evaluate, 2, None)
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertRaises(StopIteration, it.evaluate, 2, None, False)
def test_getter(self):
it = declarations.Iterator([(1, 2), (1, 3)], getter=lambda p: p[1])
- self.assertEqual(2, it.evaluate(0, None))
- self.assertEqual(3, it.evaluate(1, None))
- self.assertEqual(2, it.evaluate(2, None))
- self.assertEqual(3, it.evaluate(3, None))
+ self.assertEqual(2, it.evaluate(0, None, False))
+ self.assertEqual(3, it.evaluate(1, None, False))
+ self.assertEqual(2, it.evaluate(2, None, False))
+ self.assertEqual(3, it.evaluate(3, None, False))
class PostGenerationDeclarationTestCase(unittest.TestCase):