summaryrefslogtreecommitdiff
path: root/factory/containers.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/containers.py')
-rw-r--r--factory/containers.py142
1 files changed, 83 insertions, 59 deletions
diff --git a/factory/containers.py b/factory/containers.py
index 48a4015..0620c82 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -47,7 +47,7 @@ class ObjectParamsWrapper(object):
raise AttributeError("The param '{0}' does not exist. Perhaps your declarations are out of order?".format(name))
-class OrderedDeclarationDict(object):
+class OrderedDict(object):
def __init__(self, **kwargs):
self._order = {}
self._values = {}
@@ -88,18 +88,51 @@ class OrderedDeclarationDict(object):
for i in order:
yield self._order[i]
-class DeclarationsHolder(object):
- """Holds all declarations, ordered and unordered."""
- def __init__(self, defaults=None):
- if not defaults:
- defaults = {}
- self._ordered = OrderedDeclarationDict()
+class DeclarationDict(object):
+ """Holds a dict of declarations, keeping OrderedDeclaration at the end."""
+ def __init__(self, extra=None):
+ if not extra:
+ extra = {}
+ self._ordered = OrderedDict()
self._unordered = {}
- self.update_base(defaults)
+ self.update(extra)
+
+ def __setitem__(self, key, value):
+ if key in self:
+ del self[key]
+
+ if isinstance(value, OrderedDeclaration):
+ self._ordered[key] = value
+ else:
+ self._unordered[key] = value
+
+ def __getitem__(self, key):
+ """Try in _unordered first, then in _ordered."""
+ try:
+ return self._unordered[key]
+ except KeyError:
+ return self._ordered[key]
- def update_base(self, attrs):
- """Updates the DeclarationsHolder from a class definition.
+ def __delitem__(self, key):
+ if key in self._unordered:
+ del self._unordered[key]
+ else:
+ del self._ordered[key]
+
+ def pop(self, key, *args):
+ assert len(args) <= 1
+ try:
+ return self._unordered.pop(key)
+ except KeyError:
+ return self._ordered.pop(key, *args)
+
+ def update(self, d):
+ for k in d:
+ self[k] = d[k]
+
+ def update_with_public(self, d):
+ """Updates the DeclarationDict from a class definition dict.
Takes into account all public attributes and OrderedDeclaration
instances; ignores all attributes starting with '_'.
@@ -107,23 +140,25 @@ class DeclarationsHolder(object):
Returns a dict containing all remaining elements.
"""
remaining = {}
- for key, value in attrs.iteritems():
- if isinstance(value, OrderedDeclaration):
- self._ordered[key] = value
- elif not key.startswith('_'):
- self._unordered[key] = value
+ for k, v in d.iteritems():
+ if k.startswith('_') and not isinstance(v, OrderedDeclaration):
+ remaining[k] = v
else:
- remaining[key] = value
+ self[k] = v
return remaining
+ def copy(self, extra=None):
+ new = DeclarationDict()
+ new.update(self)
+ if extra:
+ new.update(extra)
+ return new
+
def __contains__(self, key):
- return key in self._ordered or key in self._unordered
+ return key in self._unordered or key in self._ordered
- def __getitem__(self, key):
- try:
- return self._unordered[key]
- except KeyError:
- return self._ordered[key]
+ def items(self):
+ return list(self.iteritems())
def iteritems(self):
for pair in self._unordered.iteritems():
@@ -131,58 +166,47 @@ class DeclarationsHolder(object):
for pair in self._ordered.iteritems():
yield pair
- def items(self):
- return list(self.iteritems())
+ def __iter__(self):
+ for k in self._unordered:
+ yield k
+ for k in self._ordered:
+ yield k
- def _extract_sub_fields(self, base):
- """Extract all subfields declaration from a given dict-like object.
- Will compare with attributes declared in the current object, and
- will pop() values from the given base.
- """
- sub_fields = dict()
+class AttributeBuilder(object):
+ """Builds attributes from a factory and extra data."""
+
+ def __init__(self, factory, extra=None):
+ if not extra:
+ extra = {}
+ self.factory = factory
+ self._attrs = factory.declarations(extra)
+ self._subfield = self._extract_subfields()
- for key in list(base):
+ def _extract_subfields(self):
+ sub_fields = {}
+ for key in list(self._attrs):
if ATTR_SPLITTER in key:
cls_name, attr_name = key.split(ATTR_SPLITTER, 1)
- if cls_name in self:
- sub_fields.setdefault(cls_name, {})[attr_name] = base.pop(key)
+ if cls_name in self._attrs:
+ sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key)
return sub_fields
- def build_attributes(self, factory, create=False, extra=None):
- """Build the list of attributes based on class attributes."""
- if not extra:
- extra = {}
-
- factory.sequence = factory._generate_next_sequence()
+ def build(self, create):
+ self.factory.sequence = self.factory._generate_next_sequence()
attributes = {}
- sub_fields = {}
- for base in (self._unordered, self._ordered, extra):
- sub_fields.update(self._extract_sub_fields(base))
-
- def make_value(key, val):
- if key in extra:
- val = extra.pop(key)
+ for key, val in self._attrs.iteritems():
if isinstance(val, SubFactory):
- new_val = val.evaluate(factory, create, sub_fields.get(key, {}))
+ val = val.evaluate(self.factory, create, self._subfield.get(key, {}))
elif isinstance(val, OrderedDeclaration):
wrapper = ObjectParamsWrapper(attributes)
- new_val = val.evaluate(factory, wrapper)
- else:
- new_val = val
+ val = val.evaluate(self.factory, wrapper)
+ attributes[key] = val
- return new_val
-
- # For fields in _unordered, use the value from extra if any; otherwise,
- # use the default value.
- for key, value in self._unordered.iteritems():
- attributes[key] = make_value(key, value)
- for key, value in self._ordered.iteritems():
- attributes[key] = make_value(key, value)
- attributes.update(extra)
return attributes
+
class StubObject(object):
"""A generic container."""