summaryrefslogtreecommitdiff
path: root/factory/containers.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/containers.py')
-rw-r--r--factory/containers.py114
1 files changed, 54 insertions, 60 deletions
diff --git a/factory/containers.py b/factory/containers.py
index 2f92f62..ee2ad82 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -21,12 +21,8 @@
# THE SOFTWARE.
-from factory import declarations
-
-
-#: String for splitting an attribute name into a
-#: (subfactory_name, subfactory_field) tuple.
-ATTR_SPLITTER = '__'
+from . import declarations
+from . import utils
class CyclicDefinitionError(Exception):
@@ -66,7 +62,7 @@ class LazyStub(object):
def __str__(self):
return '<LazyStub for %s with %s>' % (
- self.__target_class.__name__, self.__attrs.keys())
+ self.__target_class.__name__, list(self.__attrs.keys()))
def __fill__(self):
"""Fill this LazyStub, computing values of all defined attributes.
@@ -128,7 +124,7 @@ class DeclarationDict(dict):
return False
elif isinstance(value, declarations.OrderedDeclaration):
return True
- return (not name.startswith("_"))
+ return (not name.startswith("_") and not name.startswith("FACTORY_"))
def update_with_public(self, d):
"""Updates the DeclarationDict from a class definition dict.
@@ -140,7 +136,7 @@ class DeclarationDict(dict):
Returns a dict containing all remaining elements.
"""
remaining = {}
- for k, v in d.iteritems():
+ for k, v in d.items():
if self.is_declaration(k, v):
self[k] = v
else:
@@ -153,45 +149,29 @@ class DeclarationDict(dict):
Args:
extra (dict): additional attributes to include in the copy.
"""
- new = DeclarationDict()
+ new = self.__class__()
new.update(self)
if extra:
new.update(extra)
return new
+class PostGenerationDeclarationDict(DeclarationDict):
+ """Alternate DeclarationDict for PostGenerationDeclaration."""
+
+ def is_declaration(self, name, value):
+ """Captures instances of PostGenerationDeclaration."""
+ return isinstance(value, declarations.PostGenerationDeclaration)
+
+
class LazyValue(object):
"""Some kind of "lazy evaluating" object."""
- def evaluate(self, obj, containers=()):
+ def evaluate(self, obj, containers=()): # pragma: no cover
"""Compute the value, using the given object."""
raise NotImplementedError("This is an abstract method.")
-class SubFactoryWrapper(LazyValue):
- """Lazy wrapper around a SubFactory.
-
- Attributes:
- subfactory (declarations.SubFactory): the SubFactory being wrapped
- subfields (DeclarationDict): Default values to override when evaluating
- the SubFactory
- create (bool): whether to 'create' or 'build' the SubFactory.
- """
-
- def __init__(self, subfactory, subfields, create, *args, **kwargs):
- super(SubFactoryWrapper, self).__init__(*args, **kwargs)
- self.subfactory = subfactory
- self.subfields = subfields
- self.create = create
-
- def evaluate(self, obj, containers=()):
- expanded_containers = (obj,)
- if containers:
- expanded_containers += tuple(containers)
- return self.subfactory.evaluate(self.create, self.subfields,
- expanded_containers)
-
-
class OrderedDeclarationWrapper(LazyValue):
"""Lazy wrapper around an OrderedDeclaration.
@@ -202,10 +182,12 @@ class OrderedDeclarationWrapper(LazyValue):
declaration
"""
- def __init__(self, declaration, sequence, *args, **kwargs):
- super(OrderedDeclarationWrapper, self).__init__(*args, **kwargs)
+ def __init__(self, declaration, sequence, create, extra=None, **kwargs):
+ super(OrderedDeclarationWrapper, self).__init__(**kwargs)
self.declaration = declaration
self.sequence = sequence
+ self.create = create
+ self.extra = extra
def evaluate(self, obj, containers=()):
"""Lazily evaluate the attached OrderedDeclaration.
@@ -215,7 +197,14 @@ class OrderedDeclarationWrapper(LazyValue):
containers (object list): the chain of containers of the object
being built, its immediate holder being first.
"""
- return self.declaration.evaluate(self.sequence, obj, containers)
+ return self.declaration.evaluate(self.sequence, obj,
+ create=self.create,
+ extra=self.extra,
+ containers=containers,
+ )
+
+ def __repr__(self):
+ return '<%s for %r>' % (self.__class__.__name__, self.declaration)
class AttributeBuilder(object):
@@ -236,38 +225,43 @@ class AttributeBuilder(object):
extra = {}
self.factory = factory
- self._containers = extra.pop('__containers', None)
+ self._containers = extra.pop('__containers', ())
self._attrs = factory.declarations(extra)
- self._subfields = self._extract_subfields()
-
- def _extract_subfields(self):
- """Extract the subfields from the declarations list."""
- sub_fields = {}
- for key in list(self._attrs):
- if ATTR_SPLITTER in key:
- # Trying to define a default of a subfactory
- cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
- if cls_name in self._attrs:
- sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key)
- return sub_fields
-
- def build(self, create):
+
+ attrs_with_subfields = [
+ k for k, v in self._attrs.items()
+ if self.has_subfields(v)]
+
+ self._subfields = utils.multi_extract_dict(
+ attrs_with_subfields, self._attrs)
+
+ def has_subfields(self, value):
+ return isinstance(value, declarations.ParameteredAttribute)
+
+ def build(self, create, force_sequence=None):
"""Build a dictionary of attributes.
Args:
create (bool): whether to 'build' or 'create' the subfactories.
+ force_sequence (int or None): if set to an int, use this value for
+ the sequence counter; don't advance the related counter.
"""
# Setup factory sequence.
- self.factory.sequence = self.factory._generate_next_sequence()
+ if force_sequence is None:
+ sequence = self.factory._generate_next_sequence()
+ else:
+ sequence = force_sequence
# Parse attribute declarations, wrapping SubFactory and
# OrderedDeclaration.
wrapped_attrs = {}
- for k, v in self._attrs.iteritems():
- if isinstance(v, declarations.SubFactory):
- v = SubFactoryWrapper(v, self._subfields.get(k, {}), create)
- elif isinstance(v, declarations.OrderedDeclaration):
- v = OrderedDeclarationWrapper(v, self.factory.sequence)
+ for k, v in self._attrs.items():
+ if isinstance(v, declarations.OrderedDeclaration):
+ v = OrderedDeclarationWrapper(v,
+ sequence=sequence,
+ create=create,
+ extra=self._subfields.get(k, {}),
+ )
wrapped_attrs[k] = v
stub = LazyStub(wrapped_attrs, containers=self._containers,