diff options
author | rcommande <commande.romain@gmail.com> | 2013-06-17 23:20:57 +0200 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2013-06-20 01:40:14 +0200 |
commit | 79ee9d24e203ad8e7daf2181437fe7132972529e (patch) | |
tree | 5fd18f176fb800024ed64129c39fae9590464db1 /factory | |
parent | 5137998a856519cd11cd743d6c567497600139e8 (diff) | |
download | factory-boy-79ee9d24e203ad8e7daf2181437fe7132972529e.tar factory-boy-79ee9d24e203ad8e7daf2181437fe7132972529e.tar.gz |
Small next sequence refactoring
Diffstat (limited to 'factory')
-rw-r--r-- | factory/__init__.py | 2 | ||||
-rw-r--r-- | factory/alchemy.py | 12 |
2 files changed, 8 insertions, 6 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index 058fd76..f90f40a 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -23,6 +23,7 @@ __version__ = '2.1.0-dev' __author__ = 'Raphaël Barrois <raphael.barrois+fboy@polytechnique.org>' + from .base import ( Factory, BaseDictFactory, @@ -39,7 +40,6 @@ from .base import ( from .mogo import MogoFactory from .django import DjangoModelFactory -from .alchemy import SQLAlchemyModelFactory from .declarations import ( LazyAttribute, diff --git a/factory/alchemy.py b/factory/alchemy.py index 2bfaf81..ca7aefa 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -19,6 +19,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. from __future__ import unicode_literals +from sqlalchemy.sql.functions import max + from . import base @@ -30,12 +32,12 @@ class SQLAlchemyModelFactory(base.Factory): @classmethod def _setup_next_sequence(cls, *args, **kwargs): """Compute the next available PK, based on the 'pk' database field.""" - from sqlalchemy.sql.functions import max session = cls.FACTORY_SESSION - pk = cls.FACTORY_FOR.__table__.primary_key.columns.values()[0].key - max_pk = session.query(max(getattr(cls.FACTORY_FOR, pk))).one() - if isinstance(max_pk[0], int): - return max_pk[0] + 1 if max_pk[0] else 1 + model = cls.FACTORY_FOR + pk = getattr(model, model.__mapper__.primary_key[0].name) + max_pk = session.query(max(pk)).one()[0] + if isinstance(max_pk, int): + return max_pk + 1 if max_pk else 1 else: return 1 |