summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py42
1 files changed, 25 insertions, 17 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 08598e5..5fe427c 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -21,11 +21,6 @@
# THE SOFTWARE.
-#: String for splitting an attribute name into a
-#: (subfactory_name, subfactory_field) tuple.
-ATTR_SPLITTER = '__'
-
-
class OrderedDeclaration(object):
"""A factory declaration.
@@ -64,27 +59,37 @@ class LazyAttribute(OrderedDeclaration):
return self.function(obj)
-def dig(obj, name):
- """Try to retrieve the given attribute of an object, using ATTR_SPLITTER.
+class _UNSPECIFIED(object):
+ pass
+
- If ATTR_SPLITTER is '__', dig(foo, 'a__b__c') is equivalent to foo.a.b.c.
+def deepgetattr(obj, name, default=_UNSPECIFIED):
+ """Try to retrieve the given attribute of an object, digging on '.'.
+
+ This is an extended getattr, digging deeper if '.' is found.
Args:
obj (object): the object of which an attribute should be read
name (str): the name of an attribute to look up.
+ default (object): the default value to use if the attribute wasn't found
Returns:
- the attribute pointed to by 'name', according to ATTR_SPLITTER.
+ the attribute pointed to by 'name', splitting on '.'.
Raises:
AttributeError: if obj has no 'name' attribute.
"""
- may_split = (ATTR_SPLITTER in name and not name.startswith(ATTR_SPLITTER))
- if may_split and not hasattr(obj, name):
- attr, subname = name.split(ATTR_SPLITTER, 1)
- return dig(getattr(obj, attr), subname)
- else:
- return getattr(obj, name)
+ try:
+ if '.' in name:
+ attr, subname = name.split('.', 1)
+ return deepgetattr(getattr(obj, attr), subname, default)
+ else:
+ return getattr(obj, name)
+ except AttributeError:
+ if default is _UNSPECIFIED:
+ raise
+ else:
+ return default
class SelfAttribute(OrderedDeclaration):
@@ -92,14 +97,17 @@ class SelfAttribute(OrderedDeclaration):
Attributes:
attribute_name (str): the name of the attribute to copy.
+ default (object): the default value to use if the attribute doesn't
+ exist.
"""
- def __init__(self, attribute_name, *args, **kwargs):
+ def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super(SelfAttribute, self).__init__(*args, **kwargs)
self.attribute_name = attribute_name
+ self.default = default
def evaluate(self, sequence, obj, containers=()):
- return dig(obj, self.attribute_name)
+ return deepgetattr(obj, self.attribute_name, self.default)
class Sequence(OrderedDeclaration):