diff options
-rw-r--r-- | docs/changelog.rst | 7 | ||||
-rw-r--r-- | factory/base.py | 74 | ||||
-rw-r--r-- | tests/test_alchemy.py | 2 | ||||
-rw-r--r-- | tests/test_using.py | 72 |
4 files changed, 139 insertions, 16 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst index 326b245..25d6a06 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,15 +1,18 @@ ChangeLog ========= -.. _v2.1.3: +.. _v2.2.0: -2.1.3 (current) +2.2.0 (current) --------------- *Bugfix:* - Removed duplicated :class:`~factory.alchemy.SQLAlchemyModelFactory` lurking in :mod:`factory` (:issue:`83`) + - Properly handle sequences within object inheritance chains. + If FactoryA inherits from FactoryB, and their associated classes share the same link, + sequence counters will be shared (:issue:`93`) *New:* diff --git a/factory/base.py b/factory/base.py index 1b9fa0d..462a60c 100644 --- a/factory/base.py +++ b/factory/base.py @@ -186,7 +186,8 @@ class FactoryMetaClass(type): else: # If inheriting the factory from a parent, keep a link to it. # This allows to use the sequence counters from the parents. - if associated_class == inherited_associated_class: + if (inherited_associated_class is not None + and issubclass(associated_class, inherited_associated_class)): attrs['_base_factory'] = base # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into @@ -212,6 +213,32 @@ class FactoryMetaClass(type): # Factory base classes + +class _Counter(object): + """Simple, naive counter. + + Attributes: + for_class (obj): the class this counter related to + seq (int): the next value + """ + + def __init__(self, seq, for_class): + self.seq = seq + self.for_class = for_class + + def next(self): + value = self.seq + self.seq += 1 + return value + + def reset(self, next_value=0): + self.seq = next_value + + def __repr__(self): + return '<_Counter for %s.%s, next=%d>' % ( + self.for_class.__module__, self.for_class.__name__, self.seq) + + class BaseFactory(object): """Factory base support for sequences, attributes and stubs.""" @@ -224,10 +251,10 @@ class BaseFactory(object): raise FactoryError('You cannot instantiate BaseFactory') # ID to use for the next 'declarations.Sequence' attribute. - _next_sequence = None + _counter = None # Base factory, if this class was inherited from another factory. This is - # used for sharing the _next_sequence counter among factories for the same + # used for sharing the sequence _counter among factories for the same # class. _base_factory = None @@ -245,7 +272,14 @@ class BaseFactory(object): @classmethod def reset_sequence(cls, value=None, force=False): - """Reset the sequence counter.""" + """Reset the sequence counter. + + Args: + value (int or None): the new 'next' sequence value; if None, + recompute the next value from _setup_next_sequence(). + force (bool): whether to force-reset parent sequence counters + in a factory inheritance chain. + """ if cls._base_factory: if force: cls._base_factory.reset_sequence(value=value) @@ -253,10 +287,13 @@ class BaseFactory(object): raise ValueError( "Cannot reset the sequence of a factory subclass. " "Please call reset_sequence() on the root factory, " - "or call reset_sequence(forward=True)." + "or call reset_sequence(force=True)." ) else: - cls._next_sequence = value + cls._setup_counter() + if value is None: + value = cls._setup_next_sequence() + cls._counter.reset(value) @classmethod def _setup_next_sequence(cls): @@ -268,6 +305,19 @@ class BaseFactory(object): return 0 @classmethod + def _setup_counter(cls): + """Ensures cls._counter is set for this class. + + Due to the way inheritance works in Python, we need to ensure that the + ``_counter`` attribute has been initialized for *this* Factory subclass, + not one of its parents. + """ + if cls._counter is None or cls._counter.for_class != cls: + first_seq = cls._setup_next_sequence() + cls._counter = _Counter(for_class=cls, seq=first_seq) + logger.debug("%r: Setting up next sequence (%d)", cls, first_seq) + + @classmethod def _generate_next_sequence(cls): """Retrieve a new sequence ID. @@ -279,16 +329,14 @@ class BaseFactory(object): # Rely upon our parents if cls._base_factory: + logger.debug("%r: reusing sequence from %r", cls, cls._base_factory) return cls._base_factory._generate_next_sequence() - # Make sure _next_sequence is initialized - if cls._next_sequence is None: - cls._next_sequence = cls._setup_next_sequence() + # Make sure _counter is initialized + cls._setup_counter() # Pick current value, then increase class counter for the next call. - next_sequence = cls._next_sequence - cls._next_sequence += 1 - return next_sequence + return cls._counter.next() @classmethod def attributes(cls, create=False, extra=None): @@ -577,7 +625,7 @@ Factory = FactoryMetaClass('Factory', (BaseFactory,), { This class has the ability to support multiple ORMs by using custom creation functions. """, - }) +}) # Backwards compatibility diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py index cfbc835..4255417 100644 --- a/tests/test_alchemy.py +++ b/tests/test_alchemy.py @@ -65,7 +65,7 @@ class SQLAlchemyPkSequenceTestCase(unittest.TestCase): def setUp(self): super(SQLAlchemyPkSequenceTestCase, self).setUp() - StandardFactory.reset_sequence() + StandardFactory.reset_sequence(1) NonIntegerPkFactory.FACTORY_SESSION.rollback() def test_pk_first(self): diff --git a/tests/test_using.py b/tests/test_using.py index 0898a13..01e950f 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -730,6 +730,78 @@ class UsingFactoryTestCase(unittest.TestCase): test_object_alt = TestObjectFactory.build() self.assertEqual(None, test_object_alt.three) + def test_inheritance_and_sequences(self): + """Sequence counters should be kept within an inheritance chain.""" + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + + one = factory.Sequence(lambda n: n) + + class TestObjectFactory2(TestObjectFactory): + FACTORY_FOR = TestObject + + to1a = TestObjectFactory() + self.assertEqual(0, to1a.one) + to2a = TestObjectFactory2() + self.assertEqual(1, to2a.one) + to1b = TestObjectFactory() + self.assertEqual(2, to1b.one) + to2b = TestObjectFactory2() + self.assertEqual(3, to2b.one) + + def test_inheritance_sequence_inheriting_objects(self): + """Sequence counters are kept with inheritance, incl. misc objects.""" + class TestObject2(TestObject): + pass + + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + + one = factory.Sequence(lambda n: n) + + class TestObjectFactory2(TestObjectFactory): + FACTORY_FOR = TestObject2 + + to1a = TestObjectFactory() + self.assertEqual(0, to1a.one) + to2a = TestObjectFactory2() + self.assertEqual(1, to2a.one) + to1b = TestObjectFactory() + self.assertEqual(2, to1b.one) + to2b = TestObjectFactory2() + self.assertEqual(3, to2b.one) + + def test_inheritance_sequence_unrelated_objects(self): + """Sequence counters are kept with inheritance, unrelated objects. + + See issue https://github.com/rbarrois/factory_boy/issues/93 + + Problem: sequence counter is somewhat shared between factories + until the "slave" factory has been called. + """ + + class TestObject2(object): + def __init__(self, one): + self.one = one + + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + + one = factory.Sequence(lambda n: n) + + class TestObjectFactory2(TestObjectFactory): + FACTORY_FOR = TestObject2 + + to1a = TestObjectFactory() + self.assertEqual(0, to1a.one) + to2a = TestObjectFactory2() + self.assertEqual(0, to2a.one) + to1b = TestObjectFactory() + self.assertEqual(1, to1b.one) + to2b = TestObjectFactory2() + self.assertEqual(1, to2b.one) + + def test_inheritance_with_inherited_class(self): class TestObjectFactory(factory.Factory): FACTORY_FOR = TestObject |