summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/containers.py2
-rw-r--r--factory/declarations.py42
-rw-r--r--setup.py2
-rw-r--r--tests/test_declarations.py43
-rw-r--r--tests/test_using.py12
5 files changed, 53 insertions, 48 deletions
diff --git a/factory/containers.py b/factory/containers.py
index fda9073..2f92f62 100644
--- a/factory/containers.py
+++ b/factory/containers.py
@@ -26,7 +26,7 @@ from factory import declarations
#: String for splitting an attribute name into a
#: (subfactory_name, subfactory_field) tuple.
-ATTR_SPLITTER = declarations.ATTR_SPLITTER
+ATTR_SPLITTER = '__'
class CyclicDefinitionError(Exception):
diff --git a/factory/declarations.py b/factory/declarations.py
index 08598e5..5fe427c 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -21,11 +21,6 @@
# THE SOFTWARE.
-#: String for splitting an attribute name into a
-#: (subfactory_name, subfactory_field) tuple.
-ATTR_SPLITTER = '__'
-
-
class OrderedDeclaration(object):
"""A factory declaration.
@@ -64,27 +59,37 @@ class LazyAttribute(OrderedDeclaration):
return self.function(obj)
-def dig(obj, name):
- """Try to retrieve the given attribute of an object, using ATTR_SPLITTER.
+class _UNSPECIFIED(object):
+ pass
+
- If ATTR_SPLITTER is '__', dig(foo, 'a__b__c') is equivalent to foo.a.b.c.
+def deepgetattr(obj, name, default=_UNSPECIFIED):
+ """Try to retrieve the given attribute of an object, digging on '.'.
+
+ This is an extended getattr, digging deeper if '.' is found.
Args:
obj (object): the object of which an attribute should be read
name (str): the name of an attribute to look up.
+ default (object): the default value to use if the attribute wasn't found
Returns:
- the attribute pointed to by 'name', according to ATTR_SPLITTER.
+ the attribute pointed to by 'name', splitting on '.'.
Raises:
AttributeError: if obj has no 'name' attribute.
"""
- may_split = (ATTR_SPLITTER in name and not name.startswith(ATTR_SPLITTER))
- if may_split and not hasattr(obj, name):
- attr, subname = name.split(ATTR_SPLITTER, 1)
- return dig(getattr(obj, attr), subname)
- else:
- return getattr(obj, name)
+ try:
+ if '.' in name:
+ attr, subname = name.split('.', 1)
+ return deepgetattr(getattr(obj, attr), subname, default)
+ else:
+ return getattr(obj, name)
+ except AttributeError:
+ if default is _UNSPECIFIED:
+ raise
+ else:
+ return default
class SelfAttribute(OrderedDeclaration):
@@ -92,14 +97,17 @@ class SelfAttribute(OrderedDeclaration):
Attributes:
attribute_name (str): the name of the attribute to copy.
+ default (object): the default value to use if the attribute doesn't
+ exist.
"""
- def __init__(self, attribute_name, *args, **kwargs):
+ def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super(SelfAttribute, self).__init__(*args, **kwargs)
self.attribute_name = attribute_name
+ self.default = default
def evaluate(self, sequence, obj, containers=()):
- return dig(obj, self.attribute_name)
+ return deepgetattr(obj, self.attribute_name, self.default)
class Sequence(OrderedDeclaration):
diff --git a/setup.py b/setup.py
index 61970fa..99aa123 100644
--- a/setup.py
+++ b/setup.py
@@ -40,7 +40,7 @@ class test(cmd.Command):
setup(
name='factory_boy',
version=VERSION,
- description="A test fixtures replacement based on thoughtbot's factory_girl for Ruby.",
+ description="A verstile test fixtures replacement based on thoughtbot's factory_girl for Ruby.",
author='Mark Sandstrom',
author_email='mark@deliciouslynerdy.com',
maintainer='Raphaƫl Barrois',
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index 0fcdf10..7215a54 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -22,7 +22,7 @@
import unittest
-from factory.declarations import dig, OrderedDeclaration, Sequence
+from factory.declarations import deepgetattr, OrderedDeclaration, Sequence
class OrderedDeclarationTestCase(unittest.TestCase):
def test_errors(self):
@@ -35,36 +35,23 @@ class DigTestCase(unittest.TestCase):
def __init__(self, n):
self.n = n
- def test_parentattr(self):
- obj = self.MyObj(1)
- obj.a__b__c = self.MyObj(2)
- obj.a = self.MyObj(3)
- obj.a.b = self.MyObj(4)
- obj.a.b.c = self.MyObj(5)
-
- self.assertEqual(2, dig(obj, 'a__b__c').n)
-
- def test_private(self):
- obj = self.MyObj(1)
- self.assertEqual(obj.__class__, dig(obj, '__class__'))
-
def test_chaining(self):
obj = self.MyObj(1)
obj.a = self.MyObj(2)
- obj.a__c = self.MyObj(3)
- obj.a.b = self.MyObj(4)
- obj.a.b.c = self.MyObj(5)
-
- self.assertEqual(2, dig(obj, 'a').n)
- self.assertRaises(AttributeError, dig, obj, 'b')
- self.assertEqual(2, dig(obj, 'a__n'))
- self.assertEqual(3, dig(obj, 'a__c').n)
- self.assertRaises(AttributeError, dig, obj, 'a__c__n')
- self.assertRaises(AttributeError, dig, obj, 'a__d')
- self.assertEqual(4, dig(obj, 'a__b').n)
- self.assertEqual(4, dig(obj, 'a__b__n'))
- self.assertEqual(5, dig(obj, 'a__b__c').n)
- self.assertEqual(5, dig(obj, 'a__b__c__n'))
+ obj.a.b = self.MyObj(3)
+ obj.a.b.c = self.MyObj(4)
+
+ self.assertEqual(2, deepgetattr(obj, 'a').n)
+ self.assertRaises(AttributeError, deepgetattr, obj, 'b')
+ self.assertEqual(2, deepgetattr(obj, 'a.n'))
+ self.assertEqual(3, deepgetattr(obj, 'a.c', 3))
+ self.assertRaises(AttributeError, deepgetattr, obj, 'a.c.n')
+ self.assertRaises(AttributeError, deepgetattr, obj, 'a.d')
+ self.assertEqual(3, deepgetattr(obj, 'a.b').n)
+ self.assertEqual(3, deepgetattr(obj, 'a.b.n'))
+ self.assertEqual(4, deepgetattr(obj, 'a.b.c').n)
+ self.assertEqual(4, deepgetattr(obj, 'a.b.c.n'))
+ self.assertEqual(42, deepgetattr(obj, 'a.b.c.n.x', 42))
diff --git a/tests/test_using.py b/tests/test_using.py
index a93c968..85a12ca 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -27,11 +27,12 @@ import factory
class TestObject(object):
- def __init__(self, one=None, two=None, three=None, four=None):
+ def __init__(self, one=None, two=None, three=None, four=None, five=None):
self.one = one
self.two = two
self.three = three
self.four = four
+ self.five = five
class FakeDjangoModel(object):
class FakeDjangoManager(object):
@@ -164,12 +165,21 @@ class FactoryTestCase(unittest.TestCase):
self.assertEqual(test_object.one, 'one')
def testSelfAttribute(self):
+ class TmpObj(object):
+ n = 3
+
class TestObjectFactory(factory.Factory):
one = 'xx'
two = factory.SelfAttribute('one')
+ three = TmpObj()
+ four = factory.SelfAttribute('three.n')
+ five = factory.SelfAttribute('three.nnn', 5)
test_object = TestObjectFactory.build(one=1)
self.assertEqual(1, test_object.two)
+ self.assertEqual(3, test_object.three.n)
+ self.assertEqual(3, test_object.four)
+ self.assertEqual(5, test_object.five)
def testSequenceDecorator(self):
class TestObjectFactory(factory.Factory):