diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-06-15 00:17:41 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-06-15 00:17:41 +0200 |
commit | fda40cb64041aacdb776e0b1f4f4a635bdc9d70b (patch) | |
tree | c0d30d702c6886e5f8542daffb7aef080137ed0b | |
parent | 1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b (diff) | |
download | factory-boy-fda40cb64041aacdb776e0b1f4f4a635bdc9d70b.tar factory-boy-fda40cb64041aacdb776e0b1f4f4a635bdc9d70b.tar.gz |
Add Iterator.reset() (Closes #63).
-rw-r--r-- | docs/changelog.rst | 1 | ||||
-rw-r--r-- | docs/reference.rst | 26 | ||||
-rw-r--r-- | factory/declarations.py | 11 | ||||
-rw-r--r-- | factory/utils.py | 23 | ||||
-rw-r--r-- | tests/test_declarations.py | 21 | ||||
-rw-r--r-- | tests/test_utils.py | 106 |
6 files changed, 184 insertions, 4 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst index 6489176..98b177d 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -16,6 +16,7 @@ ChangeLog - Add the :meth:`~factory.Factory.reset_sequence` classmethod to :class:`~factory.Factory` to ease resetting the sequence counter for a given factory. - Add debug messages to ``factory`` logger. + - Add a :meth:`~factory.Iterator.reset` method to :class:`~factory.Iterator` (:issue:`63`) *Bugfix* diff --git a/docs/reference.rst b/docs/reference.rst index 74f2dbd..e98665f 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -884,6 +884,14 @@ Iterator .. versionadded:: 1.3.0 + .. method:: reset() + + Reset the internal iterator used by the attribute, so that the next value + will be the first value generated by the iterator. + + May be called several times. + + Each call to the factory will receive the next value from the iterable: .. code-block:: python @@ -953,6 +961,24 @@ use the :func:`iterator` decorator: yield line +Resetting +~~~~~~~~~ + +In order to start back at the first value in an :class:`Iterator`, +simply call the :meth:`~Iterator.reset` method of that attribute +(accessing it from the bare :class:`~Factory` subclass): + +.. code-block:: pycon + + >>> UserFactory().lang + 'en' + >>> UserFactory().lang + 'fr' + >>> UserFactory.lang.reset() + >>> UserFactory().lang + 'en' + + Dict and List """"""""""""" diff --git a/factory/declarations.py b/factory/declarations.py index f068c0d..552ddf2 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -163,17 +163,20 @@ class Iterator(OrderedDeclaration): self.getter = getter if cycle: - self.iterator = itertools.cycle(iterator) - else: - self.iterator = iter(iterator) + iterator = itertools.cycle(iterator) + self.iterator = utils.ResetableIterator(iterator) def evaluate(self, sequence, obj, create, extra=None, containers=()): logger.debug("Iterator: Fetching next value from %r", self.iterator) - value = next(self.iterator) + value = next(iter(self.iterator)) if self.getter is None: return value return self.getter(value) + def reset(self): + """Reset the internal iterator.""" + self.iterator.reset() + class Sequence(OrderedDeclaration): """Specific OrderedDeclaration to use for 'sequenced' fields. diff --git a/factory/utils.py b/factory/utils.py index e1b265f..48c6eed 100644 --- a/factory/utils.py +++ b/factory/utils.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import collections #: String for splitting an attribute name into a #: (subfactory_name, subfactory_field) tuple. @@ -101,3 +102,25 @@ def log_pprint(args=(), kwargs=None): [str(arg) for arg in args] + ['%s=%r' % item for item in kwargs.items()] ) + + +class ResetableIterator(object): + """An iterator wrapper that can be 'reset()' to its start.""" + def __init__(self, iterator, **kwargs): + super(ResetableIterator, self).__init__(**kwargs) + self.iterator = iter(iterator) + self.past_elements = collections.deque() + self.next_elements = collections.deque() + + def __iter__(self): + while True: + if self.next_elements: + yield self.next_elements.popleft() + else: + value = next(self.iterator) + self.past_elements.append(value) + yield value + + def reset(self): + self.next_elements.clear() + self.next_elements.extend(self.past_elements) diff --git a/tests/test_declarations.py b/tests/test_declarations.py index cd38dd2..9d54c59 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -107,6 +107,27 @@ class IteratorTestCase(unittest.TestCase): self.assertEqual(2, it.evaluate(1, None, False)) self.assertRaises(StopIteration, it.evaluate, 2, None, False) + def test_reset_cycle(self): + it = declarations.Iterator([1, 2]) + self.assertEqual(1, it.evaluate(0, None, False)) + self.assertEqual(2, it.evaluate(1, None, False)) + self.assertEqual(1, it.evaluate(2, None, False)) + self.assertEqual(2, it.evaluate(3, None, False)) + self.assertEqual(1, it.evaluate(4, None, False)) + it.reset() + self.assertEqual(1, it.evaluate(5, None, False)) + self.assertEqual(2, it.evaluate(6, None, False)) + + def test_reset_no_cycling(self): + it = declarations.Iterator([1, 2], cycle=False) + self.assertEqual(1, it.evaluate(0, None, False)) + self.assertEqual(2, it.evaluate(1, None, False)) + self.assertRaises(StopIteration, it.evaluate, 2, None, False) + it.reset() + self.assertEqual(1, it.evaluate(0, None, False)) + self.assertEqual(2, it.evaluate(1, None, False)) + self.assertRaises(StopIteration, it.evaluate, 2, None, False) + def test_getter(self): it = declarations.Iterator([(1, 2), (1, 3)], getter=lambda p: p[1]) self.assertEqual(2, it.evaluate(0, None, False)) diff --git a/tests/test_utils.py b/tests/test_utils.py index b353c9d..8c73935 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. +import itertools from factory import utils @@ -230,3 +231,108 @@ class ImportObjectTestCase(unittest.TestCase): def test_invalid_module(self): self.assertRaises(ImportError, utils.import_object, 'this-is-an-invalid-module', '__name__') + + +class ResetableIteratorTestCase(unittest.TestCase): + def test_no_reset(self): + i = utils.ResetableIterator([1, 2, 3]) + self.assertEqual([1, 2, 3], list(i)) + + def test_no_reset_new_iterator(self): + i = utils.ResetableIterator([1, 2, 3]) + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + + iterator2 = iter(i) + self.assertEqual(3, next(iterator2)) + + def test_infinite(self): + i = utils.ResetableIterator(itertools.cycle([1, 2, 3])) + iterator = iter(i) + values = [next(iterator) for _i in range(10)] + self.assertEqual([1, 2, 3, 1, 2, 3, 1, 2, 3, 1], values) + + def test_reset_simple(self): + i = utils.ResetableIterator([1, 2, 3]) + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + + def test_reset_at_begin(self): + i = utils.ResetableIterator([1, 2, 3]) + iterator = iter(i) + i.reset() + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + + def test_reset_at_end(self): + i = utils.ResetableIterator([1, 2, 3]) + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + + def test_reset_after_end(self): + i = utils.ResetableIterator([1, 2, 3]) + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + self.assertRaises(StopIteration, next, iterator) + + i.reset() + # Previous iter() has stopped + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + + def test_reset_twice(self): + i = utils.ResetableIterator([1, 2, 3, 4, 5]) + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + self.assertEqual(4, next(iterator)) + + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + self.assertEqual(4, next(iterator)) + + def test_reset_shorter(self): + i = utils.ResetableIterator([1, 2, 3, 4, 5]) + iterator = iter(i) + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + self.assertEqual(4, next(iterator)) + + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + + i.reset() + self.assertEqual(1, next(iterator)) + self.assertEqual(2, next(iterator)) + self.assertEqual(3, next(iterator)) + self.assertEqual(4, next(iterator)) + |