diff options
Diffstat (limited to 'factory/containers.py')
-rw-r--r-- | factory/containers.py | 108 |
1 files changed, 99 insertions, 9 deletions
diff --git a/factory/containers.py b/factory/containers.py index c591988..d3f39c4 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -117,6 +117,69 @@ class LazyStub(object): raise AttributeError('Setting of object attributes is not allowed') +class DeclarationStack(object): + """An ordered stack of declarations. + + This is intended to handle declaration precedence among different mutating layers. + """ + def __init__(self, ordering): + self.ordering = ordering + self.layers = dict((name, {}) for name in self.ordering) + + def __getitem__(self, key): + return self.layers[key] + + def __setitem__(self, key, value): + assert key in self.ordering + self.layers[key] = value + + def current(self): + """Retrieve the current, flattened declarations dict.""" + result = {} + for layer in self.ordering: + result.update(self.layers[layer]) + return result + + +class ParameterResolver(object): + """Resolve a factory's parameter declarations.""" + def __init__(self, parameters, deps): + self.parameters = parameters + self.deps = deps + self.declaration_stack = None + + self.resolved = set() + + def resolve_one(self, name): + """Compute one field is needed, taking dependencies into accounts.""" + if name in self.resolved: + return + + for dep in self.deps.get(name, ()): + self.resolve_one(dep) + + self.compute(name) + self.resolved.add(name) + + def compute(self, name): + """Actually compute the value for a given name.""" + value = self.parameters[name] + if isinstance(value, declarations.ComplexParameter): + overrides = value.compute(name, self.declaration_stack.current()) + else: + overrides = {name: value} + self.declaration_stack['overrides'].update(overrides) + + def resolve(self, declaration_stack): + """Resolve parameters for a given declaration stack. + + Modifies the stack in-place. + """ + self.declaration_stack = declaration_stack + for name in self.parameters: + self.resolve_one(name) + + class LazyValue(object): """Some kind of "lazy evaluating" object.""" @@ -125,7 +188,7 @@ class LazyValue(object): raise NotImplementedError("This is an abstract method.") -class OrderedDeclarationWrapper(LazyValue): +class DeclarationWrapper(LazyValue): """Lazy wrapper around an OrderedDeclaration. Attributes: @@ -136,7 +199,7 @@ class OrderedDeclarationWrapper(LazyValue): """ def __init__(self, declaration, sequence, create, extra=None, **kwargs): - super(OrderedDeclarationWrapper, self).__init__(**kwargs) + super(DeclarationWrapper, self).__init__(**kwargs) self.declaration = declaration self.sequence = sequence self.create = create @@ -166,7 +229,7 @@ class AttributeBuilder(object): Attributes: factory (base.Factory): the Factory for which attributes are being built - _attrs (DeclarationDict): the attribute declarations for the factory + _declarations (DeclarationDict): the attribute declarations for the factory _subfields (dict): dict mapping an attribute name to a dict of overridden default values for the related SubFactory. """ @@ -179,20 +242,47 @@ class AttributeBuilder(object): self.factory = factory self._containers = extra.pop('__containers', ()) - self._attrs = factory.declarations(extra) + + initial_declarations = dict(factory._meta.declarations) self._log_ctx = log_ctx - initial_declarations = factory.declarations({}) + # Parameters + # ---------- + self._declarations = self.merge_declarations(initial_declarations, extra) + + # Subfields + # --------- + attrs_with_subfields = [ k for k, v in initial_declarations.items() - if self.has_subfields(v)] + if self.has_subfields(v) + ] + # Extract subfields; THIS MODIFIES self._declarations. self._subfields = utils.multi_extract_dict( - attrs_with_subfields, self._attrs) + attrs_with_subfields, self._declarations) def has_subfields(self, value): return isinstance(value, declarations.ParameteredAttribute) + def merge_declarations(self, initial, extra): + """Compute the final declarations, taking into account paramter-based overrides.""" + # Precedence order: + # - Start with class-level declarations + # - Add overrides from parameters + # - Finally, use callsite-level declarations & values + declaration_stack = DeclarationStack(['initial', 'overrides', 'extra']) + declaration_stack['initial'] = initial.copy() + declaration_stack['extra'] = extra.copy() + + # Actually compute the final stack + resolver = ParameterResolver( + parameters=self.factory._meta.parameters, + deps=self.factory._meta.parameters_dependencies, + ) + resolver.resolve(declaration_stack) + return declaration_stack.current() + def build(self, create, force_sequence=None): """Build a dictionary of attributes. @@ -210,9 +300,9 @@ class AttributeBuilder(object): # Parse attribute declarations, wrapping SubFactory and # OrderedDeclaration. wrapped_attrs = {} - for k, v in self._attrs.items(): + for k, v in self._declarations.items(): if isinstance(v, declarations.OrderedDeclaration): - v = OrderedDeclarationWrapper(v, + v = DeclarationWrapper(v, sequence=sequence, create=create, extra=self._subfields.get(k, {}), |