summaryrefslogtreecommitdiff
path: root/factory
diff options
context:
space:
mode:
Diffstat (limited to 'factory')
-rw-r--r--factory/__init__.py11
-rw-r--r--factory/alchemy.py3
-rw-r--r--factory/base.py75
-rw-r--r--factory/compat.py8
-rw-r--r--factory/containers.py116
-rw-r--r--factory/declarations.py66
-rw-r--r--factory/errors.py42
-rw-r--r--factory/fuzzy.py5
-rw-r--r--factory/utils.py6
9 files changed, 278 insertions, 54 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index 4a4a09f..ad9da80 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -20,7 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
-__version__ = '2.6.0'
+__version__ = '2.6.1'
__author__ = 'Raphaël Barrois <raphael.barrois+fboy@polytechnique.org>'
@@ -32,22 +32,27 @@ from .base import (
ListFactory,
StubFactory,
- FactoryError,
-
BUILD_STRATEGY,
CREATE_STRATEGY,
STUB_STRATEGY,
use_strategy,
)
+
+from .errors import (
+ FactoryError,
+)
+
from .faker import Faker
from .declarations import (
+ LazyFunction,
LazyAttribute,
Iterator,
Sequence,
LazyAttributeSequence,
SelfAttribute,
+ Trait,
ContainerAttribute,
SubFactory,
Dict,
diff --git a/factory/alchemy.py b/factory/alchemy.py
index 20da6cf..a9aab23 100644
--- a/factory/alchemy.py
+++ b/factory/alchemy.py
@@ -27,6 +27,7 @@ class SQLAlchemyOptions(base.FactoryOptions):
def _build_default_options(self):
return super(SQLAlchemyOptions, self)._build_default_options() + [
base.OptionDefault('sqlalchemy_session', None, inherit=True),
+ base.OptionDefault('force_flush', False, inherit=True),
]
@@ -43,4 +44,6 @@ class SQLAlchemyModelFactory(base.Factory):
session = cls._meta.sqlalchemy_session
obj = model_class(*args, **kwargs)
session.add(obj)
+ if cls._meta.force_flush:
+ session.flush()
return obj
diff --git a/factory/base.py b/factory/base.py
index 0f2af59..282e3b1 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -20,10 +20,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import collections
import logging
from . import containers
from . import declarations
+from . import errors
from . import utils
logger = logging.getLogger('factory.generate')
@@ -35,22 +37,6 @@ STUB_STRATEGY = 'stub'
-class FactoryError(Exception):
- """Any exception raised by factory_boy."""
-
-
-class AssociatedClassError(FactoryError):
- """Exception for Factory subclasses lacking Meta.model."""
-
-
-class UnknownStrategy(FactoryError):
- """Raised when a factory uses an unknown strategy."""
-
-
-class UnsupportedStrategy(FactoryError):
- """Raised when trying to use a strategy on an incompatible Factory."""
-
-
# Factory metaclasses
def get_factory_bases(bases):
@@ -82,7 +68,7 @@ class FactoryMetaClass(type):
elif cls._meta.strategy == STUB_STRATEGY:
return cls.stub(**kwargs)
else:
- raise UnknownStrategy('Unknown Meta.strategy: {0}'.format(
+ raise errors.UnknownStrategy('Unknown Meta.strategy: {0}'.format(
cls._meta.strategy))
def __new__(mcs, class_name, bases, attrs):
@@ -107,6 +93,7 @@ class FactoryMetaClass(type):
base_factory = None
attrs_meta = attrs.pop('Meta', None)
+ attrs_params = attrs.pop('Params', None)
base_meta = resolve_attribute('_meta', bases)
options_class = resolve_attribute('_options_class', bases, FactoryOptions)
@@ -121,6 +108,7 @@ class FactoryMetaClass(type):
meta=attrs_meta,
base_meta=base_meta,
base_factory=base_factory,
+ params=attrs_params,
)
return new_class
@@ -163,6 +151,8 @@ class FactoryOptions(object):
self.base_factory = None
self.declarations = {}
self.postgen_declarations = {}
+ self.parameters = {}
+ self.parameters_dependencies = {}
def _build_default_options(self):
""""Provide the default value for all allowed fields.
@@ -201,7 +191,7 @@ class FactoryOptions(object):
% (self.factory, ','.join(sorted(meta_attrs.keys()))))
def contribute_to_class(self, factory,
- meta=None, base_meta=None, base_factory=None):
+ meta=None, base_meta=None, base_factory=None, params=None):
self.factory = factory
self.base_factory = base_factory
@@ -219,6 +209,7 @@ class FactoryOptions(object):
continue
self.declarations.update(parent._meta.declarations)
self.postgen_declarations.update(parent._meta.postgen_declarations)
+ self.parameters.update(parent._meta.parameters)
for k, v in vars(self.factory).items():
if self._is_declaration(k, v):
@@ -226,6 +217,13 @@ class FactoryOptions(object):
if self._is_postgen_declaration(k, v):
self.postgen_declarations[k] = v
+ if params is not None:
+ for k, v in vars(params).items():
+ if not k.startswith('_'):
+ self.parameters[k] = v
+
+ self.parameters_dependencies = self._compute_parameter_dependencies(self.parameters)
+
def _get_counter_reference(self):
"""Identify which factory should be used for a shared counter."""
@@ -257,6 +255,32 @@ class FactoryOptions(object):
"""Captures instances of PostGenerationDeclaration."""
return isinstance(value, declarations.PostGenerationDeclaration)
+ def _compute_parameter_dependencies(self, parameters):
+ """Find out in what order parameters should be called."""
+ # Warning: parameters only provide reverse dependencies; we reverse them into standard dependencies.
+ # deep_revdeps: set of fields a field depend indirectly upon
+ deep_revdeps = collections.defaultdict(set)
+ # Actual, direct dependencies
+ deps = collections.defaultdict(set)
+
+ for name, parameter in parameters.items():
+ if isinstance(parameter, declarations.ComplexParameter):
+ field_revdeps = parameter.get_revdeps(parameters)
+ if not field_revdeps:
+ continue
+ deep_revdeps[name] = set.union(*(deep_revdeps[dep] for dep in field_revdeps))
+ deep_revdeps[name] |= set(field_revdeps)
+ for dep in field_revdeps:
+ deps[dep].add(name)
+
+ # Check for cyclical dependencies
+ cyclic = [name for name, field_deps in deep_revdeps.items() if name in field_deps]
+ if cyclic:
+ raise errors.CyclicDefinitionError(
+ "Cyclic definition detected on %s' Params around %s"
+ % (self.factory, ', '.join(cyclic)))
+ return deps
+
def __str__(self):
return "<%s for %s>" % (self.__class__.__name__, self.factory.__class__.__name__)
@@ -296,12 +320,12 @@ class BaseFactory(object):
"""Factory base support for sequences, attributes and stubs."""
# Backwards compatibility
- UnknownStrategy = UnknownStrategy
- UnsupportedStrategy = UnsupportedStrategy
+ UnknownStrategy = errors.UnknownStrategy
+ UnsupportedStrategy = errors.UnsupportedStrategy
def __new__(cls, *args, **kwargs):
"""Would be called if trying to instantiate the class."""
- raise FactoryError('You cannot instantiate BaseFactory')
+ raise errors.FactoryError('You cannot instantiate BaseFactory')
_meta = FactoryOptions()
@@ -454,6 +478,9 @@ class BaseFactory(object):
# Remove 'hidden' arguments.
for arg in cls._meta.exclude:
del kwargs[arg]
+ # Remove parameters, if defined
+ for arg in cls._meta.parameters:
+ kwargs.pop(arg, None)
# Extract *args from **kwargs
args = tuple(kwargs.pop(key) for key in cls._meta.inline_args)
@@ -477,7 +504,7 @@ class BaseFactory(object):
attrs (dict): attributes to use for generating the object
"""
if cls._meta.abstract:
- raise FactoryError(
+ raise errors.FactoryError(
"Cannot generate instances of abstract factory %(f)s; "
"Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
"is either not set or False." % dict(f=cls.__name__))
@@ -680,7 +707,7 @@ Factory = FactoryMetaClass('Factory', (BaseFactory,), {
# Backwards compatibility
-Factory.AssociatedClassError = AssociatedClassError # pylint: disable=W0201
+Factory.AssociatedClassError = errors.AssociatedClassError # pylint: disable=W0201
class StubFactory(Factory):
@@ -695,7 +722,7 @@ class StubFactory(Factory):
@classmethod
def create(cls, **kwargs):
- raise UnsupportedStrategy()
+ raise errors.UnsupportedStrategy()
class BaseDictFactory(Factory):
diff --git a/factory/compat.py b/factory/compat.py
index 785d174..737d91a 100644
--- a/factory/compat.py
+++ b/factory/compat.py
@@ -42,14 +42,6 @@ else: # pragma: no cover
from io import BytesIO
-if sys.version_info[:2] == (2, 6): # pragma: no cover
- def float_to_decimal(fl):
- return decimal.Decimal(str(fl))
-else: # pragma: no cover
- def float_to_decimal(fl):
- return decimal.Decimal(fl)
-
-
try: # pragma: no cover
# Python >= 3.2
UTC = datetime.timezone.utc
diff --git a/factory/containers.py b/factory/containers.py
index ec33ca1..4961115 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -25,13 +25,10 @@ import logging
logger = logging.getLogger(__name__)
from . import declarations
+from . import errors
from . import utils
-class CyclicDefinitionError(Exception):
- """Raised when cyclic definition were found."""
-
-
class LazyStub(object):
"""A generic container that only allows getting attributes.
@@ -93,7 +90,7 @@ class LazyStub(object):
attributes being computed.
"""
if name in self.__pending:
- raise CyclicDefinitionError(
+ raise errors.CyclicDefinitionError(
"Cyclic lazy attribute definition for %s; cycle found in %r." %
(name, self.__pending))
elif name in self.__values:
@@ -114,7 +111,6 @@ class LazyStub(object):
"The parameter %s is unknown. Evaluated attributes are %r, "
"definitions are %r." % (name, self.__values, self.__attrs))
-
def __setattr__(self, name, value):
"""Prevent setting attributes once __init__ is done."""
if not self.__initialized:
@@ -123,6 +119,69 @@ class LazyStub(object):
raise AttributeError('Setting of object attributes is not allowed')
+class DeclarationStack(object):
+ """An ordered stack of declarations.
+
+ This is intended to handle declaration precedence among different mutating layers.
+ """
+ def __init__(self, ordering):
+ self.ordering = ordering
+ self.layers = dict((name, {}) for name in self.ordering)
+
+ def __getitem__(self, key):
+ return self.layers[key]
+
+ def __setitem__(self, key, value):
+ assert key in self.ordering
+ self.layers[key] = value
+
+ def current(self):
+ """Retrieve the current, flattened declarations dict."""
+ result = {}
+ for layer in self.ordering:
+ result.update(self.layers[layer])
+ return result
+
+
+class ParameterResolver(object):
+ """Resolve a factory's parameter declarations."""
+ def __init__(self, parameters, deps):
+ self.parameters = parameters
+ self.deps = deps
+ self.declaration_stack = None
+
+ self.resolved = set()
+
+ def resolve_one(self, name):
+ """Compute one field is needed, taking dependencies into accounts."""
+ if name in self.resolved:
+ return
+
+ for dep in self.deps.get(name, ()):
+ self.resolve_one(dep)
+
+ self.compute(name)
+ self.resolved.add(name)
+
+ def compute(self, name):
+ """Actually compute the value for a given name."""
+ value = self.parameters[name]
+ if isinstance(value, declarations.ComplexParameter):
+ overrides = value.compute(name, self.declaration_stack.current())
+ else:
+ overrides = {name: value}
+ self.declaration_stack['overrides'].update(overrides)
+
+ def resolve(self, declaration_stack):
+ """Resolve parameters for a given declaration stack.
+
+ Modifies the stack in-place.
+ """
+ self.declaration_stack = declaration_stack
+ for name in self.parameters:
+ self.resolve_one(name)
+
+
class LazyValue(object):
"""Some kind of "lazy evaluating" object."""
@@ -131,7 +190,7 @@ class LazyValue(object):
raise NotImplementedError("This is an abstract method.")
-class OrderedDeclarationWrapper(LazyValue):
+class DeclarationWrapper(LazyValue):
"""Lazy wrapper around an OrderedDeclaration.
Attributes:
@@ -142,7 +201,7 @@ class OrderedDeclarationWrapper(LazyValue):
"""
def __init__(self, declaration, sequence, create, extra=None, **kwargs):
- super(OrderedDeclarationWrapper, self).__init__(**kwargs)
+ super(DeclarationWrapper, self).__init__(**kwargs)
self.declaration = declaration
self.sequence = sequence
self.create = create
@@ -172,7 +231,7 @@ class AttributeBuilder(object):
Attributes:
factory (base.Factory): the Factory for which attributes are being
built
- _attrs (DeclarationDict): the attribute declarations for the factory
+ _declarations (DeclarationDict): the attribute declarations for the factory
_subfields (dict): dict mapping an attribute name to a dict of
overridden default values for the related SubFactory.
"""
@@ -185,20 +244,47 @@ class AttributeBuilder(object):
self.factory = factory
self._containers = extra.pop('__containers', ())
- self._attrs = factory.declarations(extra)
+
+ initial_declarations = dict(factory._meta.declarations)
self._log_ctx = log_ctx
- initial_declarations = factory.declarations({})
+ # Parameters
+ # ----------
+ self._declarations = self.merge_declarations(initial_declarations, extra)
+
+ # Subfields
+ # ---------
+
attrs_with_subfields = [
k for k, v in initial_declarations.items()
- if self.has_subfields(v)]
+ if self.has_subfields(v)
+ ]
+ # Extract subfields; THIS MODIFIES self._declarations.
self._subfields = utils.multi_extract_dict(
- attrs_with_subfields, self._attrs)
+ attrs_with_subfields, self._declarations)
def has_subfields(self, value):
return isinstance(value, declarations.ParameteredAttribute)
+ def merge_declarations(self, initial, extra):
+ """Compute the final declarations, taking into account paramter-based overrides."""
+ # Precedence order:
+ # - Start with class-level declarations
+ # - Add overrides from parameters
+ # - Finally, use callsite-level declarations & values
+ declaration_stack = DeclarationStack(['initial', 'overrides', 'extra'])
+ declaration_stack['initial'] = initial.copy()
+ declaration_stack['extra'] = extra.copy()
+
+ # Actually compute the final stack
+ resolver = ParameterResolver(
+ parameters=self.factory._meta.parameters,
+ deps=self.factory._meta.parameters_dependencies,
+ )
+ resolver.resolve(declaration_stack)
+ return declaration_stack.current()
+
def build(self, create, force_sequence=None):
"""Build a dictionary of attributes.
@@ -216,9 +302,9 @@ class AttributeBuilder(object):
# Parse attribute declarations, wrapping SubFactory and
# OrderedDeclaration.
wrapped_attrs = {}
- for k, v in self._attrs.items():
+ for k, v in self._declarations.items():
if isinstance(v, declarations.OrderedDeclaration):
- v = OrderedDeclarationWrapper(v,
+ v = DeclarationWrapper(v,
sequence=sequence,
create=create,
extra=self._subfields.get(k, {}),
diff --git a/factory/declarations.py b/factory/declarations.py
index f0dbfe5..895f2ac 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -57,6 +57,23 @@ class OrderedDeclaration(object):
raise NotImplementedError('This is an abstract method')
+class LazyFunction(OrderedDeclaration):
+ """Simplest OrderedDeclaration computed by calling the given function.
+
+ Attributes:
+ function (function): a function without arguments and
+ returning the computed value.
+ """
+
+ def __init__(self, function, *args, **kwargs):
+ super(LazyFunction, self).__init__(*args, **kwargs)
+ self.function = function
+
+ def evaluate(self, sequence, obj, create, extra=None, containers=()):
+ logger.debug("LazyFunction: Evaluating %r on %r", self.function, obj)
+ return self.function()
+
+
class LazyAttribute(OrderedDeclaration):
"""Specific OrderedDeclaration computed using a lambda.
@@ -423,6 +440,55 @@ class List(SubFactory):
**params)
+# Parameters
+# ==========
+
+
+class ComplexParameter(object):
+ """A complex parameter, to be used in a Factory.Params section.
+
+ Must implement:
+ - A "compute" function, performing the actual declaration override
+ - Optionally, a get_revdeps() function (to compute other parameters it may alter)
+ """
+
+ def compute(self, field_name, declarations):
+ """Compute the overrides for this parameter.
+
+ Args:
+ - field_name (str): the field this parameter is installed at
+ - declarations (dict): the global factory declarations
+
+ Returns:
+ dict: the declarations to override
+ """
+ raise NotImplementedError()
+
+ def get_revdeps(self, parameters):
+ """Retrieve the list of other parameters modified by this one."""
+ return []
+
+
+class Trait(ComplexParameter):
+ """The simplest complex parameter, it enables a bunch of new declarations based on a boolean flag."""
+ def __init__(self, **overrides):
+ self.overrides = overrides
+
+ def compute(self, field_name, declarations):
+ if declarations.get(field_name):
+ return self.overrides
+ else:
+ return {}
+
+ def get_revdeps(self, parameters):
+ """This might alter fields it's injecting."""
+ return [param for param in parameters if param in self.overrides]
+
+
+# Post-generation
+# ===============
+
+
class ExtractionContext(object):
"""Private class holding all required context from extraction to postgen."""
def __init__(self, value=None, did_extract=False, extra=None, for_field=''):
diff --git a/factory/errors.py b/factory/errors.py
new file mode 100644
index 0000000..79d85f4
--- /dev/null
+++ b/factory/errors.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2011-2015 Raphaël Barrois
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+
+class FactoryError(Exception):
+ """Any exception raised by factory_boy."""
+
+
+class AssociatedClassError(FactoryError):
+ """Exception for Factory subclasses lacking Meta.model."""
+
+
+class UnknownStrategy(FactoryError):
+ """Raised when a factory uses an unknown strategy."""
+
+
+class UnsupportedStrategy(FactoryError):
+ """Raised when trying to use a strategy on an incompatible Factory."""
+
+
+class CyclicDefinitionError(FactoryError):
+ """Raised when a cyclical declaration occurs."""
+
+
diff --git a/factory/fuzzy.py b/factory/fuzzy.py
index 923d8b7..71d1884 100644
--- a/factory/fuzzy.py
+++ b/factory/fuzzy.py
@@ -164,7 +164,7 @@ class FuzzyDecimal(BaseFuzzyAttribute):
super(FuzzyDecimal, self).__init__(**kwargs)
def fuzz(self):
- base = compat.float_to_decimal(_random.uniform(self.low, self.high))
+ base = decimal.Decimal(str(_random.uniform(self.low, self.high)))
return base.quantize(decimal.Decimal(10) ** -self.precision)
@@ -217,6 +217,9 @@ class BaseFuzzyDateTime(BaseFuzzyAttribute):
"""%s boundaries should have start <= end, got %r > %r""" % (
self.__class__.__name__, start_dt, end_dt))
+ def _now(self):
+ raise NotImplementedError()
+
def __init__(self, start_dt, end_dt=None,
force_year=None, force_month=None, force_day=None,
force_hour=None, force_minute=None, force_second=None,
diff --git a/factory/utils.py b/factory/utils.py
index 806b1ec..cfae4ec 100644
--- a/factory/utils.py
+++ b/factory/utils.py
@@ -35,7 +35,7 @@ def extract_dict(prefix, kwargs, pop=True, exclude=()):
Args:
prefix (str): the prefix to use for lookups
- kwargs (dict): the dict from which values should be extracted
+ kwargs (dict): the dict from which values should be extracted; WILL BE MODIFIED.
pop (bool): whether to use pop (True) or get (False)
exclude (iterable): list of prefixed keys that shouldn't be extracted
@@ -68,7 +68,7 @@ def multi_extract_dict(prefixes, kwargs, pop=True, exclude=()):
Args:
prefixes (str list): the prefixes to use for lookups
- kwargs (dict): the dict from which values should be extracted
+ kwargs (dict): the dict from which values should be extracted; WILL BE MODIFIED.
pop (bool): whether to use pop (True) or get (False)
exclude (iterable): list of prefixed keys that shouldn't be extracted
@@ -101,7 +101,7 @@ def import_object(module_name, attribute_name):
def _safe_repr(obj):
try:
obj_repr = repr(obj)
- except UnicodeError:
+ except Exception:
return '<bad_repr object at %s>' % id(obj)
try: # Convert to "text type" (= unicode)