diff options
-rw-r--r-- | factory/containers.py | 2 | ||||
-rw-r--r-- | factory/declarations.py | 42 | ||||
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | tests/test_declarations.py | 43 | ||||
-rw-r--r-- | tests/test_using.py | 12 |
5 files changed, 53 insertions, 48 deletions
diff --git a/factory/containers.py b/factory/containers.py index fda9073..2f92f62 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -26,7 +26,7 @@ from factory import declarations #: String for splitting an attribute name into a #: (subfactory_name, subfactory_field) tuple. -ATTR_SPLITTER = declarations.ATTR_SPLITTER +ATTR_SPLITTER = '__' class CyclicDefinitionError(Exception): diff --git a/factory/declarations.py b/factory/declarations.py index 08598e5..5fe427c 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -21,11 +21,6 @@ # THE SOFTWARE. -#: String for splitting an attribute name into a -#: (subfactory_name, subfactory_field) tuple. -ATTR_SPLITTER = '__' - - class OrderedDeclaration(object): """A factory declaration. @@ -64,27 +59,37 @@ class LazyAttribute(OrderedDeclaration): return self.function(obj) -def dig(obj, name): - """Try to retrieve the given attribute of an object, using ATTR_SPLITTER. +class _UNSPECIFIED(object): + pass + - If ATTR_SPLITTER is '__', dig(foo, 'a__b__c') is equivalent to foo.a.b.c. +def deepgetattr(obj, name, default=_UNSPECIFIED): + """Try to retrieve the given attribute of an object, digging on '.'. + + This is an extended getattr, digging deeper if '.' is found. Args: obj (object): the object of which an attribute should be read name (str): the name of an attribute to look up. + default (object): the default value to use if the attribute wasn't found Returns: - the attribute pointed to by 'name', according to ATTR_SPLITTER. + the attribute pointed to by 'name', splitting on '.'. Raises: AttributeError: if obj has no 'name' attribute. """ - may_split = (ATTR_SPLITTER in name and not name.startswith(ATTR_SPLITTER)) - if may_split and not hasattr(obj, name): - attr, subname = name.split(ATTR_SPLITTER, 1) - return dig(getattr(obj, attr), subname) - else: - return getattr(obj, name) + try: + if '.' in name: + attr, subname = name.split('.', 1) + return deepgetattr(getattr(obj, attr), subname, default) + else: + return getattr(obj, name) + except AttributeError: + if default is _UNSPECIFIED: + raise + else: + return default class SelfAttribute(OrderedDeclaration): @@ -92,14 +97,17 @@ class SelfAttribute(OrderedDeclaration): Attributes: attribute_name (str): the name of the attribute to copy. + default (object): the default value to use if the attribute doesn't + exist. """ - def __init__(self, attribute_name, *args, **kwargs): + def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs): super(SelfAttribute, self).__init__(*args, **kwargs) self.attribute_name = attribute_name + self.default = default def evaluate(self, sequence, obj, containers=()): - return dig(obj, self.attribute_name) + return deepgetattr(obj, self.attribute_name, self.default) class Sequence(OrderedDeclaration): @@ -40,7 +40,7 @@ class test(cmd.Command): setup( name='factory_boy', version=VERSION, - description="A test fixtures replacement based on thoughtbot's factory_girl for Ruby.", + description="A verstile test fixtures replacement based on thoughtbot's factory_girl for Ruby.", author='Mark Sandstrom', author_email='mark@deliciouslynerdy.com', maintainer='Raphaƫl Barrois', diff --git a/tests/test_declarations.py b/tests/test_declarations.py index 0fcdf10..7215a54 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -22,7 +22,7 @@ import unittest -from factory.declarations import dig, OrderedDeclaration, Sequence +from factory.declarations import deepgetattr, OrderedDeclaration, Sequence class OrderedDeclarationTestCase(unittest.TestCase): def test_errors(self): @@ -35,36 +35,23 @@ class DigTestCase(unittest.TestCase): def __init__(self, n): self.n = n - def test_parentattr(self): - obj = self.MyObj(1) - obj.a__b__c = self.MyObj(2) - obj.a = self.MyObj(3) - obj.a.b = self.MyObj(4) - obj.a.b.c = self.MyObj(5) - - self.assertEqual(2, dig(obj, 'a__b__c').n) - - def test_private(self): - obj = self.MyObj(1) - self.assertEqual(obj.__class__, dig(obj, '__class__')) - def test_chaining(self): obj = self.MyObj(1) obj.a = self.MyObj(2) - obj.a__c = self.MyObj(3) - obj.a.b = self.MyObj(4) - obj.a.b.c = self.MyObj(5) - - self.assertEqual(2, dig(obj, 'a').n) - self.assertRaises(AttributeError, dig, obj, 'b') - self.assertEqual(2, dig(obj, 'a__n')) - self.assertEqual(3, dig(obj, 'a__c').n) - self.assertRaises(AttributeError, dig, obj, 'a__c__n') - self.assertRaises(AttributeError, dig, obj, 'a__d') - self.assertEqual(4, dig(obj, 'a__b').n) - self.assertEqual(4, dig(obj, 'a__b__n')) - self.assertEqual(5, dig(obj, 'a__b__c').n) - self.assertEqual(5, dig(obj, 'a__b__c__n')) + obj.a.b = self.MyObj(3) + obj.a.b.c = self.MyObj(4) + + self.assertEqual(2, deepgetattr(obj, 'a').n) + self.assertRaises(AttributeError, deepgetattr, obj, 'b') + self.assertEqual(2, deepgetattr(obj, 'a.n')) + self.assertEqual(3, deepgetattr(obj, 'a.c', 3)) + self.assertRaises(AttributeError, deepgetattr, obj, 'a.c.n') + self.assertRaises(AttributeError, deepgetattr, obj, 'a.d') + self.assertEqual(3, deepgetattr(obj, 'a.b').n) + self.assertEqual(3, deepgetattr(obj, 'a.b.n')) + self.assertEqual(4, deepgetattr(obj, 'a.b.c').n) + self.assertEqual(4, deepgetattr(obj, 'a.b.c.n')) + self.assertEqual(42, deepgetattr(obj, 'a.b.c.n.x', 42)) diff --git a/tests/test_using.py b/tests/test_using.py index a93c968..85a12ca 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -27,11 +27,12 @@ import factory class TestObject(object): - def __init__(self, one=None, two=None, three=None, four=None): + def __init__(self, one=None, two=None, three=None, four=None, five=None): self.one = one self.two = two self.three = three self.four = four + self.five = five class FakeDjangoModel(object): class FakeDjangoManager(object): @@ -164,12 +165,21 @@ class FactoryTestCase(unittest.TestCase): self.assertEqual(test_object.one, 'one') def testSelfAttribute(self): + class TmpObj(object): + n = 3 + class TestObjectFactory(factory.Factory): one = 'xx' two = factory.SelfAttribute('one') + three = TmpObj() + four = factory.SelfAttribute('three.n') + five = factory.SelfAttribute('three.nnn', 5) test_object = TestObjectFactory.build(one=1) self.assertEqual(1, test_object.two) + self.assertEqual(3, test_object.three.n) + self.assertEqual(3, test_object.four) + self.assertEqual(5, test_object.five) def testSequenceDecorator(self): class TestObjectFactory(factory.Factory): |