summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polyconseil.fr>2011-09-06 12:04:51 +0200
committerRaphaël Barrois <raphael.barrois@polyconseil.fr>2011-09-06 12:04:51 +0200
commit5c770cb1918d8067db69d3c880a9a6ad1d31ddb8 (patch)
treea1fa05eb784a8ddb1554e12e0ecbae42bad08c67
parent2821684dec8a71e93728fc3e036f83592324b518 (diff)
downloadfactory-boy-5c770cb1918d8067db69d3c880a9a6ad1d31ddb8.tar
factory-boy-5c770cb1918d8067db69d3c880a9a6ad1d31ddb8.tar.gz
Huge refactoring and code cleanup.
Signed-off-by: Raphaël Barrois <raphael.barrois@polyconseil.fr>
-rw-r--r--factory/base.py24
-rw-r--r--factory/containers.py142
-rw-r--r--factory/declarations.py34
3 files changed, 117 insertions, 83 deletions
diff --git a/factory/base.py b/factory/base.py
index 380e015..34c8470 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -21,7 +21,7 @@
import re
import sys
-from containers import DeclarationsHolder, ObjectParamsWrapper, OrderedDeclarationDict, StubObject
+from containers import AttributeBuilder, DeclarationDict, ObjectParamsWrapper, StubObject
from declarations import OrderedDeclaration
# Strategies
@@ -73,14 +73,14 @@ class BaseFactoryMetaClass(type):
# If this isn't a subclass of Factory, don't do anything special.
return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
- declarations = DeclarationsHolder()
- for base in parent_factories:
- declarations.update_base(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
+ declarations = DeclarationDict()
- non_factory_attrs = declarations.update_base(attrs)
+ #Add parent declarations in reverse order.
+ for base in reversed(parent_factories):
+ declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {}))
+ non_factory_attrs = declarations.update_with_public(attrs)
non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
-
non_factory_attrs.update(extra_attrs)
return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs)
@@ -167,21 +167,21 @@ class BaseFactory(object):
return next_sequence
@classmethod
- def attributes(cls, create=False, **kwargs):
+ def attributes(cls, create=False, extra=None):
"""Build a dict of attribute values, respecting declaration order.
The process is:
- Handle 'orderless' attributes, overriding defaults with provided
kwargs when applicable
- Handle ordered attributes, overriding them with provided kwargs when
- applicable; the current list of computed attributes is available for
+ applicable; the current list of computed attributes is available
to the currently processed object.
"""
- return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).build_attributes(cls, create, kwargs)
+ return AttributeBuilder(cls, extra).build(create)
@classmethod
- def declarations(cls):
- return DeclarationsHolder(getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS))
+ def declarations(cls, extra_defs=None):
+ return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).copy(extra_defs)
@classmethod
def build(cls, **kwargs):
@@ -198,11 +198,13 @@ class BaseFactory(object):
setattr(stub_object, name, value)
return stub_object
+
class StubFactory(BaseFactory):
__metaclass__ = BaseFactoryMetaClass
default_strategy = STUB_STRATEGY
+
class Factory(BaseFactory):
"""Factory base with build and create support.
diff --git a/factory/containers.py b/factory/containers.py
index 48a4015..0620c82 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -47,7 +47,7 @@ class ObjectParamsWrapper(object):
raise AttributeError("The param '{0}' does not exist. Perhaps your declarations are out of order?".format(name))
-class OrderedDeclarationDict(object):
+class OrderedDict(object):
def __init__(self, **kwargs):
self._order = {}
self._values = {}
@@ -88,18 +88,51 @@ class OrderedDeclarationDict(object):
for i in order:
yield self._order[i]
-class DeclarationsHolder(object):
- """Holds all declarations, ordered and unordered."""
- def __init__(self, defaults=None):
- if not defaults:
- defaults = {}
- self._ordered = OrderedDeclarationDict()
+class DeclarationDict(object):
+ """Holds a dict of declarations, keeping OrderedDeclaration at the end."""
+ def __init__(self, extra=None):
+ if not extra:
+ extra = {}
+ self._ordered = OrderedDict()
self._unordered = {}
- self.update_base(defaults)
+ self.update(extra)
+
+ def __setitem__(self, key, value):
+ if key in self:
+ del self[key]
+
+ if isinstance(value, OrderedDeclaration):
+ self._ordered[key] = value
+ else:
+ self._unordered[key] = value
+
+ def __getitem__(self, key):
+ """Try in _unordered first, then in _ordered."""
+ try:
+ return self._unordered[key]
+ except KeyError:
+ return self._ordered[key]
- def update_base(self, attrs):
- """Updates the DeclarationsHolder from a class definition.
+ def __delitem__(self, key):
+ if key in self._unordered:
+ del self._unordered[key]
+ else:
+ del self._ordered[key]
+
+ def pop(self, key, *args):
+ assert len(args) <= 1
+ try:
+ return self._unordered.pop(key)
+ except KeyError:
+ return self._ordered.pop(key, *args)
+
+ def update(self, d):
+ for k in d:
+ self[k] = d[k]
+
+ def update_with_public(self, d):
+ """Updates the DeclarationDict from a class definition dict.
Takes into account all public attributes and OrderedDeclaration
instances; ignores all attributes starting with '_'.
@@ -107,23 +140,25 @@ class DeclarationsHolder(object):
Returns a dict containing all remaining elements.
"""
remaining = {}
- for key, value in attrs.iteritems():
- if isinstance(value, OrderedDeclaration):
- self._ordered[key] = value
- elif not key.startswith('_'):
- self._unordered[key] = value
+ for k, v in d.iteritems():
+ if k.startswith('_') and not isinstance(v, OrderedDeclaration):
+ remaining[k] = v
else:
- remaining[key] = value
+ self[k] = v
return remaining
+ def copy(self, extra=None):
+ new = DeclarationDict()
+ new.update(self)
+ if extra:
+ new.update(extra)
+ return new
+
def __contains__(self, key):
- return key in self._ordered or key in self._unordered
+ return key in self._unordered or key in self._ordered
- def __getitem__(self, key):
- try:
- return self._unordered[key]
- except KeyError:
- return self._ordered[key]
+ def items(self):
+ return list(self.iteritems())
def iteritems(self):
for pair in self._unordered.iteritems():
@@ -131,58 +166,47 @@ class DeclarationsHolder(object):
for pair in self._ordered.iteritems():
yield pair
- def items(self):
- return list(self.iteritems())
+ def __iter__(self):
+ for k in self._unordered:
+ yield k
+ for k in self._ordered:
+ yield k
- def _extract_sub_fields(self, base):
- """Extract all subfields declaration from a given dict-like object.
- Will compare with attributes declared in the current object, and
- will pop() values from the given base.
- """
- sub_fields = dict()
+class AttributeBuilder(object):
+ """Builds attributes from a factory and extra data."""
+
+ def __init__(self, factory, extra=None):
+ if not extra:
+ extra = {}
+ self.factory = factory
+ self._attrs = factory.declarations(extra)
+ self._subfield = self._extract_subfields()
- for key in list(base):
+ def _extract_subfields(self):
+ sub_fields = {}
+ for key in list(self._attrs):
if ATTR_SPLITTER in key:
cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
- if cls_name in self:
- sub_fields.setdefault(cls_name, {})[attr_name] = base.pop(key)
+ if cls_name in self._attrs:
+ sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key)
return sub_fields
- def build_attributes(self, factory, create=False, extra=None):
- """Build the list of attributes based on class attributes."""
- if not extra:
- extra = {}
-
- factory.sequence = factory._generate_next_sequence()
+ def build(self, create):
+ self.factory.sequence = self.factory._generate_next_sequence()
attributes = {}
- sub_fields = {}
- for base in (self._unordered, self._ordered, extra):
- sub_fields.update(self._extract_sub_fields(base))
-
- def make_value(key, val):
- if key in extra:
- val = extra.pop(key)
+ for key, val in self._attrs.iteritems():
if isinstance(val, SubFactory):
- new_val = val.evaluate(factory, create, sub_fields.get(key, {}))
+ val = val.evaluate(self.factory, create, self._subfield.get(key, {}))
elif isinstance(val, OrderedDeclaration):
wrapper = ObjectParamsWrapper(attributes)
- new_val = val.evaluate(factory, wrapper)
- else:
- new_val = val
+ val = val.evaluate(self.factory, wrapper)
+ attributes[key] = val
- return new_val
-
- # For fields in _unordered, use the value from extra if any; otherwise,
- # use the default value.
- for key, value in self._unordered.iteritems():
- attributes[key] = make_value(key, value)
- for key, value in self._ordered.iteritems():
- attributes[key] = make_value(key, value)
- attributes.update(extra)
return attributes
+
class StubObject(object):
"""A generic container."""
diff --git a/factory/declarations.py b/factory/declarations.py
index 898493c..a1e9102 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -48,22 +48,24 @@ class OrderedDeclaration(object):
def __init__(self):
self.order = GlobalCounter.step()
- def evaluate(self, factory, attributes):
+ def evaluate(self, factory, obj):
"""Evaluate this declaration.
Args:
factory: The factory this declaration was defined in.
+ obj: The object holding currently computed attributes
attributes: The attributes created by the unordered and ordered declarations up to this point."""
raise NotImplementedError('This is an abstract method')
+
class LazyAttribute(OrderedDeclaration):
def __init__(self, function):
super(LazyAttribute, self).__init__()
self.function = function
- def evaluate(self, factory, attributes):
- return self.function(attributes)
+ def evaluate(self, factory, obj):
+ return self.function(obj)
class SelfAttribute(OrderedDeclaration):
@@ -71,8 +73,8 @@ class SelfAttribute(OrderedDeclaration):
super(SelfAttribute, self).__init__()
self.attribute_name = attribute_name
- def evaluate(self, factory, attributes):
- return getattr(attributes, self.attribute_name)
+ def evaluate(self, factory, obj):
+ return getattr(obj, self.attribute_name)
class Sequence(OrderedDeclaration):
@@ -81,12 +83,14 @@ class Sequence(OrderedDeclaration):
self.function = function
self.type = type
- def evaluate(self, factory, attributes):
+ def evaluate(self, factory, obj):
return self.function(self.type(factory.sequence))
+
class LazyAttributeSequence(Sequence):
- def evaluate(self, factory, attributes):
- return self.function(attributes, self.type(factory.sequence))
+ def evaluate(self, factory, obj):
+ return self.function(obj, self.type(factory.sequence))
+
class SubFactory(OrderedDeclaration):
"""Base class for attributes based upon a sub-factory.
@@ -98,20 +102,24 @@ class SubFactory(OrderedDeclaration):
def __init__(self, factory, **kwargs):
super(SubFactory, self).__init__()
- self.defaults = factory.declarations()
- self.defaults.update_base(kwargs)
+ self.defaults = kwargs
self.factory = factory
- def evaluate(self, factory, create, attributes):
+ def evaluate(self, factory, create, extra):
"""Evaluate the current definition and fill its attributes.
Uses attributes definition in the following order:
- attributes defined in the wrapped factory class
- values defined when defining the SubFactory
- - additional valued defined in attributes
+ - additional values defined in attributes
"""
- attrs = self.defaults.build_attributes(self.factory, create, attributes)
+ defaults = dict(self.defaults)
+ if extra:
+ defaults.update(extra)
+
+ attrs = self.factory.attributes(create, defaults)
+
if create:
return self.factory.create(**attrs)
else: