diff options
Diffstat (limited to 'factory/declarations.py')
-rw-r--r-- | factory/declarations.py | 42 |
1 files changed, 25 insertions, 17 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 08598e5..5fe427c 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -21,11 +21,6 @@ # THE SOFTWARE. -#: String for splitting an attribute name into a -#: (subfactory_name, subfactory_field) tuple. -ATTR_SPLITTER = '__' - - class OrderedDeclaration(object): """A factory declaration. @@ -64,27 +59,37 @@ class LazyAttribute(OrderedDeclaration): return self.function(obj) -def dig(obj, name): - """Try to retrieve the given attribute of an object, using ATTR_SPLITTER. +class _UNSPECIFIED(object): + pass + - If ATTR_SPLITTER is '__', dig(foo, 'a__b__c') is equivalent to foo.a.b.c. +def deepgetattr(obj, name, default=_UNSPECIFIED): + """Try to retrieve the given attribute of an object, digging on '.'. + + This is an extended getattr, digging deeper if '.' is found. Args: obj (object): the object of which an attribute should be read name (str): the name of an attribute to look up. + default (object): the default value to use if the attribute wasn't found Returns: - the attribute pointed to by 'name', according to ATTR_SPLITTER. + the attribute pointed to by 'name', splitting on '.'. Raises: AttributeError: if obj has no 'name' attribute. """ - may_split = (ATTR_SPLITTER in name and not name.startswith(ATTR_SPLITTER)) - if may_split and not hasattr(obj, name): - attr, subname = name.split(ATTR_SPLITTER, 1) - return dig(getattr(obj, attr), subname) - else: - return getattr(obj, name) + try: + if '.' in name: + attr, subname = name.split('.', 1) + return deepgetattr(getattr(obj, attr), subname, default) + else: + return getattr(obj, name) + except AttributeError: + if default is _UNSPECIFIED: + raise + else: + return default class SelfAttribute(OrderedDeclaration): @@ -92,14 +97,17 @@ class SelfAttribute(OrderedDeclaration): Attributes: attribute_name (str): the name of the attribute to copy. + default (object): the default value to use if the attribute doesn't + exist. """ - def __init__(self, attribute_name, *args, **kwargs): + def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs): super(SelfAttribute, self).__init__(*args, **kwargs) self.attribute_name = attribute_name + self.default = default def evaluate(self, sequence, obj, containers=()): - return dig(obj, self.attribute_name) + return deepgetattr(obj, self.attribute_name, self.default) class Sequence(OrderedDeclaration): |