diff options
Diffstat (limited to 'factory/containers.py')
-rw-r--r-- | factory/containers.py | 142 |
1 files changed, 83 insertions, 59 deletions
diff --git a/factory/containers.py b/factory/containers.py index 48a4015..0620c82 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -47,7 +47,7 @@ class ObjectParamsWrapper(object): raise AttributeError("The param '{0}' does not exist. Perhaps your declarations are out of order?".format(name)) -class OrderedDeclarationDict(object): +class OrderedDict(object): def __init__(self, **kwargs): self._order = {} self._values = {} @@ -88,18 +88,51 @@ class OrderedDeclarationDict(object): for i in order: yield self._order[i] -class DeclarationsHolder(object): - """Holds all declarations, ordered and unordered.""" - def __init__(self, defaults=None): - if not defaults: - defaults = {} - self._ordered = OrderedDeclarationDict() +class DeclarationDict(object): + """Holds a dict of declarations, keeping OrderedDeclaration at the end.""" + def __init__(self, extra=None): + if not extra: + extra = {} + self._ordered = OrderedDict() self._unordered = {} - self.update_base(defaults) + self.update(extra) + + def __setitem__(self, key, value): + if key in self: + del self[key] + + if isinstance(value, OrderedDeclaration): + self._ordered[key] = value + else: + self._unordered[key] = value + + def __getitem__(self, key): + """Try in _unordered first, then in _ordered.""" + try: + return self._unordered[key] + except KeyError: + return self._ordered[key] - def update_base(self, attrs): - """Updates the DeclarationsHolder from a class definition. + def __delitem__(self, key): + if key in self._unordered: + del self._unordered[key] + else: + del self._ordered[key] + + def pop(self, key, *args): + assert len(args) <= 1 + try: + return self._unordered.pop(key) + except KeyError: + return self._ordered.pop(key, *args) + + def update(self, d): + for k in d: + self[k] = d[k] + + def update_with_public(self, d): + """Updates the DeclarationDict from a class definition dict. Takes into account all public attributes and OrderedDeclaration instances; ignores all attributes starting with '_'. @@ -107,23 +140,25 @@ class DeclarationsHolder(object): Returns a dict containing all remaining elements. """ remaining = {} - for key, value in attrs.iteritems(): - if isinstance(value, OrderedDeclaration): - self._ordered[key] = value - elif not key.startswith('_'): - self._unordered[key] = value + for k, v in d.iteritems(): + if k.startswith('_') and not isinstance(v, OrderedDeclaration): + remaining[k] = v else: - remaining[key] = value + self[k] = v return remaining + def copy(self, extra=None): + new = DeclarationDict() + new.update(self) + if extra: + new.update(extra) + return new + def __contains__(self, key): - return key in self._ordered or key in self._unordered + return key in self._unordered or key in self._ordered - def __getitem__(self, key): - try: - return self._unordered[key] - except KeyError: - return self._ordered[key] + def items(self): + return list(self.iteritems()) def iteritems(self): for pair in self._unordered.iteritems(): @@ -131,58 +166,47 @@ class DeclarationsHolder(object): for pair in self._ordered.iteritems(): yield pair - def items(self): - return list(self.iteritems()) + def __iter__(self): + for k in self._unordered: + yield k + for k in self._ordered: + yield k - def _extract_sub_fields(self, base): - """Extract all subfields declaration from a given dict-like object. - Will compare with attributes declared in the current object, and - will pop() values from the given base. - """ - sub_fields = dict() +class AttributeBuilder(object): + """Builds attributes from a factory and extra data.""" + + def __init__(self, factory, extra=None): + if not extra: + extra = {} + self.factory = factory + self._attrs = factory.declarations(extra) + self._subfield = self._extract_subfields() - for key in list(base): + def _extract_subfields(self): + sub_fields = {} + for key in list(self._attrs): if ATTR_SPLITTER in key: cls_name, attr_name = key.split(ATTR_SPLITTER, 1) - if cls_name in self: - sub_fields.setdefault(cls_name, {})[attr_name] = base.pop(key) + if cls_name in self._attrs: + sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key) return sub_fields - def build_attributes(self, factory, create=False, extra=None): - """Build the list of attributes based on class attributes.""" - if not extra: - extra = {} - - factory.sequence = factory._generate_next_sequence() + def build(self, create): + self.factory.sequence = self.factory._generate_next_sequence() attributes = {} - sub_fields = {} - for base in (self._unordered, self._ordered, extra): - sub_fields.update(self._extract_sub_fields(base)) - - def make_value(key, val): - if key in extra: - val = extra.pop(key) + for key, val in self._attrs.iteritems(): if isinstance(val, SubFactory): - new_val = val.evaluate(factory, create, sub_fields.get(key, {})) + val = val.evaluate(self.factory, create, self._subfield.get(key, {})) elif isinstance(val, OrderedDeclaration): wrapper = ObjectParamsWrapper(attributes) - new_val = val.evaluate(factory, wrapper) - else: - new_val = val + val = val.evaluate(self.factory, wrapper) + attributes[key] = val - return new_val - - # For fields in _unordered, use the value from extra if any; otherwise, - # use the default value. - for key, value in self._unordered.iteritems(): - attributes[key] = make_value(key, value) - for key, value in self._ordered.iteritems(): - attributes[key] = make_value(key, value) - attributes.update(extra) return attributes + class StubObject(object): """A generic container.""" |