diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-09-17 01:12:48 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-09-17 01:12:48 +0200 |
commit | 7fe9dcaa8494e73d57613d1288b4f86c4cba5bf0 (patch) | |
tree | 5f7c3728caeb67609be6e7ba17e03b2fa93dbc22 /factory | |
parent | a8742c973db224968b74bb054027130b2ab458e0 (diff) | |
download | factory-boy-7fe9dcaa8494e73d57613d1288b4f86c4cba5bf0.tar factory-boy-7fe9dcaa8494e73d57613d1288b4f86c4cba5bf0.tar.gz |
Properly handle Sequence & inheritance (Closes #93).
There was also a nasty bug: with class FactoryB(FactoryA), FactoryB's sequence
counter started at the value of FactoryA's counter when FactoryB was first called.
Diffstat (limited to 'factory')
-rw-r--r-- | factory/base.py | 74 |
1 files changed, 61 insertions, 13 deletions
diff --git a/factory/base.py b/factory/base.py index 1b9fa0d..462a60c 100644 --- a/factory/base.py +++ b/factory/base.py @@ -186,7 +186,8 @@ class FactoryMetaClass(type): else: # If inheriting the factory from a parent, keep a link to it. # This allows to use the sequence counters from the parents. - if associated_class == inherited_associated_class: + if (inherited_associated_class is not None + and issubclass(associated_class, inherited_associated_class)): attrs['_base_factory'] = base # The CLASS_ATTRIBUTE_ASSOCIATED_CLASS must *not* be taken into @@ -212,6 +213,32 @@ class FactoryMetaClass(type): # Factory base classes + +class _Counter(object): + """Simple, naive counter. + + Attributes: + for_class (obj): the class this counter related to + seq (int): the next value + """ + + def __init__(self, seq, for_class): + self.seq = seq + self.for_class = for_class + + def next(self): + value = self.seq + self.seq += 1 + return value + + def reset(self, next_value=0): + self.seq = next_value + + def __repr__(self): + return '<_Counter for %s.%s, next=%d>' % ( + self.for_class.__module__, self.for_class.__name__, self.seq) + + class BaseFactory(object): """Factory base support for sequences, attributes and stubs.""" @@ -224,10 +251,10 @@ class BaseFactory(object): raise FactoryError('You cannot instantiate BaseFactory') # ID to use for the next 'declarations.Sequence' attribute. - _next_sequence = None + _counter = None # Base factory, if this class was inherited from another factory. This is - # used for sharing the _next_sequence counter among factories for the same + # used for sharing the sequence _counter among factories for the same # class. _base_factory = None @@ -245,7 +272,14 @@ class BaseFactory(object): @classmethod def reset_sequence(cls, value=None, force=False): - """Reset the sequence counter.""" + """Reset the sequence counter. + + Args: + value (int or None): the new 'next' sequence value; if None, + recompute the next value from _setup_next_sequence(). + force (bool): whether to force-reset parent sequence counters + in a factory inheritance chain. + """ if cls._base_factory: if force: cls._base_factory.reset_sequence(value=value) @@ -253,10 +287,13 @@ class BaseFactory(object): raise ValueError( "Cannot reset the sequence of a factory subclass. " "Please call reset_sequence() on the root factory, " - "or call reset_sequence(forward=True)." + "or call reset_sequence(force=True)." ) else: - cls._next_sequence = value + cls._setup_counter() + if value is None: + value = cls._setup_next_sequence() + cls._counter.reset(value) @classmethod def _setup_next_sequence(cls): @@ -268,6 +305,19 @@ class BaseFactory(object): return 0 @classmethod + def _setup_counter(cls): + """Ensures cls._counter is set for this class. + + Due to the way inheritance works in Python, we need to ensure that the + ``_counter`` attribute has been initialized for *this* Factory subclass, + not one of its parents. + """ + if cls._counter is None or cls._counter.for_class != cls: + first_seq = cls._setup_next_sequence() + cls._counter = _Counter(for_class=cls, seq=first_seq) + logger.debug("%r: Setting up next sequence (%d)", cls, first_seq) + + @classmethod def _generate_next_sequence(cls): """Retrieve a new sequence ID. @@ -279,16 +329,14 @@ class BaseFactory(object): # Rely upon our parents if cls._base_factory: + logger.debug("%r: reusing sequence from %r", cls, cls._base_factory) return cls._base_factory._generate_next_sequence() - # Make sure _next_sequence is initialized - if cls._next_sequence is None: - cls._next_sequence = cls._setup_next_sequence() + # Make sure _counter is initialized + cls._setup_counter() # Pick current value, then increase class counter for the next call. - next_sequence = cls._next_sequence - cls._next_sequence += 1 - return next_sequence + return cls._counter.next() @classmethod def attributes(cls, create=False, extra=None): @@ -577,7 +625,7 @@ Factory = FactoryMetaClass('Factory', (BaseFactory,), { This class has the ability to support multiple ORMs by using custom creation functions. """, - }) +}) # Backwards compatibility |