summaryrefslogtreecommitdiff
path: root/factory
diff options
context:
space:
mode:
authorrcommande <commande.romain@gmail.com>2013-06-17 23:20:57 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-06-20 01:40:14 +0200
commit79ee9d24e203ad8e7daf2181437fe7132972529e (patch)
tree5fd18f176fb800024ed64129c39fae9590464db1 /factory
parent5137998a856519cd11cd743d6c567497600139e8 (diff)
downloadfactory-boy-79ee9d24e203ad8e7daf2181437fe7132972529e.tar
factory-boy-79ee9d24e203ad8e7daf2181437fe7132972529e.tar.gz
Small next sequence refactoring
Diffstat (limited to 'factory')
-rw-r--r--factory/__init__.py2
-rw-r--r--factory/alchemy.py12
2 files changed, 8 insertions, 6 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index 058fd76..f90f40a 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -23,6 +23,7 @@
__version__ = '2.1.0-dev'
__author__ = 'Raphaël Barrois <raphael.barrois+fboy@polytechnique.org>'
+
from .base import (
Factory,
BaseDictFactory,
@@ -39,7 +40,6 @@ from .base import (
from .mogo import MogoFactory
from .django import DjangoModelFactory
-from .alchemy import SQLAlchemyModelFactory
from .declarations import (
LazyAttribute,
diff --git a/factory/alchemy.py b/factory/alchemy.py
index 2bfaf81..ca7aefa 100644
--- a/factory/alchemy.py
+++ b/factory/alchemy.py
@@ -19,6 +19,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from __future__ import unicode_literals
+from sqlalchemy.sql.functions import max
+
from . import base
@@ -30,12 +32,12 @@ class SQLAlchemyModelFactory(base.Factory):
@classmethod
def _setup_next_sequence(cls, *args, **kwargs):
"""Compute the next available PK, based on the 'pk' database field."""
- from sqlalchemy.sql.functions import max
session = cls.FACTORY_SESSION
- pk = cls.FACTORY_FOR.__table__.primary_key.columns.values()[0].key
- max_pk = session.query(max(getattr(cls.FACTORY_FOR, pk))).one()
- if isinstance(max_pk[0], int):
- return max_pk[0] + 1 if max_pk[0] else 1
+ model = cls.FACTORY_FOR
+ pk = getattr(model, model.__mapper__.primary_key[0].name)
+ max_pk = session.query(max(pk)).one()[0]
+ if isinstance(max_pk, int):
+ return max_pk + 1 if max_pk else 1
else:
return 1