summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/declarations.py15
-rw-r--r--tests/test_declarations.py69
-rw-r--r--tests/test_using.py17
3 files changed, 81 insertions, 20 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 50a826f..fe1afa4 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -100,7 +100,11 @@ def deepgetattr(obj, name, default=_UNSPECIFIED):
class SelfAttribute(OrderedDeclaration):
"""Specific OrderedDeclaration copying values from other fields.
+ If the field name starts with two dots or more, the lookup will be anchored
+ in the related 'parent'.
+
Attributes:
+ depth (int): the number of steps to go up in the containers chain
attribute_name (str): the name of the attribute to copy.
default (object): the default value to use if the attribute doesn't
exist.
@@ -108,11 +112,20 @@ class SelfAttribute(OrderedDeclaration):
def __init__(self, attribute_name, default=_UNSPECIFIED, *args, **kwargs):
super(SelfAttribute, self).__init__(*args, **kwargs)
+ depth = len(attribute_name) - len(attribute_name.lstrip('.'))
+ attribute_name = attribute_name[depth:]
+
+ self.depth = depth
self.attribute_name = attribute_name
self.default = default
def evaluate(self, sequence, obj, containers=()):
- return deepgetattr(obj, self.attribute_name, self.default)
+ if self.depth > 1:
+ # Fetching from a parent
+ target = containers[self.depth - 2]
+ else:
+ target = obj
+ return deepgetattr(target, self.attribute_name, self.default)
class Iterator(OrderedDeclaration):
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index c0b3539..4dea429 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -22,14 +22,13 @@
import datetime
-from factory.declarations import deepgetattr, CircularSubFactory, OrderedDeclaration, \
- PostGenerationDeclaration, Sequence
+from factory import declarations
from .compat import unittest
class OrderedDeclarationTestCase(unittest.TestCase):
def test_errors(self):
- decl = OrderedDeclaration()
+ decl = declarations.OrderedDeclaration()
self.assertRaises(NotImplementedError, decl.evaluate, None, {})
@@ -44,29 +43,61 @@ class DigTestCase(unittest.TestCase):
obj.a.b = self.MyObj(3)
obj.a.b.c = self.MyObj(4)
- self.assertEqual(2, deepgetattr(obj, 'a').n)
- self.assertRaises(AttributeError, deepgetattr, obj, 'b')
- self.assertEqual(2, deepgetattr(obj, 'a.n'))
- self.assertEqual(3, deepgetattr(obj, 'a.c', 3))
- self.assertRaises(AttributeError, deepgetattr, obj, 'a.c.n')
- self.assertRaises(AttributeError, deepgetattr, obj, 'a.d')
- self.assertEqual(3, deepgetattr(obj, 'a.b').n)
- self.assertEqual(3, deepgetattr(obj, 'a.b.n'))
- self.assertEqual(4, deepgetattr(obj, 'a.b.c').n)
- self.assertEqual(4, deepgetattr(obj, 'a.b.c.n'))
- self.assertEqual(42, deepgetattr(obj, 'a.b.c.n.x', 42))
+ self.assertEqual(2, declarations.deepgetattr(obj, 'a').n)
+ self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'b')
+ self.assertEqual(2, declarations.deepgetattr(obj, 'a.n'))
+ self.assertEqual(3, declarations.deepgetattr(obj, 'a.c', 3))
+ self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'a.c.n')
+ self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'a.d')
+ self.assertEqual(3, declarations.deepgetattr(obj, 'a.b').n)
+ self.assertEqual(3, declarations.deepgetattr(obj, 'a.b.n'))
+ self.assertEqual(4, declarations.deepgetattr(obj, 'a.b.c').n)
+ self.assertEqual(4, declarations.deepgetattr(obj, 'a.b.c.n'))
+ self.assertEqual(42, declarations.deepgetattr(obj, 'a.b.c.n.x', 42))
+
+
+class SelfAttributeTestCase(unittest.TestCase):
+ def test_standard(self):
+ a = declarations.SelfAttribute('foo.bar.baz')
+ self.assertEqual(0, a.depth)
+ self.assertEqual('foo.bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+ def test_dot(self):
+ a = declarations.SelfAttribute('.bar.baz')
+ self.assertEqual(1, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+ def test_default(self):
+ a = declarations.SelfAttribute('bar.baz', 42)
+ self.assertEqual(0, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(42, a.default)
+
+ def test_parent(self):
+ a = declarations.SelfAttribute('..bar.baz')
+ self.assertEqual(2, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+ def test_grandparent(self):
+ a = declarations.SelfAttribute('...bar.baz')
+ self.assertEqual(3, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
class PostGenerationDeclarationTestCase(unittest.TestCase):
def test_extract_no_prefix(self):
- decl = PostGenerationDeclaration()
+ decl = declarations.PostGenerationDeclaration()
extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42})
self.assertEqual(extracted, 13)
self.assertEqual(kwargs, {'bar': 42})
def test_extract_with_prefix(self):
- decl = PostGenerationDeclaration(extract_prefix='blah')
+ decl = declarations.PostGenerationDeclaration(extract_prefix='blah')
extracted, kwargs = decl.extract('foo',
{'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1})
@@ -76,17 +107,17 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
class CircularSubFactoryTestCase(unittest.TestCase):
def test_lazyness(self):
- f = CircularSubFactory('factory.declarations', 'Sequence', x=3)
+ f = declarations.CircularSubFactory('factory.declarations', 'Sequence', x=3)
self.assertEqual(None, f.factory)
self.assertEqual({'x': 3}, f.defaults)
factory_class = f.get_factory()
- self.assertEqual(Sequence, factory_class)
+ self.assertEqual(declarations.Sequence, factory_class)
def test_cache(self):
orig_date = datetime.date
- f = CircularSubFactory('datetime', 'date')
+ f = declarations.CircularSubFactory('datetime', 'date')
self.assertEqual(None, f.factory)
factory_class = f.get_factory()
diff --git a/tests/test_using.py b/tests/test_using.py
index ad62113..38c9e9e 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -365,6 +365,23 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(3, test_object.four)
self.assertEqual(5, test_object.five)
+ def testSelfAttributeParent(self):
+ class TestModel2(FakeModel):
+ pass
+
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
+ one = 3
+ three = factory.SelfAttribute('..bar')
+
+ class TestModel2Factory(FakeModelFactory):
+ FACTORY_FOR = TestModel2
+ bar = 4
+ two = factory.SubFactory(TestModelFactory, one=1)
+
+ test_model = TestModel2Factory()
+ self.assertEqual(4, test_model.two.three)
+
def testSequenceDecorator(self):
class TestObjectFactory(factory.Factory):
FACTORY_FOR = TestObject