summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py140
1 files changed, 130 insertions, 10 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 4a5bf97..41d99a3 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -20,6 +20,10 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+
+import itertools
+
+
class OrderedDeclaration(object):
"""A factory declaration.
@@ -28,7 +32,7 @@ class OrderedDeclaration(object):
in the same factory.
"""
- def evaluate(self, sequence, obj):
+ def evaluate(self, sequence, obj, containers=()):
"""Evaluate this declaration.
Args:
@@ -36,6 +40,8 @@ class OrderedDeclaration(object):
the current instance
obj (containers.LazyStub): The object holding currently computed
attributes
+ containers (list of containers.LazyStub): The chain of SubFactory
+ which led to building this object.
"""
raise NotImplementedError('This is an abstract method')
@@ -52,25 +58,87 @@ class LazyAttribute(OrderedDeclaration):
super(LazyAttribute, self).__init__(*args, **kwargs)
self.function = function
- def evaluate(self, sequence, obj):
+ def evaluate(self, sequence, obj, containers=()):
return self.function(obj)
+class _UNSPECIFIED(object):
+ pass
+
+
+def deepgetattr(obj, name, default=_UNSPECIFIED):
+ """Try to retrieve the given attribute of an object, digging on '.'.
+
+ This is an extended getattr, digging deeper if '.' is found.
+
+ Args:
+ obj (object): the object of which an attribute should be read
+ name (str): the name of an attribute to look up.
+ default (object): the default value to use if the attribute wasn't found
+
+ Returns:
+ the attribute pointed to by 'name', splitting on '.'.
+
+ Raises:
+ AttributeError: if obj has no 'name' attribute.
+ """
+ try:
+ if '.' in name:
+ attr, subname = name.split('.', 1)
+ return deepgetattr(getattr(obj, attr), subname, default)
+ else:
+ return getattr(obj, name)
+ except AttributeError:
+ if default is _UNSPECIFIED:
+ raise
+ else:
+ return default
+
+
class SelfAttribute(OrderedDeclaration):
"""Specific OrderedDeclaration copying values from other fields.
Attributes:
attribute_name (str): the name of the attribute to copy.
+ default (object): the default value to use if the attribute doesn't
+ exist.
"""
- def __init__(self, attribute_name, *args, **kwargs):
+ def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super(SelfAttribute, self).__init__(*args, **kwargs)
self.attribute_name = attribute_name
+ self.default = default
+
+ def evaluate(self, sequence, obj, containers=()):
+ return deepgetattr(obj, self.attribute_name, self.default)
+
+
+class Iterator(OrderedDeclaration):
+ """Fill this value using the values returned by an iterator.
+
+ Warning: the iterator should not end !
+
+ Attributes:
+ iterator (iterable): the iterator whose value should be used.
+ """
+
+ def __init__(self, iterator):
+ super(Iterator, self).__init__()
+ self.iterator = iter(iterator)
+
+ def evaluate(self, sequence, obj, containers=()):
+ return self.iterator.next()
- def evaluate(self, sequence, obj):
- # TODO(rbarrois): allow the use of ATTR_SPLITTER to fetch fields of
- # subfactories.
- return getattr(obj, self.attribute_name)
+
+class InfiniteIterator(Iterator):
+ """Same as Iterator, but make the iterator infinite by cycling at the end.
+
+ Attributes:
+ iterator (iterable): the iterator, once made infinite.
+ """
+
+ def __init__(self, iterator):
+ return super(InfiniteIterator, self).__init__(itertools.cycle(iterator))
class Sequence(OrderedDeclaration):
@@ -89,7 +157,7 @@ class Sequence(OrderedDeclaration):
self.function = function
self.type = type
- def evaluate(self, sequence, obj):
+ def evaluate(self, sequence, obj, containers=()):
return self.function(self.type(sequence))
@@ -102,10 +170,42 @@ class LazyAttributeSequence(Sequence):
type (function): A function converting an integer into the expected kind
of counter for the 'function' attribute.
"""
- def evaluate(self, sequence, obj):
+ def evaluate(self, sequence, obj, containers=()):
return self.function(obj, self.type(sequence))
+class ContainerAttribute(OrderedDeclaration):
+ """Variant of LazyAttribute, also receives the containers of the object.
+
+ Attributes:
+ function (function): A function, expecting the current LazyStub and the
+ (optional) object having a subfactory containing this attribute.
+ strict (bool): Whether evaluating should fail when the containers are
+ not passed in (i.e used outside a SubFactory).
+ """
+ def __init__(self, function, strict=True, *args, **kwargs):
+ super(ContainerAttribute, self).__init__(*args, **kwargs)
+ self.function = function
+ self.strict = strict
+
+ def evaluate(self, sequence, obj, containers=()):
+ """Evaluate the current ContainerAttribute.
+
+ Args:
+ obj (LazyStub): a lazy stub of the object being constructed, if
+ needed.
+ containers (list of LazyStub): a list of lazy stubs of factories
+ being evaluated in a chain, each item being a future field of
+ next one.
+ """
+ if self.strict and not containers:
+ raise TypeError(
+ "A ContainerAttribute in 'strict' mode can only be used "
+ "within a SubFactory.")
+
+ return self.function(obj, containers)
+
+
class SubFactory(OrderedDeclaration):
"""Base class for attributes based upon a sub-factory.
@@ -120,18 +220,27 @@ class SubFactory(OrderedDeclaration):
self.defaults = kwargs
self.factory = factory
- def evaluate(self, create, extra):
+ def evaluate(self, create, extra, containers):
"""Evaluate the current definition and fill its attributes.
Uses attributes definition in the following order:
- attributes defined in the wrapped factory class
- values defined when defining the SubFactory
- additional values defined in attributes
+
+ Args:
+ create (bool): whether the subfactory should call 'build' or
+ 'create'
+ extra (containers.DeclarationDict): extra values that should
+ override the wrapped factory's defaults
+ containers (list of LazyStub): List of LazyStub for the chain of
+ factories being evaluated, the calling stub being first.
"""
defaults = dict(self.defaults)
if extra:
defaults.update(extra)
+ defaults['__containers'] = containers
attrs = self.factory.attributes(create, defaults)
@@ -146,8 +255,19 @@ class SubFactory(OrderedDeclaration):
def lazy_attribute(func):
return LazyAttribute(func)
+def iterator(func):
+ """Turn a generator function into an iterator attribute."""
+ return Iterator(func())
+
+def infinite_iterator(func):
+ """Turn a generator function into an infinite iterator attribute."""
+ return InfiniteIterator(func())
+
def sequence(func):
return Sequence(func)
def lazy_attribute_sequence(func):
return LazyAttributeSequence(func)
+
+def container_attribute(func):
+ return ContainerAttribute(func, strict=False)