aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/changelog.rst7
-rw-r--r--factory/base.py74
-rw-r--r--tests/test_alchemy.py2
-rw-r--r--tests/test_using.py72
4 files changed, 139 insertions, 16 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 326b245..25d6a06 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -1,15 +1,18 @@
ChangeLog
=========
-.. _v2.1.3:
+.. _v2.2.0:
-2.1.3 (current)
+2.2.0 (current)
---------------
*Bugfix:*
- Removed duplicated :class:`~factory.alchemy.SQLAlchemyModelFactory` lurking in :mod:`factory`
(:issue:`83`)
+ - Properly handle sequences within object inheritance chains.
+ If FactoryA inherits from FactoryB, and their associated classes share the same link,
+ sequence counters will be shared (:issue:`93`)
*New:*
diff --git a/factory/base.py b/factory/base.py
index 1b9fa0d..462a60c 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -186,7 +186,8 @@ class FactoryMetaClass(type):
else:
# If inheriting the factory from a parent, keep a link to it.
# This allows to use the sequence counters from the parents.
- if associated_class == inherited_associated_class:
+ if (inherited_associated_class is not None
+ and issubclass(associated_class, inherited_associated_class)):
attrs['_base_factory'] = base
# The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into
@@ -212,6 +213,32 @@ class FactoryMetaClass(type):
# Factory base classes
+
+class _Counter(object):
+ """Simple, naive counter.
+
+ Attributes:
+ for_class (obj): the class this counter related to
+ seq (int): the next value
+ """
+
+ def __init__(self, seq, for_class):
+ self.seq = seq
+ self.for_class = for_class
+
+ def next(self):
+ value = self.seq
+ self.seq += 1
+ return value
+
+ def reset(self, next_value=0):
+ self.seq = next_value
+
+ def __repr__(self):
+ return '<_Counter for %s.%s, next=%d>' % (
+ self.for_class.__module__, self.for_class.__name__, self.seq)
+
+
class BaseFactory(object):
"""Factory base support for sequences, attributes and stubs."""
@@ -224,10 +251,10 @@ class BaseFactory(object):
raise FactoryError('You cannot instantiate BaseFactory')
# ID to use for the next 'declarations.Sequence' attribute.
- _next_sequence = None
+ _counter = None
# Base factory, if this class was inherited from another factory. This is
- # used for sharing the _next_sequence counter among factories for the same
+ # used for sharing the sequence _counter among factories for the same
# class.
_base_factory = None
@@ -245,7 +272,14 @@ class BaseFactory(object):
@classmethod
def reset_sequence(cls, value=None, force=False):
- """Reset the sequence counter."""
+ """Reset the sequence counter.
+
+ Args:
+ value (int or None): the new 'next' sequence value; if None,
+ recompute the next value from _setup_next_sequence().
+ force (bool): whether to force-reset parent sequence counters
+ in a factory inheritance chain.
+ """
if cls._base_factory:
if force:
cls._base_factory.reset_sequence(value=value)
@@ -253,10 +287,13 @@ class BaseFactory(object):
raise ValueError(
"Cannot reset the sequence of a factory subclass. "
"Please call reset_sequence() on the root factory, "
- "or call reset_sequence(forward=True)."
+ "or call reset_sequence(force=True)."
)
else:
- cls._next_sequence = value
+ cls._setup_counter()
+ if value is None:
+ value = cls._setup_next_sequence()
+ cls._counter.reset(value)
@classmethod
def _setup_next_sequence(cls):
@@ -268,6 +305,19 @@ class BaseFactory(object):
return 0
@classmethod
+ def _setup_counter(cls):
+ """Ensures cls._counter is set for this class.
+
+ Due to the way inheritance works in Python, we need to ensure that the
+ ``_counter`` attribute has been initialized for *this* Factory subclass,
+ not one of its parents.
+ """
+ if cls._counter is None or cls._counter.for_class != cls:
+ first_seq = cls._setup_next_sequence()
+ cls._counter = _Counter(for_class=cls, seq=first_seq)
+ logger.debug("%r: Setting up next sequence (%d)", cls, first_seq)
+
+ @classmethod
def _generate_next_sequence(cls):
"""Retrieve a new sequence ID.
@@ -279,16 +329,14 @@ class BaseFactory(object):
# Rely upon our parents
if cls._base_factory:
+ logger.debug("%r: reusing sequence from %r", cls, cls._base_factory)
return cls._base_factory._generate_next_sequence()
- # Make sure _next_sequence is initialized
- if cls._next_sequence is None:
- cls._next_sequence = cls._setup_next_sequence()
+ # Make sure _counter is initialized
+ cls._setup_counter()
# Pick current value, then increase class counter for the next call.
- next_sequence = cls._next_sequence
- cls._next_sequence += 1
- return next_sequence
+ return cls._counter.next()
@classmethod
def attributes(cls, create=False, extra=None):
@@ -577,7 +625,7 @@ Factory = FactoryMetaClass('Factory', (BaseFactory,), {
This class has the ability to support multiple ORMs by using custom creation
functions.
""",
- })
+})
# Backwards compatibility
diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py
index cfbc835..4255417 100644
--- a/tests/test_alchemy.py
+++ b/tests/test_alchemy.py
@@ -65,7 +65,7 @@ class SQLAlchemyPkSequenceTestCase(unittest.TestCase):
def setUp(self):
super(SQLAlchemyPkSequenceTestCase, self).setUp()
- StandardFactory.reset_sequence()
+ StandardFactory.reset_sequence(1)
NonIntegerPkFactory.FACTORY_SESSION.rollback()
def test_pk_first(self):
diff --git a/tests/test_using.py b/tests/test_using.py
index 0898a13..01e950f 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -730,6 +730,78 @@ class UsingFactoryTestCase(unittest.TestCase):
test_object_alt = TestObjectFactory.build()
self.assertEqual(None, test_object_alt.three)
+ def test_inheritance_and_sequences(self):
+ """Sequence counters should be kept within an inheritance chain."""
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: n)
+
+ class TestObjectFactory2(TestObjectFactory):
+ FACTORY_FOR = TestObject
+
+ to1a = TestObjectFactory()
+ self.assertEqual(0, to1a.one)
+ to2a = TestObjectFactory2()
+ self.assertEqual(1, to2a.one)
+ to1b = TestObjectFactory()
+ self.assertEqual(2, to1b.one)
+ to2b = TestObjectFactory2()
+ self.assertEqual(3, to2b.one)
+
+ def test_inheritance_sequence_inheriting_objects(self):
+ """Sequence counters are kept with inheritance, incl. misc objects."""
+ class TestObject2(TestObject):
+ pass
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: n)
+
+ class TestObjectFactory2(TestObjectFactory):
+ FACTORY_FOR = TestObject2
+
+ to1a = TestObjectFactory()
+ self.assertEqual(0, to1a.one)
+ to2a = TestObjectFactory2()
+ self.assertEqual(1, to2a.one)
+ to1b = TestObjectFactory()
+ self.assertEqual(2, to1b.one)
+ to2b = TestObjectFactory2()
+ self.assertEqual(3, to2b.one)
+
+ def test_inheritance_sequence_unrelated_objects(self):
+ """Sequence counters are kept with inheritance, unrelated objects.
+
+ See issue https://github.com/rbarrois/factory_boy/issues/93
+
+ Problem: sequence counter is somewhat shared between factories
+ until the "slave" factory has been called.
+ """
+
+ class TestObject2(object):
+ def __init__(self, one):
+ self.one = one
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: n)
+
+ class TestObjectFactory2(TestObjectFactory):
+ FACTORY_FOR = TestObject2
+
+ to1a = TestObjectFactory()
+ self.assertEqual(0, to1a.one)
+ to2a = TestObjectFactory2()
+ self.assertEqual(0, to2a.one)
+ to1b = TestObjectFactory()
+ self.assertEqual(1, to1b.one)
+ to2b = TestObjectFactory2()
+ self.assertEqual(1, to2b.one)
+
+
def test_inheritance_with_inherited_class(self):
class TestObjectFactory(factory.Factory):
FACTORY_FOR = TestObject