diff options
Diffstat (limited to 'factory')
-rw-r--r-- | factory/__init__.py | 11 | ||||
-rw-r--r-- | factory/alchemy.py | 3 | ||||
-rw-r--r-- | factory/base.py | 75 | ||||
-rw-r--r-- | factory/compat.py | 8 | ||||
-rw-r--r-- | factory/containers.py | 116 | ||||
-rw-r--r-- | factory/declarations.py | 66 | ||||
-rw-r--r-- | factory/errors.py | 42 | ||||
-rw-r--r-- | factory/fuzzy.py | 5 | ||||
-rw-r--r-- | factory/utils.py | 6 |
9 files changed, 278 insertions, 54 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index 4a4a09f..ad9da80 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -__version__ = '2.6.0' +__version__ = '2.6.1' __author__ = 'Raphaël Barrois <raphael.barrois+fboy@polytechnique.org>' @@ -32,22 +32,27 @@ from .base import ( ListFactory, StubFactory, - FactoryError, - BUILD_STRATEGY, CREATE_STRATEGY, STUB_STRATEGY, use_strategy, ) + +from .errors import ( + FactoryError, +) + from .faker import Faker from .declarations import ( + LazyFunction, LazyAttribute, Iterator, Sequence, LazyAttributeSequence, SelfAttribute, + Trait, ContainerAttribute, SubFactory, Dict, diff --git a/factory/alchemy.py b/factory/alchemy.py index 20da6cf..a9aab23 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -27,6 +27,7 @@ class SQLAlchemyOptions(base.FactoryOptions): def _build_default_options(self): return super(SQLAlchemyOptions, self)._build_default_options() + [ base.OptionDefault('sqlalchemy_session', None, inherit=True), + base.OptionDefault('force_flush', False, inherit=True), ] @@ -43,4 +44,6 @@ class SQLAlchemyModelFactory(base.Factory): session = cls._meta.sqlalchemy_session obj = model_class(*args, **kwargs) session.add(obj) + if cls._meta.force_flush: + session.flush() return obj diff --git a/factory/base.py b/factory/base.py index 0f2af59..282e3b1 100644 --- a/factory/base.py +++ b/factory/base.py @@ -20,10 +20,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import collections import logging from . import containers from . import declarations +from . import errors from . import utils logger = logging.getLogger('factory.generate') @@ -35,22 +37,6 @@ STUB_STRATEGY = 'stub' -class FactoryError(Exception): - """Any exception raised by factory_boy.""" - - -class AssociatedClassError(FactoryError): - """Exception for Factory subclasses lacking Meta.model.""" - - -class UnknownStrategy(FactoryError): - """Raised when a factory uses an unknown strategy.""" - - -class UnsupportedStrategy(FactoryError): - """Raised when trying to use a strategy on an incompatible Factory.""" - - # Factory metaclasses def get_factory_bases(bases): @@ -82,7 +68,7 @@ class FactoryMetaClass(type): elif cls._meta.strategy == STUB_STRATEGY: return cls.stub(**kwargs) else: - raise UnknownStrategy('Unknown Meta.strategy: {0}'.format( + raise errors.UnknownStrategy('Unknown Meta.strategy: {0}'.format( cls._meta.strategy)) def __new__(mcs, class_name, bases, attrs): @@ -107,6 +93,7 @@ class FactoryMetaClass(type): base_factory = None attrs_meta = attrs.pop('Meta', None) + attrs_params = attrs.pop('Params', None) base_meta = resolve_attribute('_meta', bases) options_class = resolve_attribute('_options_class', bases, FactoryOptions) @@ -121,6 +108,7 @@ class FactoryMetaClass(type): meta=attrs_meta, base_meta=base_meta, base_factory=base_factory, + params=attrs_params, ) return new_class @@ -163,6 +151,8 @@ class FactoryOptions(object): self.base_factory = None self.declarations = {} self.postgen_declarations = {} + self.parameters = {} + self.parameters_dependencies = {} def _build_default_options(self): """"Provide the default value for all allowed fields. @@ -201,7 +191,7 @@ class FactoryOptions(object): % (self.factory, ','.join(sorted(meta_attrs.keys())))) def contribute_to_class(self, factory, - meta=None, base_meta=None, base_factory=None): + meta=None, base_meta=None, base_factory=None, params=None): self.factory = factory self.base_factory = base_factory @@ -219,6 +209,7 @@ class FactoryOptions(object): continue self.declarations.update(parent._meta.declarations) self.postgen_declarations.update(parent._meta.postgen_declarations) + self.parameters.update(parent._meta.parameters) for k, v in vars(self.factory).items(): if self._is_declaration(k, v): @@ -226,6 +217,13 @@ class FactoryOptions(object): if self._is_postgen_declaration(k, v): self.postgen_declarations[k] = v + if params is not None: + for k, v in vars(params).items(): + if not k.startswith('_'): + self.parameters[k] = v + + self.parameters_dependencies = self._compute_parameter_dependencies(self.parameters) + def _get_counter_reference(self): """Identify which factory should be used for a shared counter.""" @@ -257,6 +255,32 @@ class FactoryOptions(object): """Captures instances of PostGenerationDeclaration.""" return isinstance(value, declarations.PostGenerationDeclaration) + def _compute_parameter_dependencies(self, parameters): + """Find out in what order parameters should be called.""" + # Warning: parameters only provide reverse dependencies; we reverse them into standard dependencies. + # deep_revdeps: set of fields a field depend indirectly upon + deep_revdeps = collections.defaultdict(set) + # Actual, direct dependencies + deps = collections.defaultdict(set) + + for name, parameter in parameters.items(): + if isinstance(parameter, declarations.ComplexParameter): + field_revdeps = parameter.get_revdeps(parameters) + if not field_revdeps: + continue + deep_revdeps[name] = set.union(*(deep_revdeps[dep] for dep in field_revdeps)) + deep_revdeps[name] |= set(field_revdeps) + for dep in field_revdeps: + deps[dep].add(name) + + # Check for cyclical dependencies + cyclic = [name for name, field_deps in deep_revdeps.items() if name in field_deps] + if cyclic: + raise errors.CyclicDefinitionError( + "Cyclic definition detected on %s' Params around %s" + % (self.factory, ', '.join(cyclic))) + return deps + def __str__(self): return "<%s for %s>" % (self.__class__.__name__, self.factory.__class__.__name__) @@ -296,12 +320,12 @@ class BaseFactory(object): """Factory base support for sequences, attributes and stubs.""" # Backwards compatibility - UnknownStrategy = UnknownStrategy - UnsupportedStrategy = UnsupportedStrategy + UnknownStrategy = errors.UnknownStrategy + UnsupportedStrategy = errors.UnsupportedStrategy def __new__(cls, *args, **kwargs): """Would be called if trying to instantiate the class.""" - raise FactoryError('You cannot instantiate BaseFactory') + raise errors.FactoryError('You cannot instantiate BaseFactory') _meta = FactoryOptions() @@ -454,6 +478,9 @@ class BaseFactory(object): # Remove 'hidden' arguments. for arg in cls._meta.exclude: del kwargs[arg] + # Remove parameters, if defined + for arg in cls._meta.parameters: + kwargs.pop(arg, None) # Extract *args from **kwargs args = tuple(kwargs.pop(key) for key in cls._meta.inline_args) @@ -477,7 +504,7 @@ class BaseFactory(object): attrs (dict): attributes to use for generating the object """ if cls._meta.abstract: - raise FactoryError( + raise errors.FactoryError( "Cannot generate instances of abstract factory %(f)s; " "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract " "is either not set or False." % dict(f=cls.__name__)) @@ -680,7 +707,7 @@ Factory = FactoryMetaClass('Factory', (BaseFactory,), { # Backwards compatibility -Factory.AssociatedClassError = AssociatedClassError # pylint: disable=W0201 +Factory.AssociatedClassError = errors.AssociatedClassError # pylint: disable=W0201 class StubFactory(Factory): @@ -695,7 +722,7 @@ class StubFactory(Factory): @classmethod def create(cls, **kwargs): - raise UnsupportedStrategy() + raise errors.UnsupportedStrategy() class BaseDictFactory(Factory): diff --git a/factory/compat.py b/factory/compat.py index 785d174..737d91a 100644 --- a/factory/compat.py +++ b/factory/compat.py @@ -42,14 +42,6 @@ else: # pragma: no cover from io import BytesIO -if sys.version_info[:2] == (2, 6): # pragma: no cover - def float_to_decimal(fl): - return decimal.Decimal(str(fl)) -else: # pragma: no cover - def float_to_decimal(fl): - return decimal.Decimal(fl) - - try: # pragma: no cover # Python >= 3.2 UTC = datetime.timezone.utc diff --git a/factory/containers.py b/factory/containers.py index ec33ca1..4961115 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -25,13 +25,10 @@ import logging logger = logging.getLogger(__name__) from . import declarations +from . import errors from . import utils -class CyclicDefinitionError(Exception): - """Raised when cyclic definition were found.""" - - class LazyStub(object): """A generic container that only allows getting attributes. @@ -93,7 +90,7 @@ class LazyStub(object): attributes being computed. """ if name in self.__pending: - raise CyclicDefinitionError( + raise errors.CyclicDefinitionError( "Cyclic lazy attribute definition for %s; cycle found in %r." % (name, self.__pending)) elif name in self.__values: @@ -114,7 +111,6 @@ class LazyStub(object): "The parameter %s is unknown. Evaluated attributes are %r, " "definitions are %r." % (name, self.__values, self.__attrs)) - def __setattr__(self, name, value): """Prevent setting attributes once __init__ is done.""" if not self.__initialized: @@ -123,6 +119,69 @@ class LazyStub(object): raise AttributeError('Setting of object attributes is not allowed') +class DeclarationStack(object): + """An ordered stack of declarations. + + This is intended to handle declaration precedence among different mutating layers. + """ + def __init__(self, ordering): + self.ordering = ordering + self.layers = dict((name, {}) for name in self.ordering) + + def __getitem__(self, key): + return self.layers[key] + + def __setitem__(self, key, value): + assert key in self.ordering + self.layers[key] = value + + def current(self): + """Retrieve the current, flattened declarations dict.""" + result = {} + for layer in self.ordering: + result.update(self.layers[layer]) + return result + + +class ParameterResolver(object): + """Resolve a factory's parameter declarations.""" + def __init__(self, parameters, deps): + self.parameters = parameters + self.deps = deps + self.declaration_stack = None + + self.resolved = set() + + def resolve_one(self, name): + """Compute one field is needed, taking dependencies into accounts.""" + if name in self.resolved: + return + + for dep in self.deps.get(name, ()): + self.resolve_one(dep) + + self.compute(name) + self.resolved.add(name) + + def compute(self, name): + """Actually compute the value for a given name.""" + value = self.parameters[name] + if isinstance(value, declarations.ComplexParameter): + overrides = value.compute(name, self.declaration_stack.current()) + else: + overrides = {name: value} + self.declaration_stack['overrides'].update(overrides) + + def resolve(self, declaration_stack): + """Resolve parameters for a given declaration stack. + + Modifies the stack in-place. + """ + self.declaration_stack = declaration_stack + for name in self.parameters: + self.resolve_one(name) + + class LazyValue(object): """Some kind of "lazy evaluating" object.""" @@ -131,7 +190,7 @@ class LazyValue(object): raise NotImplementedError("This is an abstract method.") -class OrderedDeclarationWrapper(LazyValue): +class DeclarationWrapper(LazyValue): """Lazy wrapper around an OrderedDeclaration. Attributes: @@ -142,7 +201,7 @@ class OrderedDeclarationWrapper(LazyValue): """ def __init__(self, declaration, sequence, create, extra=None, **kwargs): - super(OrderedDeclarationWrapper, self).__init__(**kwargs) + super(DeclarationWrapper, self).__init__(**kwargs) self.declaration = declaration self.sequence = sequence self.create = create @@ -172,7 +231,7 @@ class AttributeBuilder(object): Attributes: factory (base.Factory): the Factory for which attributes are being built - _attrs (DeclarationDict): the attribute declarations for the factory + _declarations (DeclarationDict): the attribute declarations for the factory _subfields (dict): dict mapping an attribute name to a dict of overridden default values for the related SubFactory. """ @@ -185,20 +244,47 @@ class AttributeBuilder(object): self.factory = factory self._containers = extra.pop('__containers', ()) - self._attrs = factory.declarations(extra) + + initial_declarations = dict(factory._meta.declarations) self._log_ctx = log_ctx - initial_declarations = factory.declarations({}) + # Parameters + # ---------- + self._declarations = self.merge_declarations(initial_declarations, extra) + + # Subfields + # --------- + attrs_with_subfields = [ k for k, v in initial_declarations.items() - if self.has_subfields(v)] + if self.has_subfields(v) + ] + # Extract subfields; THIS MODIFIES self._declarations. self._subfields = utils.multi_extract_dict( - attrs_with_subfields, self._attrs) + attrs_with_subfields, self._declarations) def has_subfields(self, value): return isinstance(value, declarations.ParameteredAttribute) + def merge_declarations(self, initial, extra): + """Compute the final declarations, taking into account paramter-based overrides.""" + # Precedence order: + # - Start with class-level declarations + # - Add overrides from parameters + # - Finally, use callsite-level declarations & values + declaration_stack = DeclarationStack(['initial', 'overrides', 'extra']) + declaration_stack['initial'] = initial.copy() + declaration_stack['extra'] = extra.copy() + + # Actually compute the final stack + resolver = ParameterResolver( + parameters=self.factory._meta.parameters, + deps=self.factory._meta.parameters_dependencies, + ) + resolver.resolve(declaration_stack) + return declaration_stack.current() + def build(self, create, force_sequence=None): """Build a dictionary of attributes. @@ -216,9 +302,9 @@ class AttributeBuilder(object): # Parse attribute declarations, wrapping SubFactory and # OrderedDeclaration. wrapped_attrs = {} - for k, v in self._attrs.items(): + for k, v in self._declarations.items(): if isinstance(v, declarations.OrderedDeclaration): - v = OrderedDeclarationWrapper(v, + v = DeclarationWrapper(v, sequence=sequence, create=create, extra=self._subfields.get(k, {}), diff --git a/factory/declarations.py b/factory/declarations.py index f0dbfe5..895f2ac 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -57,6 +57,23 @@ class OrderedDeclaration(object): raise NotImplementedError('This is an abstract method') +class LazyFunction(OrderedDeclaration): + """Simplest OrderedDeclaration computed by calling the given function. + + Attributes: + function (function): a function without arguments and + returning the computed value. + """ + + def __init__(self, function, *args, **kwargs): + super(LazyFunction, self).__init__(*args, **kwargs) + self.function = function + + def evaluate(self, sequence, obj, create, extra=None, containers=()): + logger.debug("LazyFunction: Evaluating %r on %r", self.function, obj) + return self.function() + + class LazyAttribute(OrderedDeclaration): """Specific OrderedDeclaration computed using a lambda. @@ -423,6 +440,55 @@ class List(SubFactory): **params) +# Parameters +# ========== + + +class ComplexParameter(object): + """A complex parameter, to be used in a Factory.Params section. + + Must implement: + - A "compute" function, performing the actual declaration override + - Optionally, a get_revdeps() function (to compute other parameters it may alter) + """ + + def compute(self, field_name, declarations): + """Compute the overrides for this parameter. + + Args: + - field_name (str): the field this parameter is installed at + - declarations (dict): the global factory declarations + + Returns: + dict: the declarations to override + """ + raise NotImplementedError() + + def get_revdeps(self, parameters): + """Retrieve the list of other parameters modified by this one.""" + return [] + + +class Trait(ComplexParameter): + """The simplest complex parameter, it enables a bunch of new declarations based on a boolean flag.""" + def __init__(self, **overrides): + self.overrides = overrides + + def compute(self, field_name, declarations): + if declarations.get(field_name): + return self.overrides + else: + return {} + + def get_revdeps(self, parameters): + """This might alter fields it's injecting.""" + return [param for param in parameters if param in self.overrides] + + +# Post-generation +# =============== + + class ExtractionContext(object): """Private class holding all required context from extraction to postgen.""" def __init__(self, value=None, did_extract=False, extra=None, for_field=''): diff --git a/factory/errors.py b/factory/errors.py new file mode 100644 index 0000000..79d85f4 --- /dev/null +++ b/factory/errors.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2011-2015 Raphaël Barrois +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +class FactoryError(Exception): + """Any exception raised by factory_boy.""" + + +class AssociatedClassError(FactoryError): + """Exception for Factory subclasses lacking Meta.model.""" + + +class UnknownStrategy(FactoryError): + """Raised when a factory uses an unknown strategy.""" + + +class UnsupportedStrategy(FactoryError): + """Raised when trying to use a strategy on an incompatible Factory.""" + + +class CyclicDefinitionError(FactoryError): + """Raised when a cyclical declaration occurs.""" + + diff --git a/factory/fuzzy.py b/factory/fuzzy.py index 923d8b7..71d1884 100644 --- a/factory/fuzzy.py +++ b/factory/fuzzy.py @@ -164,7 +164,7 @@ class FuzzyDecimal(BaseFuzzyAttribute): super(FuzzyDecimal, self).__init__(**kwargs) def fuzz(self): - base = compat.float_to_decimal(_random.uniform(self.low, self.high)) + base = decimal.Decimal(str(_random.uniform(self.low, self.high))) return base.quantize(decimal.Decimal(10) ** -self.precision) @@ -217,6 +217,9 @@ class BaseFuzzyDateTime(BaseFuzzyAttribute): """%s boundaries should have start <= end, got %r > %r""" % ( self.__class__.__name__, start_dt, end_dt)) + def _now(self): + raise NotImplementedError() + def __init__(self, start_dt, end_dt=None, force_year=None, force_month=None, force_day=None, force_hour=None, force_minute=None, force_second=None, diff --git a/factory/utils.py b/factory/utils.py index 806b1ec..cfae4ec 100644 --- a/factory/utils.py +++ b/factory/utils.py @@ -35,7 +35,7 @@ def extract_dict(prefix, kwargs, pop=True, exclude=()): Args: prefix (str): the prefix to use for lookups - kwargs (dict): the dict from which values should be extracted + kwargs (dict): the dict from which values should be extracted; WILL BE MODIFIED. pop (bool): whether to use pop (True) or get (False) exclude (iterable): list of prefixed keys that shouldn't be extracted @@ -68,7 +68,7 @@ def multi_extract_dict(prefixes, kwargs, pop=True, exclude=()): Args: prefixes (str list): the prefixes to use for lookups - kwargs (dict): the dict from which values should be extracted + kwargs (dict): the dict from which values should be extracted; WILL BE MODIFIED. pop (bool): whether to use pop (True) or get (False) exclude (iterable): list of prefixed keys that shouldn't be extracted @@ -101,7 +101,7 @@ def import_object(module_name, attribute_name): def _safe_repr(obj): try: obj_repr = repr(obj) - except UnicodeError: + except Exception: return '<bad_repr object at %s>' % id(obj) try: # Convert to "text type" (= unicode) |