From c86d32b892c383fb18b0a5d7cebc7671e4e88ab1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Barrois?= Date: Wed, 14 Nov 2012 23:15:55 +0100 Subject: Mix SelfAttribute with ContainerAttribute. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit With a very simple syntax. Signed-off-by: Raphaƫl Barrois --- factory/declarations.py | 15 +++++++++- tests/test_declarations.py | 69 +++++++++++++++++++++++++++++++++------------- tests/test_using.py | 17 ++++++++++++ 3 files changed, 81 insertions(+), 20 deletions(-) diff --git a/factory/declarations.py b/factory/declarations.py index 50a826f..fe1afa4 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -100,7 +100,11 @@ def deepgetattr(obj, name, default=_UNSPECIFIED): class SelfAttribute(OrderedDeclaration): """Specific OrderedDeclaration copying values from other fields. + If the field name starts with two dots or more, the lookup will be anchored + in the related 'parent'. + Attributes: + depth (int): the number of steps to go up in the containers chain attribute_name (str): the name of the attribute to copy. default (object): the default value to use if the attribute doesn't exist. @@ -108,11 +112,20 @@ class SelfAttribute(OrderedDeclaration): def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs): super(SelfAttribute, self).__init__(*args, **kwargs) + depth = len(attribute_name) - len(attribute_name.lstrip('.')) + attribute_name = attribute_name[depth:] + + self.depth = depth self.attribute_name = attribute_name self.default = default def evaluate(self, sequence, obj, containers=()): - return deepgetattr(obj, self.attribute_name, self.default) + if self.depth > 1: + # Fetching from a parent + target = containers[self.depth - 2] + else: + target = obj + return deepgetattr(target, self.attribute_name, self.default) class Iterator(OrderedDeclaration): diff --git a/tests/test_declarations.py b/tests/test_declarations.py index c0b3539..4dea429 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -22,14 +22,13 @@ import datetime -from factory.declarations import deepgetattr, CircularSubFactory, OrderedDeclaration, \ - PostGenerationDeclaration, Sequence +from factory import declarations from .compat import unittest class OrderedDeclarationTestCase(unittest.TestCase): def test_errors(self): - decl = OrderedDeclaration() + decl = declarations.OrderedDeclaration() self.assertRaises(NotImplementedError, decl.evaluate, None, {}) @@ -44,29 +43,61 @@ class DigTestCase(unittest.TestCase): obj.a.b = self.MyObj(3) obj.a.b.c = self.MyObj(4) - self.assertEqual(2, deepgetattr(obj, 'a').n) - self.assertRaises(AttributeError, deepgetattr, obj, 'b') - self.assertEqual(2, deepgetattr(obj, 'a.n')) - self.assertEqual(3, deepgetattr(obj, 'a.c', 3)) - self.assertRaises(AttributeError, deepgetattr, obj, 'a.c.n') - self.assertRaises(AttributeError, deepgetattr, obj, 'a.d') - self.assertEqual(3, deepgetattr(obj, 'a.b').n) - self.assertEqual(3, deepgetattr(obj, 'a.b.n')) - self.assertEqual(4, deepgetattr(obj, 'a.b.c').n) - self.assertEqual(4, deepgetattr(obj, 'a.b.c.n')) - self.assertEqual(42, deepgetattr(obj, 'a.b.c.n.x', 42)) + self.assertEqual(2, declarations.deepgetattr(obj, 'a').n) + self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'b') + self.assertEqual(2, declarations.deepgetattr(obj, 'a.n')) + self.assertEqual(3, declarations.deepgetattr(obj, 'a.c', 3)) + self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'a.c.n') + self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'a.d') + self.assertEqual(3, declarations.deepgetattr(obj, 'a.b').n) + self.assertEqual(3, declarations.deepgetattr(obj, 'a.b.n')) + self.assertEqual(4, declarations.deepgetattr(obj, 'a.b.c').n) + self.assertEqual(4, declarations.deepgetattr(obj, 'a.b.c.n')) + self.assertEqual(42, declarations.deepgetattr(obj, 'a.b.c.n.x', 42)) + + +class SelfAttributeTestCase(unittest.TestCase): + def test_standard(self): + a = declarations.SelfAttribute('foo.bar.baz') + self.assertEqual(0, a.depth) + self.assertEqual('foo.bar.baz', a.attribute_name) + self.assertEqual(declarations._UNSPECIFIED, a.default) + + def test_dot(self): + a = declarations.SelfAttribute('.bar.baz') + self.assertEqual(1, a.depth) + self.assertEqual('bar.baz', a.attribute_name) + self.assertEqual(declarations._UNSPECIFIED, a.default) + + def test_default(self): + a = declarations.SelfAttribute('bar.baz', 42) + self.assertEqual(0, a.depth) + self.assertEqual('bar.baz', a.attribute_name) + self.assertEqual(42, a.default) + + def test_parent(self): + a = declarations.SelfAttribute('..bar.baz') + self.assertEqual(2, a.depth) + self.assertEqual('bar.baz', a.attribute_name) + self.assertEqual(declarations._UNSPECIFIED, a.default) + + def test_grandparent(self): + a = declarations.SelfAttribute('...bar.baz') + self.assertEqual(3, a.depth) + self.assertEqual('bar.baz', a.attribute_name) + self.assertEqual(declarations._UNSPECIFIED, a.default) class PostGenerationDeclarationTestCase(unittest.TestCase): def test_extract_no_prefix(self): - decl = PostGenerationDeclaration() + decl = declarations.PostGenerationDeclaration() extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42}) self.assertEqual(extracted, 13) self.assertEqual(kwargs, {'bar': 42}) def test_extract_with_prefix(self): - decl = PostGenerationDeclaration(extract_prefix='blah') + decl = declarations.PostGenerationDeclaration(extract_prefix='blah') extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1}) @@ -76,17 +107,17 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): class CircularSubFactoryTestCase(unittest.TestCase): def test_lazyness(self): - f = CircularSubFactory('factory.declarations', 'Sequence', x=3) + f = declarations.CircularSubFactory('factory.declarations', 'Sequence', x=3) self.assertEqual(None, f.factory) self.assertEqual({'x': 3}, f.defaults) factory_class = f.get_factory() - self.assertEqual(Sequence, factory_class) + self.assertEqual(declarations.Sequence, factory_class) def test_cache(self): orig_date = datetime.date - f = CircularSubFactory('datetime', 'date') + f = declarations.CircularSubFactory('datetime', 'date') self.assertEqual(None, f.factory) factory_class = f.get_factory() diff --git a/tests/test_using.py b/tests/test_using.py index ad62113..38c9e9e 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -365,6 +365,23 @@ class UsingFactoryTestCase(unittest.TestCase): self.assertEqual(3, test_object.four) self.assertEqual(5, test_object.five) + def testSelfAttributeParent(self): + class TestModel2(FakeModel): + pass + + class TestModelFactory(FakeModelFactory): + FACTORY_FOR = TestModel + one = 3 + three = factory.SelfAttribute('..bar') + + class TestModel2Factory(FakeModelFactory): + FACTORY_FOR = TestModel2 + bar = 4 + two = factory.SubFactory(TestModelFactory, one=1) + + test_model = TestModel2Factory() + self.assertEqual(4, test_model.two.three) + def testSequenceDecorator(self): class TestObjectFactory(factory.Factory): FACTORY_FOR = TestObject -- cgit v1.2.3