summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/changelog.rst1
-rw-r--r--docs/reference.rst26
-rw-r--r--factory/declarations.py11
-rw-r--r--factory/utils.py23
-rw-r--r--tests/test_declarations.py21
-rw-r--r--tests/test_utils.py106
6 files changed, 184 insertions, 4 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 6489176..98b177d 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -16,6 +16,7 @@ ChangeLog
- Add the :meth:`~factory.Factory.reset_sequence` classmethod to :class:`~factory.Factory`
to ease resetting the sequence counter for a given factory.
- Add debug messages to ``factory`` logger.
+ - Add a :meth:`~factory.Iterator.reset` method to :class:`~factory.Iterator` (:issue:`63`)
*Bugfix*
diff --git a/docs/reference.rst b/docs/reference.rst
index 74f2dbd..e98665f 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -884,6 +884,14 @@ Iterator
.. versionadded:: 1.3.0
+ .. method:: reset()
+
+ Reset the internal iterator used by the attribute, so that the next value
+ will be the first value generated by the iterator.
+
+ May be called several times.
+
+
Each call to the factory will receive the next value from the iterable:
.. code-block:: python
@@ -953,6 +961,24 @@ use the :func:`iterator` decorator:
yield line
+Resetting
+~~~~~~~~~
+
+In order to start back at the first value in an :class:`Iterator`,
+simply call the :meth:`~Iterator.reset` method of that attribute
+(accessing it from the bare :class:`~Factory` subclass):
+
+.. code-block:: pycon
+
+ >>> UserFactory().lang
+ 'en'
+ >>> UserFactory().lang
+ 'fr'
+ >>> UserFactory.lang.reset()
+ >>> UserFactory().lang
+ 'en'
+
+
Dict and List
"""""""""""""
diff --git a/factory/declarations.py b/factory/declarations.py
index f068c0d..552ddf2 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -163,17 +163,20 @@ class Iterator(OrderedDeclaration):
self.getter = getter
if cycle:
- self.iterator = itertools.cycle(iterator)
- else:
- self.iterator = iter(iterator)
+ iterator = itertools.cycle(iterator)
+ self.iterator = utils.ResetableIterator(iterator)
def evaluate(self, sequence, obj, create, extra=None, containers=()):
logger.debug("Iterator: Fetching next value from %r", self.iterator)
- value = next(self.iterator)
+ value = next(iter(self.iterator))
if self.getter is None:
return value
return self.getter(value)
+ def reset(self):
+ """Reset the internal iterator."""
+ self.iterator.reset()
+
class Sequence(OrderedDeclaration):
"""Specific OrderedDeclaration to use for 'sequenced' fields.
diff --git a/factory/utils.py b/factory/utils.py
index e1b265f..48c6eed 100644
--- a/factory/utils.py
+++ b/factory/utils.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import collections
#: String for splitting an attribute name into a
#: (subfactory_name, subfactory_field) tuple.
@@ -101,3 +102,25 @@ def log_pprint(args=(), kwargs=None):
[str(arg) for arg in args] +
['%s=%r' % item for item in kwargs.items()]
)
+
+
+class ResetableIterator(object):
+ """An iterator wrapper that can be 'reset()' to its start."""
+ def __init__(self, iterator, **kwargs):
+ super(ResetableIterator, self).__init__(**kwargs)
+ self.iterator = iter(iterator)
+ self.past_elements = collections.deque()
+ self.next_elements = collections.deque()
+
+ def __iter__(self):
+ while True:
+ if self.next_elements:
+ yield self.next_elements.popleft()
+ else:
+ value = next(self.iterator)
+ self.past_elements.append(value)
+ yield value
+
+ def reset(self):
+ self.next_elements.clear()
+ self.next_elements.extend(self.past_elements)
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index cd38dd2..9d54c59 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -107,6 +107,27 @@ class IteratorTestCase(unittest.TestCase):
self.assertEqual(2, it.evaluate(1, None, False))
self.assertRaises(StopIteration, it.evaluate, 2, None, False)
+ def test_reset_cycle(self):
+ it = declarations.Iterator([1, 2])
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertEqual(1, it.evaluate(2, None, False))
+ self.assertEqual(2, it.evaluate(3, None, False))
+ self.assertEqual(1, it.evaluate(4, None, False))
+ it.reset()
+ self.assertEqual(1, it.evaluate(5, None, False))
+ self.assertEqual(2, it.evaluate(6, None, False))
+
+ def test_reset_no_cycling(self):
+ it = declarations.Iterator([1, 2], cycle=False)
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertRaises(StopIteration, it.evaluate, 2, None, False)
+ it.reset()
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertRaises(StopIteration, it.evaluate, 2, None, False)
+
def test_getter(self):
it = declarations.Iterator([(1, 2), (1, 3)], getter=lambda p: p[1])
self.assertEqual(2, it.evaluate(0, None, False))
diff --git a/tests/test_utils.py b/tests/test_utils.py
index b353c9d..8c73935 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -20,6 +20,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import itertools
from factory import utils
@@ -230,3 +231,108 @@ class ImportObjectTestCase(unittest.TestCase):
def test_invalid_module(self):
self.assertRaises(ImportError, utils.import_object,
'this-is-an-invalid-module', '__name__')
+
+
+class ResetableIteratorTestCase(unittest.TestCase):
+ def test_no_reset(self):
+ i = utils.ResetableIterator([1, 2, 3])
+ self.assertEqual([1, 2, 3], list(i))
+
+ def test_no_reset_new_iterator(self):
+ i = utils.ResetableIterator([1, 2, 3])
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+
+ iterator2 = iter(i)
+ self.assertEqual(3, next(iterator2))
+
+ def test_infinite(self):
+ i = utils.ResetableIterator(itertools.cycle([1, 2, 3]))
+ iterator = iter(i)
+ values = [next(iterator) for _i in range(10)]
+ self.assertEqual([1, 2, 3, 1, 2, 3, 1, 2, 3, 1], values)
+
+ def test_reset_simple(self):
+ i = utils.ResetableIterator([1, 2, 3])
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+
+ def test_reset_at_begin(self):
+ i = utils.ResetableIterator([1, 2, 3])
+ iterator = iter(i)
+ i.reset()
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+
+ def test_reset_at_end(self):
+ i = utils.ResetableIterator([1, 2, 3])
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+
+ def test_reset_after_end(self):
+ i = utils.ResetableIterator([1, 2, 3])
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+ self.assertRaises(StopIteration, next, iterator)
+
+ i.reset()
+ # Previous iter() has stopped
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+
+ def test_reset_twice(self):
+ i = utils.ResetableIterator([1, 2, 3, 4, 5])
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+ self.assertEqual(4, next(iterator))
+
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+ self.assertEqual(4, next(iterator))
+
+ def test_reset_shorter(self):
+ i = utils.ResetableIterator([1, 2, 3, 4, 5])
+ iterator = iter(i)
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+ self.assertEqual(4, next(iterator))
+
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+
+ i.reset()
+ self.assertEqual(1, next(iterator))
+ self.assertEqual(2, next(iterator))
+ self.assertEqual(3, next(iterator))
+ self.assertEqual(4, next(iterator))
+