summaryrefslogtreecommitdiff
path: root/factory/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/base.py')
-rw-r--r--factory/base.py75
1 files changed, 51 insertions, 24 deletions
diff --git a/factory/base.py b/factory/base.py
index 0f2af59..282e3b1 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -20,10 +20,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import collections
import logging
from . import containers
from . import declarations
+from . import errors
from . import utils
logger = logging.getLogger('factory.generate')
@@ -35,22 +37,6 @@ STUB_STRATEGY = 'stub'
-class FactoryError(Exception):
- """Any exception raised by factory_boy."""
-
-
-class AssociatedClassError(FactoryError):
- """Exception for Factory subclasses lacking Meta.model."""
-
-
-class UnknownStrategy(FactoryError):
- """Raised when a factory uses an unknown strategy."""
-
-
-class UnsupportedStrategy(FactoryError):
- """Raised when trying to use a strategy on an incompatible Factory."""
-
-
# Factory metaclasses
def get_factory_bases(bases):
@@ -82,7 +68,7 @@ class FactoryMetaClass(type):
elif cls._meta.strategy == STUB_STRATEGY:
return cls.stub(**kwargs)
else:
- raise UnknownStrategy('Unknown Meta.strategy: {0}'.format(
+ raise errors.UnknownStrategy('Unknown Meta.strategy: {0}'.format(
cls._meta.strategy))
def __new__(mcs, class_name, bases, attrs):
@@ -107,6 +93,7 @@ class FactoryMetaClass(type):
base_factory = None
attrs_meta = attrs.pop('Meta', None)
+ attrs_params = attrs.pop('Params', None)
base_meta = resolve_attribute('_meta', bases)
options_class = resolve_attribute('_options_class', bases, FactoryOptions)
@@ -121,6 +108,7 @@ class FactoryMetaClass(type):
meta=attrs_meta,
base_meta=base_meta,
base_factory=base_factory,
+ params=attrs_params,
)
return new_class
@@ -163,6 +151,8 @@ class FactoryOptions(object):
self.base_factory = None
self.declarations = {}
self.postgen_declarations = {}
+ self.parameters = {}
+ self.parameters_dependencies = {}
def _build_default_options(self):
""""Provide the default value for all allowed fields.
@@ -201,7 +191,7 @@ class FactoryOptions(object):
% (self.factory, ','.join(sorted(meta_attrs.keys()))))
def contribute_to_class(self, factory,
- meta=None, base_meta=None, base_factory=None):
+ meta=None, base_meta=None, base_factory=None, params=None):
self.factory = factory
self.base_factory = base_factory
@@ -219,6 +209,7 @@ class FactoryOptions(object):
continue
self.declarations.update(parent._meta.declarations)
self.postgen_declarations.update(parent._meta.postgen_declarations)
+ self.parameters.update(parent._meta.parameters)
for k, v in vars(self.factory).items():
if self._is_declaration(k, v):
@@ -226,6 +217,13 @@ class FactoryOptions(object):
if self._is_postgen_declaration(k, v):
self.postgen_declarations[k] = v
+ if params is not None:
+ for k, v in vars(params).items():
+ if not k.startswith('_'):
+ self.parameters[k] = v
+
+ self.parameters_dependencies = self._compute_parameter_dependencies(self.parameters)
+
def _get_counter_reference(self):
"""Identify which factory should be used for a shared counter."""
@@ -257,6 +255,32 @@ class FactoryOptions(object):
"""Captures instances of PostGenerationDeclaration."""
return isinstance(value, declarations.PostGenerationDeclaration)
+ def _compute_parameter_dependencies(self, parameters):
+ """Find out in what order parameters should be called."""
+ # Warning: parameters only provide reverse dependencies; we reverse them into standard dependencies.
+ # deep_revdeps: set of fields a field depend indirectly upon
+ deep_revdeps = collections.defaultdict(set)
+ # Actual, direct dependencies
+ deps = collections.defaultdict(set)
+
+ for name, parameter in parameters.items():
+ if isinstance(parameter, declarations.ComplexParameter):
+ field_revdeps = parameter.get_revdeps(parameters)
+ if not field_revdeps:
+ continue
+ deep_revdeps[name] = set.union(*(deep_revdeps[dep] for dep in field_revdeps))
+ deep_revdeps[name] |= set(field_revdeps)
+ for dep in field_revdeps:
+ deps[dep].add(name)
+
+ # Check for cyclical dependencies
+ cyclic = [name for name, field_deps in deep_revdeps.items() if name in field_deps]
+ if cyclic:
+ raise errors.CyclicDefinitionError(
+ "Cyclic definition detected on %s' Params around %s"
+ % (self.factory, ', '.join(cyclic)))
+ return deps
+
def __str__(self):
return "<%s for %s>" % (self.__class__.__name__, self.factory.__class__.__name__)
@@ -296,12 +320,12 @@ class BaseFactory(object):
"""Factory base support for sequences, attributes and stubs."""
# Backwards compatibility
- UnknownStrategy = UnknownStrategy
- UnsupportedStrategy = UnsupportedStrategy
+ UnknownStrategy = errors.UnknownStrategy
+ UnsupportedStrategy = errors.UnsupportedStrategy
def __new__(cls, *args, **kwargs):
"""Would be called if trying to instantiate the class."""
- raise FactoryError('You cannot instantiate BaseFactory')
+ raise errors.FactoryError('You cannot instantiate BaseFactory')
_meta = FactoryOptions()
@@ -454,6 +478,9 @@ class BaseFactory(object):
# Remove 'hidden' arguments.
for arg in cls._meta.exclude:
del kwargs[arg]
+ # Remove parameters, if defined
+ for arg in cls._meta.parameters:
+ kwargs.pop(arg, None)
# Extract *args from **kwargs
args = tuple(kwargs.pop(key) for key in cls._meta.inline_args)
@@ -477,7 +504,7 @@ class BaseFactory(object):
attrs (dict): attributes to use for generating the object
"""
if cls._meta.abstract:
- raise FactoryError(
+ raise errors.FactoryError(
"Cannot generate instances of abstract factory %(f)s; "
"Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
"is either not set or False." % dict(f=cls.__name__))
@@ -680,7 +707,7 @@ Factory = FactoryMetaClass('Factory', (BaseFactory,), {
# Backwards compatibility
-Factory.AssociatedClassError = AssociatedClassError # pylint: disable=W0201
+Factory.AssociatedClassError = errors.AssociatedClassError # pylint: disable=W0201
class StubFactory(Factory):
@@ -695,7 +722,7 @@ class StubFactory(Factory):
@classmethod
def create(cls, **kwargs):
- raise UnsupportedStrategy()
+ raise errors.UnsupportedStrategy()
class BaseDictFactory(Factory):