diff options
Diffstat (limited to 'factory/declarations.py')
-rw-r--r-- | factory/declarations.py | 140 |
1 files changed, 130 insertions, 10 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 4a5bf97..41d99a3 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -20,6 +20,10 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. + +import itertools + + class OrderedDeclaration(object): """A factory declaration. @@ -28,7 +32,7 @@ class OrderedDeclaration(object): in the same factory. """ - def evaluate(self, sequence, obj): + def evaluate(self, sequence, obj, containers=()): """Evaluate this declaration. Args: @@ -36,6 +40,8 @@ class OrderedDeclaration(object): the current instance obj (containers.LazyStub): The object holding currently computed attributes + containers (list of containers.LazyStub): The chain of SubFactory + which led to building this object. """ raise NotImplementedError('This is an abstract method') @@ -52,25 +58,87 @@ class LazyAttribute(OrderedDeclaration): super(LazyAttribute, self).__init__(*args, **kwargs) self.function = function - def evaluate(self, sequence, obj): + def evaluate(self, sequence, obj, containers=()): return self.function(obj) +class _UNSPECIFIED(object): + pass + + +def deepgetattr(obj, name, default=_UNSPECIFIED): + """Try to retrieve the given attribute of an object, digging on '.'. + + This is an extended getattr, digging deeper if '.' is found. + + Args: + obj (object): the object of which an attribute should be read + name (str): the name of an attribute to look up. + default (object): the default value to use if the attribute wasn't found + + Returns: + the attribute pointed to by 'name', splitting on '.'. + + Raises: + AttributeError: if obj has no 'name' attribute. + """ + try: + if '.' in name: + attr, subname = name.split('.', 1) + return deepgetattr(getattr(obj, attr), subname, default) + else: + return getattr(obj, name) + except AttributeError: + if default is _UNSPECIFIED: + raise + else: + return default + + class SelfAttribute(OrderedDeclaration): """Specific OrderedDeclaration copying values from other fields. Attributes: attribute_name (str): the name of the attribute to copy. + default (object): the default value to use if the attribute doesn't + exist. """ - def __init__(self, attribute_name, *args, **kwargs): + def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs): super(SelfAttribute, self).__init__(*args, **kwargs) self.attribute_name = attribute_name + self.default = default + + def evaluate(self, sequence, obj, containers=()): + return deepgetattr(obj, self.attribute_name, self.default) + + +class Iterator(OrderedDeclaration): + """Fill this value using the values returned by an iterator. + + Warning: the iterator should not end ! + + Attributes: + iterator (iterable): the iterator whose value should be used. + """ + + def __init__(self, iterator): + super(Iterator, self).__init__() + self.iterator = iter(iterator) + + def evaluate(self, sequence, obj, containers=()): + return self.iterator.next() - def evaluate(self, sequence, obj): - # TODO(rbarrois): allow the use of ATTR_SPLITTER to fetch fields of - # subfactories. - return getattr(obj, self.attribute_name) + +class InfiniteIterator(Iterator): + """Same as Iterator, but make the iterator infinite by cycling at the end. + + Attributes: + iterator (iterable): the iterator, once made infinite. + """ + + def __init__(self, iterator): + return super(InfiniteIterator, self).__init__(itertools.cycle(iterator)) class Sequence(OrderedDeclaration): @@ -89,7 +157,7 @@ class Sequence(OrderedDeclaration): self.function = function self.type = type - def evaluate(self, sequence, obj): + def evaluate(self, sequence, obj, containers=()): return self.function(self.type(sequence)) @@ -102,10 +170,42 @@ class LazyAttributeSequence(Sequence): type (function): A function converting an integer into the expected kind of counter for the 'function' attribute. """ - def evaluate(self, sequence, obj): + def evaluate(self, sequence, obj, containers=()): return self.function(obj, self.type(sequence)) +class ContainerAttribute(OrderedDeclaration): + """Variant of LazyAttribute, also receives the containers of the object. + + Attributes: + function (function): A function, expecting the current LazyStub and the + (optional) object having a subfactory containing this attribute. + strict (bool): Whether evaluating should fail when the containers are + not passed in (i.e used outside a SubFactory). + """ + def __init__(self, function, strict=True, *args, **kwargs): + super(ContainerAttribute, self).__init__(*args, **kwargs) + self.function = function + self.strict = strict + + def evaluate(self, sequence, obj, containers=()): + """Evaluate the current ContainerAttribute. + + Args: + obj (LazyStub): a lazy stub of the object being constructed, if + needed. + containers (list of LazyStub): a list of lazy stubs of factories + being evaluated in a chain, each item being a future field of + next one. + """ + if self.strict and not containers: + raise TypeError( + "A ContainerAttribute in 'strict' mode can only be used " + "within a SubFactory.") + + return self.function(obj, containers) + + class SubFactory(OrderedDeclaration): """Base class for attributes based upon a sub-factory. @@ -120,18 +220,27 @@ class SubFactory(OrderedDeclaration): self.defaults = kwargs self.factory = factory - def evaluate(self, create, extra): + def evaluate(self, create, extra, containers): """Evaluate the current definition and fill its attributes. Uses attributes definition in the following order: - attributes defined in the wrapped factory class - values defined when defining the SubFactory - additional values defined in attributes + + Args: + create (bool): whether the subfactory should call 'build' or + 'create' + extra (containers.DeclarationDict): extra values that should + override the wrapped factory's defaults + containers (list of LazyStub): List of LazyStub for the chain of + factories being evaluated, the calling stub being first. """ defaults = dict(self.defaults) if extra: defaults.update(extra) + defaults['__containers'] = containers attrs = self.factory.attributes(create, defaults) @@ -146,8 +255,19 @@ class SubFactory(OrderedDeclaration): def lazy_attribute(func): return LazyAttribute(func) +def iterator(func): + """Turn a generator function into an iterator attribute.""" + return Iterator(func()) + +def infinite_iterator(func): + """Turn a generator function into an infinite iterator attribute.""" + return InfiniteIterator(func()) + def sequence(func): return Sequence(func) def lazy_attribute_sequence(func): return LazyAttributeSequence(func) + +def container_attribute(func): + return ContainerAttribute(func, strict=False) |