diff options
author | Thomas Goirand <thomas@goirand.fr> | 2013-05-12 05:32:34 +0000 |
---|---|---|
committer | Thomas Goirand <thomas@goirand.fr> | 2013-05-12 05:32:34 +0000 |
commit | 28991f9514e3cd78a528bbbe956d9b4536c416e0 (patch) | |
tree | a3871392d2382f60490824d79058f8a71ae1c34e /factory/declarations.py | |
parent | 57fa2e21aed37c1af2a87f36a998046b73092a21 (diff) | |
parent | 876845102c4a217496d0f6435bfe1e3726d31fe4 (diff) | |
download | factory-boy-28991f9514e3cd78a528bbbe956d9b4536c416e0.tar factory-boy-28991f9514e3cd78a528bbbe956d9b4536c416e0.tar.gz |
Merge tag '2.0.2' into debian/unstable
Release of factory_boy 2.0.2
Conflicts:
docs/changelog.rst
docs/index.rst
docs/subfactory.rst
tests/cyclic/bar.py
tests/cyclic/foo.py
Diffstat (limited to 'factory/declarations.py')
-rw-r--r-- | factory/declarations.py | 282 |
1 files changed, 197 insertions, 85 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 83c32ab..969d780 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright (c) 2010 Mark Sandstrom -# Copyright (c) 2011 Raphaël Barrois +# Copyright (c) 2011-2013 Raphaël Barrois # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,7 +23,8 @@ import itertools -from factory import utils +from . import compat +from . import utils class OrderedDeclaration(object): @@ -34,7 +35,7 @@ class OrderedDeclaration(object): in the same factory. """ - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Evaluate this declaration. Args: @@ -44,6 +45,10 @@ class OrderedDeclaration(object): attributes containers (list of containers.LazyStub): The chain of SubFactory which led to building this object. + create (bool): whether the target class should be 'built' or + 'created' + extra (DeclarationDict or None): extracted key/value extracted from + the attribute prefix """ raise NotImplementedError('This is an abstract method') @@ -60,7 +65,7 @@ class LazyAttribute(OrderedDeclaration): super(LazyAttribute, self).__init__(*args, **kwargs) self.function = function - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): return self.function(obj) @@ -100,7 +105,11 @@ def deepgetattr(obj, name, default=_UNSPECIFIED): class SelfAttribute(OrderedDeclaration): """Specific OrderedDeclaration copying values from other fields. + If the field name starts with two dots or more, the lookup will be anchored + in the related 'parent'. + Attributes: + depth (int): the number of steps to go up in the containers chain attribute_name (str): the name of the attribute to copy. default (object): the default value to use if the attribute doesn't exist. @@ -108,11 +117,27 @@ class SelfAttribute(OrderedDeclaration): def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs): super(SelfAttribute, self).__init__(*args, **kwargs) + depth = len(attribute_name) - len(attribute_name.lstrip('.')) + attribute_name = attribute_name[depth:] + + self.depth = depth self.attribute_name = attribute_name self.default = default - def evaluate(self, sequence, obj, containers=()): - return deepgetattr(obj, self.attribute_name, self.default) + def evaluate(self, sequence, obj, create, extra=None, containers=()): + if self.depth > 1: + # Fetching from a parent + target = containers[self.depth - 2] + else: + target = obj + return deepgetattr(target, self.attribute_name, self.default) + + def __repr__(self): + return '<%s(%r, default=%r)>' % ( + self.__class__.__name__, + self.attribute_name, + self.default, + ) class Iterator(OrderedDeclaration): @@ -122,25 +147,23 @@ class Iterator(OrderedDeclaration): Attributes: iterator (iterable): the iterator whose value should be used. + getter (callable or None): a function to parse returned values """ - def __init__(self, iterator): + def __init__(self, iterator, cycle=True, getter=None): super(Iterator, self).__init__() - self.iterator = iter(iterator) - - def evaluate(self, sequence, obj, containers=()): - return self.iterator.next() + self.getter = getter + if cycle: + self.iterator = itertools.cycle(iterator) + else: + self.iterator = iter(iterator) -class InfiniteIterator(Iterator): - """Same as Iterator, but make the iterator infinite by cycling at the end. - - Attributes: - iterator (iterable): the iterator, once made infinite. - """ - - def __init__(self, iterator): - return super(InfiniteIterator, self).__init__(itertools.cycle(iterator)) + def evaluate(self, sequence, obj, create, extra=None, containers=()): + value = next(self.iterator) + if self.getter is None: + return value + return self.getter(value) class Sequence(OrderedDeclaration): @@ -154,12 +177,12 @@ class Sequence(OrderedDeclaration): type (function): A function converting an integer into the expected kind of counter for the 'function' attribute. """ - def __init__(self, function, type=str): + def __init__(self, function, type=int): # pylint: disable=W0622 super(Sequence, self).__init__() self.function = function self.type = type - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): return self.function(self.type(sequence)) @@ -172,7 +195,7 @@ class LazyAttributeSequence(Sequence): type (function): A function converting an integer into the expected kind of counter for the 'function' attribute. """ - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): return self.function(obj, self.type(sequence)) @@ -190,7 +213,7 @@ class ContainerAttribute(OrderedDeclaration): self.function = function self.strict = strict - def evaluate(self, sequence, obj, containers=()): + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Evaluate the current ContainerAttribute. Args: @@ -217,12 +240,23 @@ class SubFactory(OrderedDeclaration): factory (base.Factory): the wrapped factory """ - def __init__(self, factory, **kwargs): - super(SubFactory, self).__init__() + CONTAINERS_FIELD = '__containers' + + # Whether to add the current object to the stack of containers + EXTEND_CONTAINERS = False + + def __init__(self, **kwargs): + super(ParameteredAttribute, self).__init__() self.defaults = kwargs self.factory = factory - def evaluate(self, create, extra, containers): + def _prepare_containers(self, obj, containers=()): + if self.EXTEND_CONTAINERS: + return (obj,) + tuple(containers) + + return containers + + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Evaluate the current definition and fill its attributes. Uses attributes definition in the following order: @@ -242,25 +276,105 @@ class SubFactory(OrderedDeclaration): defaults = dict(self.defaults) if extra: defaults.update(extra) - defaults['__containers'] = containers + if self.CONTAINERS_FIELD: + containers = self._prepare_containers(obj, containers) + defaults[self.CONTAINERS_FIELD] = containers - if create: - return self.factory.create(**defaults) - else: - return self.factory.build(**defaults) + return self.generate(sequence, obj, create, defaults) + def generate(self, sequence, obj, create, params): # pragma: no cover + """Actually generate the related attribute. -class PostGenerationDeclaration(object): - """Declarations to be called once the target object has been generated. + Args: + sequence (int): the current sequence number + obj (LazyStub): the object being constructed + create (bool): whether the calling factory was in 'create' or + 'build' mode + params (dict): parameters inherited from init and evaluation-time + overrides. + + Returns: + Computed value for the current declaration. + """ + raise NotImplementedError() + + +class SubFactory(ParameteredAttribute): + """Base class for attributes based upon a sub-factory. Attributes: - extract_prefix (str): prefix to use when extracting attributes from - the factory's declaration for this declaration. If empty, uses - the attribute name of the PostGenerationDeclaration. + defaults (dict): Overrides to the defaults defined in the wrapped + factory + factory (base.Factory): the wrapped factory """ - def __init__(self, extract_prefix=None): - self.extract_prefix = extract_prefix + EXTEND_CONTAINERS = True + + def __init__(self, factory, **kwargs): + super(SubFactory, self).__init__(**kwargs) + if isinstance(factory, type): + self.factory = factory + self.factory_module = self.factory_name = '' + else: + # Must be a string + if not (compat.is_string(factory) and '.' in factory): + raise ValueError( + "The argument of a SubFactory must be either a class " + "or the fully qualified path to a Factory class; got " + "%r instead." % factory) + self.factory = None + self.factory_module, self.factory_name = factory.rsplit('.', 1) + + def get_factory(self): + """Retrieve the wrapped factory.Factory subclass.""" + if self.factory is None: + # Must be a module path + self.factory = utils.import_object( + self.factory_module, self.factory_name) + return self.factory + + def generate(self, sequence, obj, create, params): + """Evaluate the current definition and fill its attributes. + + Args: + create (bool): whether the subfactory should call 'build' or + 'create' + params (containers.DeclarationDict): extra values that should + override the wrapped factory's defaults + """ + subfactory = self.get_factory() + return subfactory.simple_generate(create, **params) + + +class Dict(SubFactory): + """Fill a dict with usual declarations.""" + + def __init__(self, params, dict_factory='factory.DictFactory'): + super(Dict, self).__init__(dict_factory, **dict(params)) + + def generate(self, sequence, obj, create, params): + dict_factory = self.get_factory() + return dict_factory.simple_generate(create, + __sequence=sequence, + **params) + + +class List(SubFactory): + """Fill a list with standard declarations.""" + + def __init__(self, params, list_factory='factory.ListFactory'): + params = dict((str(i), v) for i, v in enumerate(params)) + super(List, self).__init__(list_factory, **params) + + def generate(self, sequence, obj, create, params): + list_factory = self.get_factory() + return list_factory.simple_generate(create, + __sequence=sequence, + **params) + + +class PostGenerationDeclaration(object): + """Declarations to be called once the target object has been generated.""" def extract(self, name, attrs): """Extract relevant attributes from a dict. @@ -275,12 +389,8 @@ class PostGenerationDeclaration(object): (object, dict): a tuple containing the attribute at 'name' (if provided) and a dict of extracted attributes """ - if self.extract_prefix: - extract_prefix = self.extract_prefix - else: - extract_prefix = name - extracted = attrs.pop(extract_prefix, None) - kwargs = utils.extract_dict(extract_prefix, attrs) + extracted = attrs.pop(name, None) + kwargs = utils.extract_dict(name, attrs) return extracted, kwargs def call(self, obj, create, extracted=None, **kwargs): @@ -289,7 +399,7 @@ class PostGenerationDeclaration(object): Args: obj (object): the newly generated object create (bool): whether the object was 'built' or 'created' - extracted (object): the value given for <extract_prefix> in the + extracted (object): the value given for <name> in the object definition, or None if not provided. kwargs (dict): declarations extracted from the object definition for this hook @@ -299,18 +409,12 @@ class PostGenerationDeclaration(object): class PostGeneration(PostGenerationDeclaration): """Calls a given function once the object has been generated.""" - def __init__(self, function, extract_prefix=None): - super(PostGeneration, self).__init__(extract_prefix) + def __init__(self, function): + super(PostGeneration, self).__init__() self.function = function def call(self, obj, create, extracted=None, **kwargs): - self.function(obj, create, extracted, **kwargs) - - -def post_generation(extract_prefix=None): - def decorator(fun): - return PostGeneration(fun, extract_prefix=extract_prefix) - return decorator + return self.function(obj, create, extracted, **kwargs) class RelatedFactory(PostGenerationDeclaration): @@ -324,17 +428,39 @@ class RelatedFactory(PostGenerationDeclaration): """ def __init__(self, factory, name='', **defaults): - super(RelatedFactory, self).__init__(extract_prefix=None) - self.factory = factory + super(RelatedFactory, self).__init__() self.name = name self.defaults = defaults + if isinstance(factory, type): + self.factory = factory + self.factory_module = self.factory_name = '' + else: + # Must be a string + if not (compat.is_string(factory) and '.' in factory): + raise ValueError( + "The argument of a SubFactory must be either a class " + "or the fully qualified path to a Factory class; got " + "%r instead." % factory) + self.factory = None + self.factory_module, self.factory_name = factory.rsplit('.', 1) + + def get_factory(self): + """Retrieve the wrapped factory.Factory subclass.""" + if self.factory is None: + # Must be a module path + self.factory = utils.import_object( + self.factory_module, self.factory_name) + return self.factory + def call(self, obj, create, extracted=None, **kwargs): passed_kwargs = dict(self.defaults) passed_kwargs.update(kwargs) if self.name: passed_kwargs[self.name] = obj - self.factory.simple_generate(create, **passed_kwargs) + + factory = self.get_factory() + factory.simple_generate(create, **passed_kwargs) class PostGenerationMethodCall(PostGenerationDeclaration): @@ -348,39 +474,25 @@ class PostGenerationMethodCall(PostGenerationDeclaration): Example: class UserFactory(factory.Factory): ... - password = factory.PostGenerationMethodCall('set_password', password='') + password = factory.PostGenerationMethodCall('set_pass', password='') """ - def __init__(self, method_name, extract_prefix=None, *args, **kwargs): - super(PostGenerationMethodCall, self).__init__(extract_prefix) + def __init__(self, method_name, *args, **kwargs): + super(PostGenerationMethodCall, self).__init__() self.method_name = method_name self.method_args = args self.method_kwargs = kwargs def call(self, obj, create, extracted=None, **kwargs): + if extracted is None: + passed_args = self.method_args + + elif len(self.method_args) <= 1: + # Max one argument expected + passed_args = (extracted,) + else: + passed_args = tuple(extracted) + passed_kwargs = dict(self.method_kwargs) passed_kwargs.update(kwargs) method = getattr(obj, self.method_name) - method(*self.method_args, **passed_kwargs) - - -# Decorators... in case lambdas don't cut it - -def lazy_attribute(func): - return LazyAttribute(func) - -def iterator(func): - """Turn a generator function into an iterator attribute.""" - return Iterator(func()) - -def infinite_iterator(func): - """Turn a generator function into an infinite iterator attribute.""" - return InfiniteIterator(func()) - -def sequence(func): - return Sequence(func) - -def lazy_attribute_sequence(func): - return LazyAttributeSequence(func) - -def container_attribute(func): - return ContainerAttribute(func, strict=False) + method(*passed_args, **passed_kwargs) |