diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2012-02-24 23:58:54 +0100 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2012-02-24 23:58:54 +0100 |
commit | 8a459c5e26a14a531f78d740b325c996044df760 (patch) | |
tree | 32fa813787f885a8ac59970cd65d8350c0e91551 | |
parent | cbbe5cc359412c8e6c49e06d5d1f35680ad88c40 (diff) | |
download | factory-boy-8a459c5e26a14a531f78d740b325c996044df760.tar factory-boy-8a459c5e26a14a531f78d740b325c996044df760.tar.gz |
Add the Iterator and InfiniteIterator attribute kinds.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r-- | factory/__init__.py | 4 | ||||
-rw-r--r-- | factory/declarations.py | 39 | ||||
-rw-r--r-- | tests/test_using.py | 46 |
3 files changed, 89 insertions, 0 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index cf9cc3e..12f297e 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -52,6 +52,8 @@ from base import ( from declarations import ( LazyAttribute, + Iterator, + InfiniteIterator, Sequence, LazyAttributeSequence, SelfAttribute, @@ -59,6 +61,8 @@ from declarations import ( SubFactory, lazy_attribute, + iterator, + infinite_iterator, sequence, lazy_attribute_sequence, container_attribute, diff --git a/factory/declarations.py b/factory/declarations.py index 5fe427c..41d99a3 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -21,6 +21,9 @@ # THE SOFTWARE. +import itertools + + class OrderedDeclaration(object): """A factory declaration. @@ -110,6 +113,34 @@ class SelfAttribute(OrderedDeclaration): return deepgetattr(obj, self.attribute_name, self.default) +class Iterator(OrderedDeclaration): + """Fill this value using the values returned by an iterator. + + Warning: the iterator should not end ! + + Attributes: + iterator (iterable): the iterator whose value should be used. + """ + + def __init__(self, iterator): + super(Iterator, self).__init__() + self.iterator = iter(iterator) + + def evaluate(self, sequence, obj, containers=()): + return self.iterator.next() + + +class InfiniteIterator(Iterator): + """Same as Iterator, but make the iterator infinite by cycling at the end. + + Attributes: + iterator (iterable): the iterator, once made infinite. + """ + + def __init__(self, iterator): + return super(InfiniteIterator, self).__init__(itertools.cycle(iterator)) + + class Sequence(OrderedDeclaration): """Specific OrderedDeclaration to use for 'sequenced' fields. @@ -224,6 +255,14 @@ class SubFactory(OrderedDeclaration): def lazy_attribute(func): return LazyAttribute(func) +def iterator(func): + """Turn a generator function into an iterator attribute.""" + return Iterator(func()) + +def infinite_iterator(func): + """Turn a generator function into an infinite iterator attribute.""" + return InfiniteIterator(func()) + def sequence(func): return Sequence(func) diff --git a/tests/test_using.py b/tests/test_using.py index 4e69212..41b34ea 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -760,5 +760,51 @@ class SubFactoryTestCase(unittest.TestCase): self.assertEqual(outer.side_a.inner_from_a.b, 4) +class IteratorTestCase(unittest.TestCase): + + def test_iterator(self): + class TestObjectFactory(factory.Factory): + one = factory.Iterator(xrange(10, 30)) + + objs = TestObjectFactory.build_batch(20) + + for i, obj in enumerate(objs): + self.assertEqual(i + 10, obj.one) + + def test_infinite_iterator(self): + class TestObjectFactory(factory.Factory): + one = factory.InfiniteIterator(xrange(5)) + + objs = TestObjectFactory.build_batch(20) + + for i, obj in enumerate(objs): + self.assertEqual(i % 5, obj.one) + + def test_iterator_decorator(self): + class TestObjectFactory(factory.Factory): + @factory.iterator + def one(): + for i in xrange(10, 50): + yield i + + objs = TestObjectFactory.build_batch(20) + + for i, obj in enumerate(objs): + self.assertEqual(i + 10, obj.one) + + def test_infinite_iterator_decorator(self): + class TestObjectFactory(factory.Factory): + @factory.infinite_iterator + def one(): + for i in xrange(5): + yield i + + objs = TestObjectFactory.build_batch(20) + + for i, obj in enumerate(objs): + self.assertEqual(i % 5, obj.one) + + + if __name__ == '__main__': unittest.main() |