diff options
-rw-r--r-- | factory/base.py | 40 | ||||
-rw-r--r-- | factory/containers.py | 52 |
2 files changed, 59 insertions, 33 deletions
diff --git a/factory/base.py b/factory/base.py index 013b58f..e78ec6e 100644 --- a/factory/base.py +++ b/factory/base.py @@ -21,7 +21,7 @@ import re import sys -from containers import ObjectParamsWrapper, OrderedDeclarationDict, StubObject +from containers import DeclarationsHolder, ObjectParamsWrapper, OrderedDeclarationDict, StubObject from declarations import OrderedDeclaration # Strategies @@ -40,8 +40,7 @@ FACTORY_CLASS_DECLARATION = 'FACTORY_FOR' # Factory class attributes -CLASS_ATTRIBUTE_ORDERED_DECLARATIONS = '_ordered_declarations' -CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS = '_unordered_declarations' +CLASS_ATTRIBUTE_DECLARATIONS = '_declarations' CLASS_ATTRIBUTE_ASSOCIATED_CLASS = '_associated_class' # Factory metaclasses @@ -78,21 +77,12 @@ class BaseFactoryMetaClass(type): # If this isn't a subclass of Factory, don't do anything special. return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) - ordered_declarations = getattr(base, CLASS_ATTRIBUTE_ORDERED_DECLARATIONS, - OrderedDeclarationDict()) - unordered_declarations = getattr(base, CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS, {}) + declarations = getattr(base, CLASS_ATTRIBUTE_DECLARATIONS, DeclarationsHolder()) + attrs = declarations.update_base(attrs) - for name in list(attrs): - if isinstance(attrs[name], OrderedDeclaration): - ordered_declarations[name] = attrs.pop(name) - elif not name.startswith('_'): - unordered_declarations[name] = attrs.pop(name) + attrs[CLASS_ATTRIBUTE_DECLARATIONS] = declarations - attrs[CLASS_ATTRIBUTE_ORDERED_DECLARATIONS] = ordered_declarations - attrs[CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS] = unordered_declarations - - for name, value in extra_attrs.iteritems(): - attrs[name] = value + attrs.update(extra_attrs) return super(BaseFactoryMetaClass, cls).__new__(cls, class_name, bases, attrs) @@ -176,23 +166,7 @@ class BaseFactory(object): attributes = {} cls.sequence = cls._generate_next_sequence() - for name, value in getattr(cls, CLASS_ATTRIBUTE_UNORDERED_DECLARATIONS).iteritems(): - if name in kwargs: - attributes[name] = kwargs.pop(name) - else: - attributes[name] = value - - for name, ordered_declaration in getattr(cls, CLASS_ATTRIBUTE_ORDERED_DECLARATIONS).iteritems(): - if name in kwargs: - attributes[name] = kwargs.pop(name) - else: - a = ObjectParamsWrapper(attributes) - attributes[name] = ordered_declaration.evaluate(cls, a) - - for name in kwargs: - attributes[name] = kwargs[name] - - return attributes + return getattr(cls, CLASS_ATTRIBUTE_DECLARATIONS).build_attributes(cls, kwargs) @classmethod def build(cls, **kwargs): diff --git a/factory/containers.py b/factory/containers.py index a117d5c..63be161 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -18,6 +18,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +from declarations import OrderedDeclaration + class ObjectParamsWrapper(object): '''A generic container that allows for getting but not setting of attributes. @@ -83,6 +85,56 @@ class OrderedDeclarationDict(object): for i in order: yield self._order[i] +class DeclarationsHolder(object): + """Holds all declarations, ordered and unordered.""" + + def __init__(self): + self._ordered = OrderedDeclarationDict() + self._unordered = {} + + def update_base(self, attrs): + """Updates the DeclarationsHolder from a class definition. + + Takes into account all public attributes and OrderedDeclaration + instances; ignores all attributes starting with '_'. + + Returns a dict containing all remaining elements. + """ + remaining = {} + for key, value in attrs.iteritems(): + if isinstance(value, OrderedDeclaration): + self._ordered[key] = value + elif not key.startswith('_'): + self._unordered[key] = value + else: + remaining[key] = value + return remaining + + def __contains__(self, key): + return key in self._ordered or key in self._unordered + + def __getitem__(self, key): + try: + return self._unordered[key] + except KeyError: + return self._ordered[key] + + def build_attributes(self, factory, extra): + """Build the list of attributes based on class attributes.""" + attributes = {} + # For fields in _unordered, use the value from attrs if any; otherwise, + # use the default value. + for key, value in self._unordered.iteritems(): + attributes[key] = extra.get(key, value) + for key, value in self._ordered.iteritems(): + if key in extra: + attributes[key] = extra[key] + else: + wrapper = ObjectParamsWrapper(attributes) + attributes[key] = value.evaluate(factory, wrapper) + attributes.update(extra) + return attributes + class StubObject(object): '''A generic container.''' |