summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
authorThomas Goirand <thomas@goirand.fr>2013-05-12 05:32:34 +0000
committerThomas Goirand <thomas@goirand.fr>2013-05-12 05:32:34 +0000
commit28991f9514e3cd78a528bbbe956d9b4536c416e0 (patch)
treea3871392d2382f60490824d79058f8a71ae1c34e /factory/declarations.py
parent57fa2e21aed37c1af2a87f36a998046b73092a21 (diff)
parent876845102c4a217496d0f6435bfe1e3726d31fe4 (diff)
downloadfactory-boy-28991f9514e3cd78a528bbbe956d9b4536c416e0.tar
factory-boy-28991f9514e3cd78a528bbbe956d9b4536c416e0.tar.gz
Merge tag '2.0.2' into debian/unstable
Release of factory_boy 2.0.2 Conflicts: docs/changelog.rst docs/index.rst docs/subfactory.rst tests/cyclic/bar.py tests/cyclic/foo.py
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py282
1 files changed, 197 insertions, 85 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 83c32ab..969d780 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,7 +23,8 @@
import itertools
-from factory import utils
+from . import compat
+from . import utils
class OrderedDeclaration(object):
@@ -34,7 +35,7 @@ class OrderedDeclaration(object):
in the same factory.
"""
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
"""Evaluate this declaration.
Args:
@@ -44,6 +45,10 @@ class OrderedDeclaration(object):
attributes
containers (list of containers.LazyStub): The chain of SubFactory
which led to building this object.
+ create (bool): whether the target class should be 'built' or
+ 'created'
+ extra (DeclarationDict or None): extracted key/value extracted from
+ the attribute prefix
"""
raise NotImplementedError('This is an abstract method')
@@ -60,7 +65,7 @@ class LazyAttribute(OrderedDeclaration):
super(LazyAttribute, self).__init__(*args, **kwargs)
self.function = function
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
return self.function(obj)
@@ -100,7 +105,11 @@ def deepgetattr(obj, name, default=_UNSPECIFIED):
class SelfAttribute(OrderedDeclaration):
"""Specific OrderedDeclaration copying values from other fields.
+ If the field name starts with two dots or more, the lookup will be anchored
+ in the related 'parent'.
+
Attributes:
+ depth (int): the number of steps to go up in the containers chain
attribute_name (str): the name of the attribute to copy.
default (object): the default value to use if the attribute doesn't
exist.
@@ -108,11 +117,27 @@ class SelfAttribute(OrderedDeclaration):
def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super(SelfAttribute, self).__init__(*args, **kwargs)
+ depth = len(attribute_name) - len(attribute_name.lstrip('.'))
+ attribute_name = attribute_name[depth:]
+
+ self.depth = depth
self.attribute_name = attribute_name
self.default = default
- def evaluate(self, sequence, obj, containers=()):
- return deepgetattr(obj, self.attribute_name, self.default)
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
+ if self.depth > 1:
+ # Fetching from a parent
+ target = containers[self.depth - 2]
+ else:
+ target = obj
+ return deepgetattr(target, self.attribute_name, self.default)
+
+ def __repr__(self):
+ return '<%s(%r, default=%r)>' % (
+ self.__class__.__name__,
+ self.attribute_name,
+ self.default,
+ )
class Iterator(OrderedDeclaration):
@@ -122,25 +147,23 @@ class Iterator(OrderedDeclaration):
Attributes:
iterator (iterable): the iterator whose value should be used.
+ getter (callable or None): a function to parse returned values
"""
- def __init__(self, iterator):
+ def __init__(self, iterator, cycle=True, getter=None):
super(Iterator, self).__init__()
- self.iterator = iter(iterator)
-
- def evaluate(self, sequence, obj, containers=()):
- return self.iterator.next()
+ self.getter = getter
+ if cycle:
+ self.iterator = itertools.cycle(iterator)
+ else:
+ self.iterator = iter(iterator)
-class InfiniteIterator(Iterator):
- """Same as Iterator, but make the iterator infinite by cycling at the end.
-
- Attributes:
- iterator (iterable): the iterator, once made infinite.
- """
-
- def __init__(self, iterator):
- return super(InfiniteIterator, self).__init__(itertools.cycle(iterator))
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
+ value = next(self.iterator)
+ if self.getter is None:
+ return value
+ return self.getter(value)
class Sequence(OrderedDeclaration):
@@ -154,12 +177,12 @@ class Sequence(OrderedDeclaration):
type (function): A function converting an integer into the expected kind
of counter for the 'function' attribute.
"""
- def __init__(self, function, type=str):
+ def __init__(self, function, type=int): # pylint: disable=W0622
super(Sequence, self).__init__()
self.function = function
self.type = type
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
return self.function(self.type(sequence))
@@ -172,7 +195,7 @@ class LazyAttributeSequence(Sequence):
type (function): A function converting an integer into the expected kind
of counter for the 'function' attribute.
"""
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
return self.function(obj, self.type(sequence))
@@ -190,7 +213,7 @@ class ContainerAttribute(OrderedDeclaration):
self.function = function
self.strict = strict
- def evaluate(self, sequence, obj, containers=()):
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
"""Evaluate the current ContainerAttribute.
Args:
@@ -217,12 +240,23 @@ class SubFactory(OrderedDeclaration):
factory (base.Factory): the wrapped factory
"""
- def __init__(self, factory, **kwargs):
- super(SubFactory, self).__init__()
+ CONTAINERS_FIELD = '__containers'
+
+ # Whether to add the current object to the stack of containers
+ EXTEND_CONTAINERS = False
+
+ def __init__(self, **kwargs):
+ super(ParameteredAttribute, self).__init__()
self.defaults = kwargs
self.factory = factory
- def evaluate(self, create, extra, containers):
+ def _prepare_containers(self, obj, containers=()):
+ if self.EXTEND_CONTAINERS:
+ return (obj,) + tuple(containers)
+
+ return containers
+
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
"""Evaluate the current definition and fill its attributes.
Uses attributes definition in the following order:
@@ -242,25 +276,105 @@ class SubFactory(OrderedDeclaration):
defaults = dict(self.defaults)
if extra:
defaults.update(extra)
- defaults['__containers'] = containers
+ if self.CONTAINERS_FIELD:
+ containers = self._prepare_containers(obj, containers)
+ defaults[self.CONTAINERS_FIELD] = containers
- if create:
- return self.factory.create(**defaults)
- else:
- return self.factory.build(**defaults)
+ return self.generate(sequence, obj, create, defaults)
+ def generate(self, sequence, obj, create, params): # pragma: no cover
+ """Actually generate the related attribute.
-class PostGenerationDeclaration(object):
- """Declarations to be called once the target object has been generated.
+ Args:
+ sequence (int): the current sequence number
+ obj (LazyStub): the object being constructed
+ create (bool): whether the calling factory was in 'create' or
+ 'build' mode
+ params (dict): parameters inherited from init and evaluation-time
+ overrides.
+
+ Returns:
+ Computed value for the current declaration.
+ """
+ raise NotImplementedError()
+
+
+class SubFactory(ParameteredAttribute):
+ """Base class for attributes based upon a sub-factory.
Attributes:
- extract_prefix (str): prefix to use when extracting attributes from
- the factory's declaration for this declaration. If empty, uses
- the attribute name of the PostGenerationDeclaration.
+ defaults (dict): Overrides to the defaults defined in the wrapped
+ factory
+ factory (base.Factory): the wrapped factory
"""
- def __init__(self, extract_prefix=None):
- self.extract_prefix = extract_prefix
+ EXTEND_CONTAINERS = True
+
+ def __init__(self, factory, **kwargs):
+ super(SubFactory, self).__init__(**kwargs)
+ if isinstance(factory, type):
+ self.factory = factory
+ self.factory_module = self.factory_name = ''
+ else:
+ # Must be a string
+ if not (compat.is_string(factory) and '.' in factory):
+ raise ValueError(
+ "The argument of a SubFactory must be either a class "
+ "or the fully qualified path to a Factory class; got "
+ "%r instead." % factory)
+ self.factory = None
+ self.factory_module, self.factory_name = factory.rsplit('.', 1)
+
+ def get_factory(self):
+ """Retrieve the wrapped factory.Factory subclass."""
+ if self.factory is None:
+ # Must be a module path
+ self.factory = utils.import_object(
+ self.factory_module, self.factory_name)
+ return self.factory
+
+ def generate(self, sequence, obj, create, params):
+ """Evaluate the current definition and fill its attributes.
+
+ Args:
+ create (bool): whether the subfactory should call 'build' or
+ 'create'
+ params (containers.DeclarationDict): extra values that should
+ override the wrapped factory's defaults
+ """
+ subfactory = self.get_factory()
+ return subfactory.simple_generate(create, **params)
+
+
+class Dict(SubFactory):
+ """Fill a dict with usual declarations."""
+
+ def __init__(self, params, dict_factory='factory.DictFactory'):
+ super(Dict, self).__init__(dict_factory, **dict(params))
+
+ def generate(self, sequence, obj, create, params):
+ dict_factory = self.get_factory()
+ return dict_factory.simple_generate(create,
+ __sequence=sequence,
+ **params)
+
+
+class List(SubFactory):
+ """Fill a list with standard declarations."""
+
+ def __init__(self, params, list_factory='factory.ListFactory'):
+ params = dict((str(i), v) for i, v in enumerate(params))
+ super(List, self).__init__(list_factory, **params)
+
+ def generate(self, sequence, obj, create, params):
+ list_factory = self.get_factory()
+ return list_factory.simple_generate(create,
+ __sequence=sequence,
+ **params)
+
+
+class PostGenerationDeclaration(object):
+ """Declarations to be called once the target object has been generated."""
def extract(self, name, attrs):
"""Extract relevant attributes from a dict.
@@ -275,12 +389,8 @@ class PostGenerationDeclaration(object):
(object, dict): a tuple containing the attribute at 'name' (if
provided) and a dict of extracted attributes
"""
- if self.extract_prefix:
- extract_prefix = self.extract_prefix
- else:
- extract_prefix = name
- extracted = attrs.pop(extract_prefix, None)
- kwargs = utils.extract_dict(extract_prefix, attrs)
+ extracted = attrs.pop(name, None)
+ kwargs = utils.extract_dict(name, attrs)
return extracted, kwargs
def call(self, obj, create, extracted=None, **kwargs):
@@ -289,7 +399,7 @@ class PostGenerationDeclaration(object):
Args:
obj (object): the newly generated object
create (bool): whether the object was 'built' or 'created'
- extracted (object): the value given for <extract_prefix> in the
+ extracted (object): the value given for <name> in the
object definition, or None if not provided.
kwargs (dict): declarations extracted from the object
definition for this hook
@@ -299,18 +409,12 @@ class PostGenerationDeclaration(object):
class PostGeneration(PostGenerationDeclaration):
"""Calls a given function once the object has been generated."""
- def __init__(self, function, extract_prefix=None):
- super(PostGeneration, self).__init__(extract_prefix)
+ def __init__(self, function):
+ super(PostGeneration, self).__init__()
self.function = function
def call(self, obj, create, extracted=None, **kwargs):
- self.function(obj, create, extracted, **kwargs)
-
-
-def post_generation(extract_prefix=None):
- def decorator(fun):
- return PostGeneration(fun, extract_prefix=extract_prefix)
- return decorator
+ return self.function(obj, create, extracted, **kwargs)
class RelatedFactory(PostGenerationDeclaration):
@@ -324,17 +428,39 @@ class RelatedFactory(PostGenerationDeclaration):
"""
def __init__(self, factory, name='', **defaults):
- super(RelatedFactory, self).__init__(extract_prefix=None)
- self.factory = factory
+ super(RelatedFactory, self).__init__()
self.name = name
self.defaults = defaults
+ if isinstance(factory, type):
+ self.factory = factory
+ self.factory_module = self.factory_name = ''
+ else:
+ # Must be a string
+ if not (compat.is_string(factory) and '.' in factory):
+ raise ValueError(
+ "The argument of a SubFactory must be either a class "
+ "or the fully qualified path to a Factory class; got "
+ "%r instead." % factory)
+ self.factory = None
+ self.factory_module, self.factory_name = factory.rsplit('.', 1)
+
+ def get_factory(self):
+ """Retrieve the wrapped factory.Factory subclass."""
+ if self.factory is None:
+ # Must be a module path
+ self.factory = utils.import_object(
+ self.factory_module, self.factory_name)
+ return self.factory
+
def call(self, obj, create, extracted=None, **kwargs):
passed_kwargs = dict(self.defaults)
passed_kwargs.update(kwargs)
if self.name:
passed_kwargs[self.name] = obj
- self.factory.simple_generate(create, **passed_kwargs)
+
+ factory = self.get_factory()
+ factory.simple_generate(create, **passed_kwargs)
class PostGenerationMethodCall(PostGenerationDeclaration):
@@ -348,39 +474,25 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
Example:
class UserFactory(factory.Factory):
...
- password = factory.PostGenerationMethodCall('set_password', password='')
+ password = factory.PostGenerationMethodCall('set_pass', password='')
"""
- def __init__(self, method_name, extract_prefix=None, *args, **kwargs):
- super(PostGenerationMethodCall, self).__init__(extract_prefix)
+ def __init__(self, method_name, *args, **kwargs):
+ super(PostGenerationMethodCall, self).__init__()
self.method_name = method_name
self.method_args = args
self.method_kwargs = kwargs
def call(self, obj, create, extracted=None, **kwargs):
+ if extracted is None:
+ passed_args = self.method_args
+
+ elif len(self.method_args) <= 1:
+ # Max one argument expected
+ passed_args = (extracted,)
+ else:
+ passed_args = tuple(extracted)
+
passed_kwargs = dict(self.method_kwargs)
passed_kwargs.update(kwargs)
method = getattr(obj, self.method_name)
- method(*self.method_args, **passed_kwargs)
-
-
-# Decorators... in case lambdas don't cut it
-
-def lazy_attribute(func):
- return LazyAttribute(func)
-
-def iterator(func):
- """Turn a generator function into an iterator attribute."""
- return Iterator(func())
-
-def infinite_iterator(func):
- """Turn a generator function into an infinite iterator attribute."""
- return InfiniteIterator(func())
-
-def sequence(func):
- return Sequence(func)
-
-def lazy_attribute_sequence(func):
- return LazyAttributeSequence(func)
-
-def container_attribute(func):
- return ContainerAttribute(func, strict=False)
+ method(*passed_args, **passed_kwargs)