From c77962de7dd7206ccab85b44da173832acbf5921 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Sat, 2 Apr 2016 16:13:34 +0200 Subject: Add a new Params section to factories. This handles parameters that alter the declarations of a factory. A few technical notes: - A parameter's outcome may alter other parameters - In order to fix that, we perform a (simple) cyclic definition detection at class declaration time. - Parameters may only be either naked values or ComplexParameter subclasses - Parameters are never passed to the underlying class --- docs/reference.rst | 46 +++++++++++++++++ factory/base.py | 44 ++++++++++++++++- factory/containers.py | 108 ++++++++++++++++++++++++++++++++++++---- factory/declarations.py | 33 +++++++++++++ factory/utils.py | 4 +- tests/test_containers.py | 125 +++++++++-------------------------------------- tests/test_using.py | 9 ++++ 7 files changed, 256 insertions(+), 113 deletions(-) diff --git a/docs/reference.rst b/docs/reference.rst index e2f63db..8550f88 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -1299,6 +1299,52 @@ with the :class:`Dict` and :class:`List` attributes: argument, if another type (tuple, set, ...) is required. +Parameters +"""""""""" + +Some models have many fields that can be summarized by a few parameters; for instance, +a train with many cars — each complete with serial number, manufacturer, ...; +or an order that can be pending/shipped/received, with a few fields to describe each step. + +When building instances of such models, a couple of parameters can be enough to determine +all other fields; this is handled by the :class:`~Factory.Params` section of a :class:`Factory` declaration. + + +Simple parameters +~~~~~~~~~~~~~~~~~ + +Some factories only need little data: + +.. code-block:: python + + class ConferenceFactory(factory.Factory): + class Meta: + model = Conference + + class Params: + duration = 'short' # Or 'long' + + start_date = factory.fuzzy.FuzzyDate() + end_date = factory.LazyAttribute( + lambda o: o.start_date + datetime.timedelta(days=2 if o.duration == 'short' else 7) + ) + sprints_start = factory.LazyAttribute( + lambda o: o.end_date - datetime.timedelta(days=0 if o.duration == 'short' else 1) + ) + +.. code-block:: pycon + + >>> Conference(duration='short') + + >>> Conference(duration='long') + + + +Any simple parameter provided to the :class:`Factory.Params` section is available to the whole factory, +but not passed to the final class (similar to the :attr:`~FactoryOptions.exclude` behavior). + + + Post-generation hooks """"""""""""""""""""" diff --git a/factory/base.py b/factory/base.py index 1ddb742..282e3b1 100644 --- a/factory/base.py +++ b/factory/base.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import collections import logging from . import containers @@ -92,6 +93,7 @@ class FactoryMetaClass(type): base_factory = None attrs_meta = attrs.pop('Meta', None) + attrs_params = attrs.pop('Params', None) base_meta = resolve_attribute('_meta', bases) options_class = resolve_attribute('_options_class', bases, FactoryOptions) @@ -106,6 +108,7 @@ class FactoryMetaClass(type): meta=attrs_meta, base_meta=base_meta, base_factory=base_factory, + params=attrs_params, ) return new_class @@ -148,6 +151,8 @@ class FactoryOptions(object): self.base_factory = None self.declarations = {} self.postgen_declarations = {} + self.parameters = {} + self.parameters_dependencies = {} def _build_default_options(self): """"Provide the default value for all allowed fields. @@ -186,7 +191,7 @@ class FactoryOptions(object): % (self.factory, ','.join(sorted(meta_attrs.keys())))) def contribute_to_class(self, factory, - meta=None, base_meta=None, base_factory=None): + meta=None, base_meta=None, base_factory=None, params=None): self.factory = factory self.base_factory = base_factory @@ -204,6 +209,7 @@ class FactoryOptions(object): continue self.declarations.update(parent._meta.declarations) self.postgen_declarations.update(parent._meta.postgen_declarations) + self.parameters.update(parent._meta.parameters) for k, v in vars(self.factory).items(): if self._is_declaration(k, v): @@ -211,6 +217,13 @@ class FactoryOptions(object): if self._is_postgen_declaration(k, v): self.postgen_declarations[k] = v + if params is not None: + for k, v in vars(params).items(): + if not k.startswith('_'): + self.parameters[k] = v + + self.parameters_dependencies = self._compute_parameter_dependencies(self.parameters) + def _get_counter_reference(self): """Identify which factory should be used for a shared counter.""" @@ -242,6 +255,32 @@ class FactoryOptions(object): """Captures instances of PostGenerationDeclaration.""" return isinstance(value, declarations.PostGenerationDeclaration) + def _compute_parameter_dependencies(self, parameters): + """Find out in what order parameters should be called.""" + # Warning: parameters only provide reverse dependencies; we reverse them into standard dependencies. + # deep_revdeps: set of fields a field depend indirectly upon + deep_revdeps = collections.defaultdict(set) + # Actual, direct dependencies + deps = collections.defaultdict(set) + + for name, parameter in parameters.items(): + if isinstance(parameter, declarations.ComplexParameter): + field_revdeps = parameter.get_revdeps(parameters) + if not field_revdeps: + continue + deep_revdeps[name] = set.union(*(deep_revdeps[dep] for dep in field_revdeps)) + deep_revdeps[name] |= set(field_revdeps) + for dep in field_revdeps: + deps[dep].add(name) + + # Check for cyclical dependencies + cyclic = [name for name, field_deps in deep_revdeps.items() if name in field_deps] + if cyclic: + raise errors.CyclicDefinitionError( + "Cyclic definition detected on %s' Params around %s" + % (self.factory, ', '.join(cyclic))) + return deps + def __str__(self): return "<%s for %s>" % (self.__class__.__name__, self.factory.__class__.__name__) @@ -439,6 +478,9 @@ class BaseFactory(object): # Remove 'hidden' arguments. for arg in cls._meta.exclude: del kwargs[arg] + # Remove parameters, if defined + for arg in cls._meta.parameters: + kwargs.pop(arg, None) # Extract *args from **kwargs args = tuple(kwargs.pop(key) for key in cls._meta.inline_args) diff --git a/factory/containers.py b/factory/containers.py index c591988..d3f39c4 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -117,6 +117,69 @@ class LazyStub(object): raise AttributeError('Setting of object attributes is not allowed') +class DeclarationStack(object): + """An ordered stack of declarations. + + This is intended to handle declaration precedence among different mutating layers. + """ + def __init__(self, ordering): + self.ordering = ordering + self.layers = dict((name, {}) for name in self.ordering) + + def __getitem__(self, key): + return self.layers[key] + + def __setitem__(self, key, value): + assert key in self.ordering + self.layers[key] = value + + def current(self): + """Retrieve the current, flattened declarations dict.""" + result = {} + for layer in self.ordering: + result.update(self.layers[layer]) + return result + + +class ParameterResolver(object): + """Resolve a factory's parameter declarations.""" + def __init__(self, parameters, deps): + self.parameters = parameters + self.deps = deps + self.declaration_stack = None + + self.resolved = set() + + def resolve_one(self, name): + """Compute one field is needed, taking dependencies into accounts.""" + if name in self.resolved: + return + + for dep in self.deps.get(name, ()): + self.resolve_one(dep) + + self.compute(name) + self.resolved.add(name) + + def compute(self, name): + """Actually compute the value for a given name.""" + value = self.parameters[name] + if isinstance(value, declarations.ComplexParameter): + overrides = value.compute(name, self.declaration_stack.current()) + else: + overrides = {name: value} + self.declaration_stack['overrides'].update(overrides) + + def resolve(self, declaration_stack): + """Resolve parameters for a given declaration stack. + + Modifies the stack in-place. + """ + self.declaration_stack = declaration_stack + for name in self.parameters: + self.resolve_one(name) + + class LazyValue(object): """Some kind of "lazy evaluating" object.""" @@ -125,7 +188,7 @@ class LazyValue(object): raise NotImplementedError("This is an abstract method.") -class OrderedDeclarationWrapper(LazyValue): +class DeclarationWrapper(LazyValue): """Lazy wrapper around an OrderedDeclaration. Attributes: @@ -136,7 +199,7 @@ class OrderedDeclarationWrapper(LazyValue): """ def __init__(self, declaration, sequence, create, extra=None, **kwargs): - super(OrderedDeclarationWrapper, self).__init__(**kwargs) + super(DeclarationWrapper, self).__init__(**kwargs) self.declaration = declaration self.sequence = sequence self.create = create @@ -166,7 +229,7 @@ class AttributeBuilder(object): Attributes: factory (base.Factory): the Factory for which attributes are being built - _attrs (DeclarationDict): the attribute declarations for the factory + _declarations (DeclarationDict): the attribute declarations for the factory _subfields (dict): dict mapping an attribute name to a dict of overridden default values for the related SubFactory. """ @@ -179,20 +242,47 @@ class AttributeBuilder(object): self.factory = factory self._containers = extra.pop('__containers', ()) - self._attrs = factory.declarations(extra) + + initial_declarations = dict(factory._meta.declarations) self._log_ctx = log_ctx - initial_declarations = factory.declarations({}) + # Parameters + # ---------- + self._declarations = self.merge_declarations(initial_declarations, extra) + + # Subfields + # --------- + attrs_with_subfields = [ k for k, v in initial_declarations.items() - if self.has_subfields(v)] + if self.has_subfields(v) + ] + # Extract subfields; THIS MODIFIES self._declarations. self._subfields = utils.multi_extract_dict( - attrs_with_subfields, self._attrs) + attrs_with_subfields, self._declarations) def has_subfields(self, value): return isinstance(value, declarations.ParameteredAttribute) + def merge_declarations(self, initial, extra): + """Compute the final declarations, taking into account paramter-based overrides.""" + # Precedence order: + # - Start with class-level declarations + # - Add overrides from parameters + # - Finally, use callsite-level declarations & values + declaration_stack = DeclarationStack(['initial', 'overrides', 'extra']) + declaration_stack['initial'] = initial.copy() + declaration_stack['extra'] = extra.copy() + + # Actually compute the final stack + resolver = ParameterResolver( + parameters=self.factory._meta.parameters, + deps=self.factory._meta.parameters_dependencies, + ) + resolver.resolve(declaration_stack) + return declaration_stack.current() + def build(self, create, force_sequence=None): """Build a dictionary of attributes. @@ -210,9 +300,9 @@ class AttributeBuilder(object): # Parse attribute declarations, wrapping SubFactory and # OrderedDeclaration. wrapped_attrs = {} - for k, v in self._attrs.items(): + for k, v in self._declarations.items(): if isinstance(v, declarations.OrderedDeclaration): - v = OrderedDeclarationWrapper(v, + v = DeclarationWrapper(v, sequence=sequence, create=create, extra=self._subfields.get(k, {}), diff --git a/factory/declarations.py b/factory/declarations.py index 9ab7462..ad1f72f 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -440,6 +440,39 @@ class List(SubFactory): **params) +# Parameters +# ========== + + +class ComplexParameter(object): + """A complex parameter, to be used in a Factory.Params section. + + Must implement: + - A "compute" function, performing the actual declaration override + - Optionally, a get_revdeps() function (to compute other parameters it may alter) + """ + + def compute(self, field_name, declarations): + """Compute the overrides for this parameter. + + Args: + - field_name (str): the field this parameter is installed at + - declarations (dict): the global factory declarations + + Returns: + dict: the declarations to override + """ + raise NotImplementedError() + + def get_revdeps(self, parameters): + """Retrieve the list of other parameters modified by this one.""" + return [] + + +# Post-generation +# =============== + + class ExtractionContext(object): """Private class holding all required context from extraction to postgen.""" def __init__(self, value=None, did_extract=False, extra=None, for_field=''): diff --git a/factory/utils.py b/factory/utils.py index 15dba0a..cfae4ec 100644 --- a/factory/utils.py +++ b/factory/utils.py @@ -35,7 +35,7 @@ def extract_dict(prefix, kwargs, pop=True, exclude=()): Args: prefix (str): the prefix to use for lookups - kwargs (dict): the dict from which values should be extracted + kwargs (dict): the dict from which values should be extracted; WILL BE MODIFIED. pop (bool): whether to use pop (True) or get (False) exclude (iterable): list of prefixed keys that shouldn't be extracted @@ -68,7 +68,7 @@ def multi_extract_dict(prefixes, kwargs, pop=True, exclude=()): Args: prefixes (str list): the prefixes to use for lookups - kwargs (dict): the dict from which values should be extracted + kwargs (dict): the dict from which values should be extracted; WILL BE MODIFIED. pop (bool): whether to use pop (True) or get (False) exclude (iterable): list of prefixed keys that shouldn't be extracted diff --git a/tests/test_containers.py b/tests/test_containers.py index 20c773a..a308353 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -103,97 +103,56 @@ class LazyStubTestCase(unittest.TestCase): class AttributeBuilderTestCase(unittest.TestCase): - def test_empty(self): - """Tests building attributes from an empty definition.""" + + def make_fake_factory(self, decls): + class Meta: + declarations = decls + parameters = {} + parameters_dependencies = {} class FakeFactory(object): - @classmethod - def declarations(cls, extra): - return extra + _meta = Meta @classmethod def _generate_next_sequence(cls): return 1 + return FakeFactory + + def test_empty(self): + """Tests building attributes from an empty definition.""" + + FakeFactory = self.make_fake_factory({}) ab = containers.AttributeBuilder(FakeFactory) self.assertEqual({}, ab.build(create=False)) def test_factory_defined(self): - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': 1} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 - + FakeFactory = self.make_fake_factory({'one': 1}) ab = containers.AttributeBuilder(FakeFactory) + self.assertEqual({'one': 1}, ab.build(create=False)) def test_extended(self): - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': 1} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 - + FakeFactory = self.make_fake_factory({'one': 1}) ab = containers.AttributeBuilder(FakeFactory, {'two': 2}) self.assertEqual({'one': 1, 'two': 2}, ab.build(create=False)) def test_overridden(self): - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': 1} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 - + FakeFactory = self.make_fake_factory({'one': 1}) ab = containers.AttributeBuilder(FakeFactory, {'one': 2}) self.assertEqual({'one': 2}, ab.build(create=False)) def test_factory_defined_sequence(self): seq = declarations.Sequence(lambda n: 'xx%d' % n) - - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': seq} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 + FakeFactory = self.make_fake_factory({'one': seq}) ab = containers.AttributeBuilder(FakeFactory) self.assertEqual({'one': 'xx1'}, ab.build(create=False)) def test_additionnal_sequence(self): seq = declarations.Sequence(lambda n: 'xx%d' % n) - - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': 1} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 + FakeFactory = self.make_fake_factory({'one': 1}) ab = containers.AttributeBuilder(FakeFactory, extra={'two': seq}) self.assertEqual({'one': 1, 'two': 'xx1'}, ab.build(create=False)) @@ -201,34 +160,14 @@ class AttributeBuilderTestCase(unittest.TestCase): def test_replaced_sequence(self): seq = declarations.Sequence(lambda n: 'xx%d' % n) seq2 = declarations.Sequence(lambda n: 'yy%d' % n) - - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': seq} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 + FakeFactory = self.make_fake_factory({'one': seq}) ab = containers.AttributeBuilder(FakeFactory, extra={'one': seq2}) self.assertEqual({'one': 'yy1'}, ab.build(create=False)) def test_lazy_function(self): lf = declarations.LazyFunction(int) - - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': 1, 'two': lf} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 + FakeFactory = self.make_fake_factory({'one': 1, 'two': lf}) ab = containers.AttributeBuilder(FakeFactory) self.assertEqual({'one': 1, 'two': 0}, ab.build(create=False)) @@ -241,17 +180,7 @@ class AttributeBuilderTestCase(unittest.TestCase): def test_lazy_attribute(self): la = declarations.LazyAttribute(lambda a: a.one * 2) - - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': 1, 'two': la} - d.update(extra) - return d - - @classmethod - def _generate_next_sequence(cls): - return 1 + FakeFactory = self.make_fake_factory({'one': 1, 'two': la}) ab = containers.AttributeBuilder(FakeFactory) self.assertEqual({'one': 1, 'two': 2}, ab.build(create=False)) @@ -267,18 +196,12 @@ class AttributeBuilderTestCase(unittest.TestCase): pass sf = declarations.SubFactory(FakeInnerFactory) - - class FakeFactory(object): - @classmethod - def declarations(cls, extra): - d = {'one': sf, 'two': 2} - d.update(extra) - return d + FakeFactory = self.make_fake_factory({'one': sf, 'two': 2}) ab = containers.AttributeBuilder(FakeFactory, {'one__blah': 1, 'two__bar': 2}) self.assertTrue(ab.has_subfields(sf)) self.assertEqual(['one'], list(ab._subfields.keys())) - self.assertEqual(2, ab._attrs['two__bar']) + self.assertEqual(2, ab._declarations['two__bar']) def test_sub_factory(self): pass diff --git a/tests/test_using.py b/tests/test_using.py index 3ef5403..67db3bc 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -40,6 +40,15 @@ class TestObject(object): self.four = four self.five = five + def as_dict(self): + return dict( + one=self.one, + two=self.two, + three=self.three, + four=self.four, + five=self.five, + ) + class FakeModel(object): @classmethod -- cgit v1.2.3