diff options
author | Raphaël Barrois <raphael.barrois@polyconseil.fr> | 2011-09-06 12:04:51 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polyconseil.fr> | 2011-09-06 12:04:51 +0200 |
commit | 5c770cb1918d8067db69d3c880a9a6ad1d31ddb8 (patch) | |
tree | a1fa05eb784a8ddb1554e12e0ecbae42bad08c67 | |
parent | 2821684dec8a71e93728fc3e036f83592324b518 (diff) | |
download | factory-boy-5c770cb1918d8067db69d3c880a9a6ad1d31ddb8.tar factory-boy-5c770cb1918d8067db69d3c880a9a6ad1d31ddb8.tar.gz |
Huge refactoring and code cleanup.
Signed-off-by: Raphaël Barrois <raphael.barrois@polyconseil.fr>
-rw-r--r-- | factory/base.py | 24 | ||||
-rw-r--r-- | factory/containers.py | 142 | ||||
-rw-r--r-- | factory/declarations.py | 34 |
3 files changed, 117 insertions, 83 deletions
diff --git a/factory/base.py b/factory/base.py index 380e015..34c8470 100644 --- a/factory/base.py +++ b/factory/base.py @@ -21,7 +21,7 @@ import re import sys -from containers import DeclarationsHolder, ObjectParamsWrapper, OrderedDeclarationDict, StubObject +from containers import AttributeBuilder, DeclarationDict, ObjectParamsWrapper, StubObject from declarations import OrderedDeclaration # Strategies @@ -73,14 +73,14 @@ class BaseFactoryMetaClass(type): # If this isn't a subclass of Factory, don't do anything special. return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - declarations = DeclarationsHolder() - for base in parent_factories: - declarations.update_base(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) + declarations = DeclarationDict() - non_factory_attrs = declarations.update_base(attrs) + #Add parent declarations in reverse order. + for base in reversed(parent_factories): + declarations.update_with_public(getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, {})) + non_factory_attrs = declarations.update_with_public(attrs) non_factory_attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations - non_factory_attrs.update(extra_attrs) return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, non_factory_attrs) @@ -167,21 +167,21 @@ class BaseFactory(object): return next_sequence @classmethod - def attributes(cls, create=False, **kwargs): + def attributes(cls, create=False, extra=None): """Build a dict of attribute values, respecting declaration order. The process is: - Handle 'orderless' attributes, overriding defaults with provided kwargs when applicable - Handle ordered attributes, overriding them with provided kwargs when - applicable; the current list of computed attributes is available for + applicable; the current list of computed attributes is available to the currently processed object. """ - return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).build_attributes(cls, create, kwargs) + return AttributeBuilder(cls, extra).build(create) @classmethod - def declarations(cls): - return DeclarationsHolder(getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS)) + def declarations(cls, extra_defs=None): + return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).copy(extra_defs) @classmethod def build(cls, **kwargs): @@ -198,11 +198,13 @@ class BaseFactory(object): setattr(stub_object, name, value) return stub_object + class StubFactory(BaseFactory): __metaclass__ = BaseFactoryMetaClass default_strategy = STUB_STRATEGY + class Factory(BaseFactory): """Factory base with build and create support. diff --git a/factory/containers.py b/factory/containers.py index 48a4015..0620c82 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -47,7 +47,7 @@ class ObjectParamsWrapper(object): raise AttributeError("The param '{0}' does not exist. Perhaps your declarations are out of order?".format(name)) -class OrderedDeclarationDict(object): +class OrderedDict(object): def __init__(self, **kwargs): self._order = {} self._values = {} @@ -88,18 +88,51 @@ class OrderedDeclarationDict(object): for i in order: yield self._order[i] -class DeclarationsHolder(object): - """Holds all declarations, ordered and unordered.""" - def __init__(self, defaults=None): - if not defaults: - defaults = {} - self._ordered = OrderedDeclarationDict() +class DeclarationDict(object): + """Holds a dict of declarations, keeping OrderedDeclaration at the end.""" + def __init__(self, extra=None): + if not extra: + extra = {} + self._ordered = OrderedDict() self._unordered = {} - self.update_base(defaults) + self.update(extra) + + def __setitem__(self, key, value): + if key in self: + del self[key] + + if isinstance(value, OrderedDeclaration): + self._ordered[key] = value + else: + self._unordered[key] = value + + def __getitem__(self, key): + """Try in _unordered first, then in _ordered.""" + try: + return self._unordered[key] + except KeyError: + return self._ordered[key] - def update_base(self, attrs): - """Updates the DeclarationsHolder from a class definition. + def __delitem__(self, key): + if key in self._unordered: + del self._unordered[key] + else: + del self._ordered[key] + + def pop(self, key, *args): + assert len(args) <= 1 + try: + return self._unordered.pop(key) + except KeyError: + return self._ordered.pop(key, *args) + + def update(self, d): + for k in d: + self[k] = d[k] + + def update_with_public(self, d): + """Updates the DeclarationDict from a class definition dict. Takes into account all public attributes and OrderedDeclaration instances; ignores all attributes starting with '_'. @@ -107,23 +140,25 @@ class DeclarationsHolder(object): Returns a dict containing all remaining elements. """ remaining = {} - for key, value in attrs.iteritems(): - if isinstance(value, OrderedDeclaration): - self._ordered[key] = value - elif not key.startswith('_'): - self._unordered[key] = value + for k, v in d.iteritems(): + if k.startswith('_') and not isinstance(v, OrderedDeclaration): + remaining[k] = v else: - remaining[key] = value + self[k] = v return remaining + def copy(self, extra=None): + new = DeclarationDict() + new.update(self) + if extra: + new.update(extra) + return new + def __contains__(self, key): - return key in self._ordered or key in self._unordered + return key in self._unordered or key in self._ordered - def __getitem__(self, key): - try: - return self._unordered[key] - except KeyError: - return self._ordered[key] + def items(self): + return list(self.iteritems()) def iteritems(self): for pair in self._unordered.iteritems(): @@ -131,58 +166,47 @@ class DeclarationsHolder(object): for pair in self._ordered.iteritems(): yield pair - def items(self): - return list(self.iteritems()) + def __iter__(self): + for k in self._unordered: + yield k + for k in self._ordered: + yield k - def _extract_sub_fields(self, base): - """Extract all subfields declaration from a given dict-like object. - Will compare with attributes declared in the current object, and - will pop() values from the given base. - """ - sub_fields = dict() +class AttributeBuilder(object): + """Builds attributes from a factory and extra data.""" + + def __init__(self, factory, extra=None): + if not extra: + extra = {} + self.factory = factory + self._attrs = factory.declarations(extra) + self._subfield = self._extract_subfields() - for key in list(base): + def _extract_subfields(self): + sub_fields = {} + for key in list(self._attrs): if ATTR_SPLITTER in key: cls_name, attr_name = key.split(ATTR_SPLITTER, 1) - if cls_name in self: - sub_fields.setdefault(cls_name, {})[attr_name] = base.pop(key) + if cls_name in self._attrs: + sub_fields.setdefault(cls_name, {})[attr_name] = self._attrs.pop(key) return sub_fields - def build_attributes(self, factory, create=False, extra=None): - """Build the list of attributes based on class attributes.""" - if not extra: - extra = {} - - factory.sequence = factory._generate_next_sequence() + def build(self, create): + self.factory.sequence = self.factory._generate_next_sequence() attributes = {} - sub_fields = {} - for base in (self._unordered, self._ordered, extra): - sub_fields.update(self._extract_sub_fields(base)) - - def make_value(key, val): - if key in extra: - val = extra.pop(key) + for key, val in self._attrs.iteritems(): if isinstance(val, SubFactory): - new_val = val.evaluate(factory, create, sub_fields.get(key, {})) + val = val.evaluate(self.factory, create, self._subfield.get(key, {})) elif isinstance(val, OrderedDeclaration): wrapper = ObjectParamsWrapper(attributes) - new_val = val.evaluate(factory, wrapper) - else: - new_val = val + val = val.evaluate(self.factory, wrapper) + attributes[key] = val - return new_val - - # For fields in _unordered, use the value from extra if any; otherwise, - # use the default value. - for key, value in self._unordered.iteritems(): - attributes[key] = make_value(key, value) - for key, value in self._ordered.iteritems(): - attributes[key] = make_value(key, value) - attributes.update(extra) return attributes + class StubObject(object): """A generic container.""" diff --git a/factory/declarations.py b/factory/declarations.py index 898493c..a1e9102 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -48,22 +48,24 @@ class OrderedDeclaration(object): def __init__(self): self.order = GlobalCounter.step() - def evaluate(self, factory, attributes): + def evaluate(self, factory, obj): """Evaluate this declaration. Args: factory: The factory this declaration was defined in. + obj: The object holding currently computed attributes attributes: The attributes created by the unordered and ordered declarations up to this point.""" raise NotImplementedError('This is an abstract method') + class LazyAttribute(OrderedDeclaration): def __init__(self, function): super(LazyAttribute, self).__init__() self.function = function - def evaluate(self, factory, attributes): - return self.function(attributes) + def evaluate(self, factory, obj): + return self.function(obj) class SelfAttribute(OrderedDeclaration): @@ -71,8 +73,8 @@ class SelfAttribute(OrderedDeclaration): super(SelfAttribute, self).__init__() self.attribute_name = attribute_name - def evaluate(self, factory, attributes): - return getattr(attributes, self.attribute_name) + def evaluate(self, factory, obj): + return getattr(obj, self.attribute_name) class Sequence(OrderedDeclaration): @@ -81,12 +83,14 @@ class Sequence(OrderedDeclaration): self.function = function self.type = type - def evaluate(self, factory, attributes): + def evaluate(self, factory, obj): return self.function(self.type(factory.sequence)) + class LazyAttributeSequence(Sequence): - def evaluate(self, factory, attributes): - return self.function(attributes, self.type(factory.sequence)) + def evaluate(self, factory, obj): + return self.function(obj, self.type(factory.sequence)) + class SubFactory(OrderedDeclaration): """Base class for attributes based upon a sub-factory. @@ -98,20 +102,24 @@ class SubFactory(OrderedDeclaration): def __init__(self, factory, **kwargs): super(SubFactory, self).__init__() - self.defaults = factory.declarations() - self.defaults.update_base(kwargs) + self.defaults = kwargs self.factory = factory - def evaluate(self, factory, create, attributes): + def evaluate(self, factory, create, extra): """Evaluate the current definition and fill its attributes. Uses attributes definition in the following order: - attributes defined in the wrapped factory class - values defined when defining the SubFactory - - additional valued defined in attributes + - additional values defined in attributes """ - attrs = self.defaults.build_attributes(self.factory, create, attributes) + defaults = dict(self.defaults) + if extra: + defaults.update(extra) + + attrs = self.factory.attributes(create, defaults) + if create: return self.factory.create(**attrs) else: |