diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2012-02-24 00:07:53 +0100 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2012-02-24 00:07:53 +0100 |
commit | 2a1138550b3220b6f8cd23bae5fed03f0fb448cf (patch) | |
tree | 38b6b352c3efead96a410dddb5b524af2d22a9e2 | |
parent | ff9d0f536bc443b81e6c95cf31644b6e19236538 (diff) | |
download | factory-boy-2a1138550b3220b6f8cd23bae5fed03f0fb448cf.tar factory-boy-2a1138550b3220b6f8cd23bae5fed03f0fb448cf.tar.gz |
Allow using '__' in factory.SelfAttribute.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r-- | factory/containers.py | 2 | ||||
-rw-r--r-- | factory/declarations.py | 33 | ||||
-rw-r--r-- | tests/test_declarations.py | 41 |
3 files changed, 71 insertions, 5 deletions
diff --git a/factory/containers.py b/factory/containers.py index dd11f5f..ef97548 100644 --- a/factory/containers.py +++ b/factory/containers.py @@ -26,7 +26,7 @@ from factory import declarations #: String for splitting an attribute name into a #: (subfactory_name, subfactory_field) tuple. -ATTR_SPLITTER = '__' +ATTR_SPLITTER = declarations.ATTR_SPLITTER class CyclicDefinitionError(Exception): diff --git a/factory/declarations.py b/factory/declarations.py index 0ce7071..60425c3 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -20,6 +20,12 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. + +#: String for splitting an attribute name into a +#: (subfactory_name, subfactory_field) tuple. +ATTR_SPLITTER = '__' + + class OrderedDeclaration(object): """A factory declaration. @@ -58,6 +64,29 @@ class LazyAttribute(OrderedDeclaration): return self.function(obj) +def dig(obj, name): + """Try to retrieve the given attribute of an object, using ATTR_SPLITTER. + + If ATTR_SPLITTER is '__', dig(foo, 'a__b__c') is equivalent to foo.a.b.c. + + Args: + obj (object): the object of which an attribute should be read + name (str): the name of an attribute to look up. + + Returns: + the attribute pointed to by 'name', according to ATTR_SPLITTER. + + Raises: + AttributeError: if obj has no 'name' attribute. + """ + may_split = (ATTR_SPLITTER in name and not name.startswith(ATTR_SPLITTER)) + if may_split and not hasattr(obj, name): + attr, subname = name.split(ATTR_SPLITTER, 1) + return dig(getattr(obj, attr), subname) + else: + return getattr(obj, name) + + class SelfAttribute(OrderedDeclaration): """Specific OrderedDeclaration copying values from other fields. @@ -70,9 +99,7 @@ class SelfAttribute(OrderedDeclaration): self.attribute_name = attribute_name def evaluate(self, sequence, obj, containers=()): - # TODO(rbarrois): allow the use of ATTR_SPLITTER to fetch fields of - # subfactories. - return getattr(obj, self.attribute_name) + return dig(obj, self.attribute_name) class Sequence(OrderedDeclaration): diff --git a/tests/test_declarations.py b/tests/test_declarations.py index dcee38b..0fcdf10 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -22,12 +22,51 @@ import unittest -from factory.declarations import OrderedDeclaration, Sequence +from factory.declarations import dig, OrderedDeclaration, Sequence class OrderedDeclarationTestCase(unittest.TestCase): def test_errors(self): decl = OrderedDeclaration() self.assertRaises(NotImplementedError, decl.evaluate, None, {}) + +class DigTestCase(unittest.TestCase): + class MyObj(object): + def __init__(self, n): + self.n = n + + def test_parentattr(self): + obj = self.MyObj(1) + obj.a__b__c = self.MyObj(2) + obj.a = self.MyObj(3) + obj.a.b = self.MyObj(4) + obj.a.b.c = self.MyObj(5) + + self.assertEqual(2, dig(obj, 'a__b__c').n) + + def test_private(self): + obj = self.MyObj(1) + self.assertEqual(obj.__class__, dig(obj, '__class__')) + + def test_chaining(self): + obj = self.MyObj(1) + obj.a = self.MyObj(2) + obj.a__c = self.MyObj(3) + obj.a.b = self.MyObj(4) + obj.a.b.c = self.MyObj(5) + + self.assertEqual(2, dig(obj, 'a').n) + self.assertRaises(AttributeError, dig, obj, 'b') + self.assertEqual(2, dig(obj, 'a__n')) + self.assertEqual(3, dig(obj, 'a__c').n) + self.assertRaises(AttributeError, dig, obj, 'a__c__n') + self.assertRaises(AttributeError, dig, obj, 'a__d') + self.assertEqual(4, dig(obj, 'a__b').n) + self.assertEqual(4, dig(obj, 'a__b__n')) + self.assertEqual(5, dig(obj, 'a__b__c').n) + self.assertEqual(5, dig(obj, 'a__b__c__n')) + + + if __name__ == '__main__': unittest.main() |