summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2012-02-24 23:58:54 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2012-02-24 23:58:54 +0100
commit8a459c5e26a14a531f78d740b325c996044df760 (patch)
tree32fa813787f885a8ac59970cd65d8350c0e91551
parentcbbe5cc359412c8e6c49e06d5d1f35680ad88c40 (diff)
downloadfactory-boy-8a459c5e26a14a531f78d740b325c996044df760.tar
factory-boy-8a459c5e26a14a531f78d740b325c996044df760.tar.gz
Add the Iterator and InfiniteIterator attribute kinds.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r--factory/__init__.py4
-rw-r--r--factory/declarations.py39
-rw-r--r--tests/test_using.py46
3 files changed, 89 insertions, 0 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index cf9cc3e..12f297e 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -52,6 +52,8 @@ from base import (
from declarations import (
LazyAttribute,
+ Iterator,
+ InfiniteIterator,
Sequence,
LazyAttributeSequence,
SelfAttribute,
@@ -59,6 +61,8 @@ from declarations import (
SubFactory,
lazy_attribute,
+ iterator,
+ infinite_iterator,
sequence,
lazy_attribute_sequence,
container_attribute,
diff --git a/factory/declarations.py b/factory/declarations.py
index 5fe427c..41d99a3 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -21,6 +21,9 @@
# THE SOFTWARE.
+import itertools
+
+
class OrderedDeclaration(object):
"""A factory declaration.
@@ -110,6 +113,34 @@ class SelfAttribute(OrderedDeclaration):
return deepgetattr(obj, self.attribute_name, self.default)
+class Iterator(OrderedDeclaration):
+ """Fill this value using the values returned by an iterator.
+
+ Warning: the iterator should not end !
+
+ Attributes:
+ iterator (iterable): the iterator whose value should be used.
+ """
+
+ def __init__(self, iterator):
+ super(Iterator, self).__init__()
+ self.iterator = iter(iterator)
+
+ def evaluate(self, sequence, obj, containers=()):
+ return self.iterator.next()
+
+
+class InfiniteIterator(Iterator):
+ """Same as Iterator, but make the iterator infinite by cycling at the end.
+
+ Attributes:
+ iterator (iterable): the iterator, once made infinite.
+ """
+
+ def __init__(self, iterator):
+ return super(InfiniteIterator, self).__init__(itertools.cycle(iterator))
+
+
class Sequence(OrderedDeclaration):
"""Specific OrderedDeclaration to use for 'sequenced' fields.
@@ -224,6 +255,14 @@ class SubFactory(OrderedDeclaration):
def lazy_attribute(func):
return LazyAttribute(func)
+def iterator(func):
+ """Turn a generator function into an iterator attribute."""
+ return Iterator(func())
+
+def infinite_iterator(func):
+ """Turn a generator function into an infinite iterator attribute."""
+ return InfiniteIterator(func())
+
def sequence(func):
return Sequence(func)
diff --git a/tests/test_using.py b/tests/test_using.py
index 4e69212..41b34ea 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -760,5 +760,51 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(outer.side_a.inner_from_a.b, 4)
+class IteratorTestCase(unittest.TestCase):
+
+ def test_iterator(self):
+ class TestObjectFactory(factory.Factory):
+ one = factory.Iterator(xrange(10, 30))
+
+ objs = TestObjectFactory.build_batch(20)
+
+ for i, obj in enumerate(objs):
+ self.assertEqual(i + 10, obj.one)
+
+ def test_infinite_iterator(self):
+ class TestObjectFactory(factory.Factory):
+ one = factory.InfiniteIterator(xrange(5))
+
+ objs = TestObjectFactory.build_batch(20)
+
+ for i, obj in enumerate(objs):
+ self.assertEqual(i % 5, obj.one)
+
+ def test_iterator_decorator(self):
+ class TestObjectFactory(factory.Factory):
+ @factory.iterator
+ def one():
+ for i in xrange(10, 50):
+ yield i
+
+ objs = TestObjectFactory.build_batch(20)
+
+ for i, obj in enumerate(objs):
+ self.assertEqual(i + 10, obj.one)
+
+ def test_infinite_iterator_decorator(self):
+ class TestObjectFactory(factory.Factory):
+ @factory.infinite_iterator
+ def one():
+ for i in xrange(5):
+ yield i
+
+ objs = TestObjectFactory.build_batch(20)
+
+ for i, obj in enumerate(objs):
+ self.assertEqual(i % 5, obj.one)
+
+
+
if __name__ == '__main__':
unittest.main()