summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/base.py40
-rw-r--r--factory/containers.py52
2 files changed, 59 insertions, 33 deletions
diff --git a/factory/base.py b/factory/base.py
index 013b58f..e78ec6e 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -21,7 +21,7 @@
import re
import sys
-from containers import ObjectParamsWrapper, OrderedDeclarationDict, StubObject
+from containers import DeclarationsHolder, ObjectParamsWrapper, OrderedDeclarationDict, StubObject
from declarations import OrderedDeclaration
# Strategies
@@ -40,8 +40,7 @@ FACTORY_CLASS_DECLARATION = 'FACTORY_FOR'
# Factory class attributes
-CLASS_ATTRIBUTE_ORDERED_DECLARATIONS = '_ordered_declarations'
-CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS = '_unordered_declarations'
+CLASS_ATTRIBUTE_DECLARATIONS = '_declarations'
CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class'
# Factory metaclasses
@@ -78,21 +77,12 @@ class BaseFactoryMetaClass(type):
# If this isn't a subclass of Factory, don't do anything special.
return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
- ordered_declarations = getattr(base, CLASS_ATTRIBUTE_ORDERED_DECLARATIONS,
- OrderedDeclarationDict())
- unordered_declarations = getattr(base, CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS, {})
+ declarations = getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, DeclarationsHolder())
+ attrs = declarations.update_base(attrs)
- for name in list(attrs):
- if isinstance(attrs[name], OrderedDeclaration):
- ordered_declarations[name] = attrs.pop(name)
- elif not name.startswith('_'):
- unordered_declarations[name] = attrs.pop(name)
+ attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations
- attrs[CLASS_ATTRIBUTE_ORDERED_DECLARATIONS] = ordered_declarations
- attrs[CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS] = unordered_declarations
-
- for name, value in extra_attrs.iteritems():
- attrs[name] = value
+ attrs.update(extra_attrs)
return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs)
@@ -176,23 +166,7 @@ class BaseFactory(object):
attributes = {}
cls.sequence = cls._generate_next_sequence()
- for name, value in getattr(cls, CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS).iteritems():
- if name in kwargs:
- attributes[name] = kwargs.pop(name)
- else:
- attributes[name] = value
-
- for name, ordered_declaration in getattr(cls, CLASS_ATTRIBUTE_ORDERED_DECLARATIONS).iteritems():
- if name in kwargs:
- attributes[name] = kwargs.pop(name)
- else:
- a = ObjectParamsWrapper(attributes)
- attributes[name] = ordered_declaration.evaluate(cls, a)
-
- for name in kwargs:
- attributes[name] = kwargs[name]
-
- return attributes
+ return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).build_attributes(cls, kwargs)
@classmethod
def build(cls, **kwargs):
diff --git a/factory/containers.py b/factory/containers.py
index a117d5c..63be161 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -18,6 +18,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+from declarations import OrderedDeclaration
+
class ObjectParamsWrapper(object):
'''A generic container that allows for getting but not setting of attributes.
@@ -83,6 +85,56 @@ class OrderedDeclarationDict(object):
for i in order:
yield self._order[i]
+class DeclarationsHolder(object):
+ """Holds all declarations, ordered and unordered."""
+
+ def __init__(self):
+ self._ordered = OrderedDeclarationDict()
+ self._unordered = {}
+
+ def update_base(self, attrs):
+ """Updates the DeclarationsHolder from a class definition.
+
+ Takes into account all public attributes and OrderedDeclaration
+ instances; ignores all attributes starting with '_'.
+
+ Returns a dict containing all remaining elements.
+ """
+ remaining = {}
+ for key, value in attrs.iteritems():
+ if isinstance(value, OrderedDeclaration):
+ self._ordered[key] = value
+ elif not key.startswith('_'):
+ self._unordered[key] = value
+ else:
+ remaining[key] = value
+ return remaining
+
+ def __contains__(self, key):
+ return key in self._ordered or key in self._unordered
+
+ def __getitem__(self, key):
+ try:
+ return self._unordered[key]
+ except KeyError:
+ return self._ordered[key]
+
+ def build_attributes(self, factory, extra):
+ """Build the list of attributes based on class attributes."""
+ attributes = {}
+ # For fields in _unordered, use the value from attrs if any; otherwise,
+ # use the default value.
+ for key, value in self._unordered.iteritems():
+ attributes[key] = extra.get(key, value)
+ for key, value in self._ordered.iteritems():
+ if key in extra:
+ attributes[key] = extra[key]
+ else:
+ wrapper = ObjectParamsWrapper(attributes)
+ attributes[key] = value.evaluate(factory, wrapper)
+ attributes.update(extra)
+ return attributes
+
class StubObject(object):
'''A generic container.'''