diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2016-04-02 16:13:34 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2016-04-02 17:11:46 +0200 |
commit | c77962de7dd7206ccab85b44da173832acbf5921 (patch) | |
tree | 0913b772d5181f654d5ce824753186a2252e9691 /factory/base.py | |
parent | eea28cce1544021f3d152782c9932a20402d6240 (diff) | |
download | factory-boy-c77962de7dd7206ccab85b44da173832acbf5921.tar factory-boy-c77962de7dd7206ccab85b44da173832acbf5921.tar.gz |
Add a new Params section to factories.
This handles parameters that alter the declarations of a factory.
A few technical notes:
- A parameter's outcome may alter other parameters
- In order to fix that, we perform a (simple) cyclic definition
detection at class declaration time.
- Parameters may only be either naked values or ComplexParameter
subclasses
- Parameters are never passed to the underlying class
Diffstat (limited to 'factory/base.py')
-rw-r--r-- | factory/base.py | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/factory/base.py b/factory/base.py index 1ddb742..282e3b1 100644 --- a/factory/base.py +++ b/factory/base.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import collections import logging from . import containers @@ -92,6 +93,7 @@ class FactoryMetaClass(type): base_factory = None attrs_meta = attrs.pop('Meta', None) + attrs_params = attrs.pop('Params', None) base_meta = resolve_attribute('_meta', bases) options_class = resolve_attribute('_options_class', bases, FactoryOptions) @@ -106,6 +108,7 @@ class FactoryMetaClass(type): meta=attrs_meta, base_meta=base_meta, base_factory=base_factory, + params=attrs_params, ) return new_class @@ -148,6 +151,8 @@ class FactoryOptions(object): self.base_factory = None self.declarations = {} self.postgen_declarations = {} + self.parameters = {} + self.parameters_dependencies = {} def _build_default_options(self): """"Provide the default value for all allowed fields. @@ -186,7 +191,7 @@ class FactoryOptions(object): % (self.factory, ','.join(sorted(meta_attrs.keys())))) def contribute_to_class(self, factory, - meta=None, base_meta=None, base_factory=None): + meta=None, base_meta=None, base_factory=None, params=None): self.factory = factory self.base_factory = base_factory @@ -204,6 +209,7 @@ class FactoryOptions(object): continue self.declarations.update(parent._meta.declarations) self.postgen_declarations.update(parent._meta.postgen_declarations) + self.parameters.update(parent._meta.parameters) for k, v in vars(self.factory).items(): if self._is_declaration(k, v): @@ -211,6 +217,13 @@ class FactoryOptions(object): if self._is_postgen_declaration(k, v): self.postgen_declarations[k] = v + if params is not None: + for k, v in vars(params).items(): + if not k.startswith('_'): + self.parameters[k] = v + + self.parameters_dependencies = self._compute_parameter_dependencies(self.parameters) + def _get_counter_reference(self): """Identify which factory should be used for a shared counter.""" @@ -242,6 +255,32 @@ class FactoryOptions(object): """Captures instances of PostGenerationDeclaration.""" return isinstance(value, declarations.PostGenerationDeclaration) + def _compute_parameter_dependencies(self, parameters): + """Find out in what order parameters should be called.""" + # Warning: parameters only provide reverse dependencies; we reverse them into standard dependencies. + # deep_revdeps: set of fields a field depend indirectly upon + deep_revdeps = collections.defaultdict(set) + # Actual, direct dependencies + deps = collections.defaultdict(set) + + for name, parameter in parameters.items(): + if isinstance(parameter, declarations.ComplexParameter): + field_revdeps = parameter.get_revdeps(parameters) + if not field_revdeps: + continue + deep_revdeps[name] = set.union(*(deep_revdeps[dep] for dep in field_revdeps)) + deep_revdeps[name] |= set(field_revdeps) + for dep in field_revdeps: + deps[dep].add(name) + + # Check for cyclical dependencies + cyclic = [name for name, field_deps in deep_revdeps.items() if name in field_deps] + if cyclic: + raise errors.CyclicDefinitionError( + "Cyclic definition detected on %s' Params around %s" + % (self.factory, ', '.join(cyclic))) + return deps + def __str__(self): return "<%s for %s>" % (self.__class__.__name__, self.factory.__class__.__name__) @@ -439,6 +478,9 @@ class BaseFactory(object): # Remove 'hidden' arguments. for arg in cls._meta.exclude: del kwargs[arg] + # Remove parameters, if defined + for arg in cls._meta.parameters: + kwargs.pop(arg, None) # Extract *args from **kwargs args = tuple(kwargs.pop(key) for key in cls._meta.inline_args) |