summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2014-05-18 12:17:19 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2014-05-18 14:18:49 +0200
commit92bb395e7f6d422ce239b2ef7303424fde43ab1d (patch)
treeedd6b7d56be4a093b12d987663c1d6085358dd95
parent80eaa0c8711f2c3ca82eb7953db49c7c61bd9ffa (diff)
downloadfactory-boy-92bb395e7f6d422ce239b2ef7303424fde43ab1d.tar
factory-boy-92bb395e7f6d422ce239b2ef7303424fde43ab1d.tar.gz
Migrate factory.alchemy to class Meta
-rw-r--r--factory/alchemy.py23
-rw-r--r--tests/test_alchemy.py4
2 files changed, 20 insertions, 7 deletions
diff --git a/factory/alchemy.py b/factory/alchemy.py
index cec15c9..b956d7e 100644
--- a/factory/alchemy.py
+++ b/factory/alchemy.py
@@ -24,17 +24,30 @@ from sqlalchemy.sql.functions import max
from . import base
+class SQLAlchemyOptions(base.FactoryOptions):
+ def _build_default_options(self):
+ return super(SQLAlchemyOptions, self)._build_default_options() + [
+ base.OptionDefault('sqlalchemy_session', None, inherit=True),
+ ]
+
+
class SQLAlchemyModelFactory(base.Factory):
"""Factory for SQLAlchemy models. """
- ABSTRACT_FACTORY = True
- FACTORY_SESSION = None
+ _options_class = SQLAlchemyOptions
+ class Meta:
+ abstract = True
+
+ _OLDSTYLE_ATTRIBUTES = base.Factory._OLDSTYLE_ATTRIBUTES.copy()
+ _OLDSTYLE_ATTRIBUTES.update({
+ 'FACTORY_SESSION': 'sqlalchemy_session',
+ })
@classmethod
def _setup_next_sequence(cls, *args, **kwargs):
"""Compute the next available PK, based on the 'pk' database field."""
- session = cls.FACTORY_SESSION
- model = cls.FACTORY_FOR
+ session = cls._meta.sqlalchemy_session
+ model = cls._meta.target
pk = getattr(model, model.__mapper__.primary_key[0].name)
max_pk = session.query(max(pk)).one()[0]
if isinstance(max_pk, int):
@@ -45,7 +58,7 @@ class SQLAlchemyModelFactory(base.Factory):
@classmethod
def _create(cls, target_class, *args, **kwargs):
"""Create an instance of the model, and save it to the database."""
- session = cls.FACTORY_SESSION
+ session = cls._meta.sqlalchemy_session
obj = target_class(*args, **kwargs)
session.add(obj)
return obj
diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py
index 4255417..c94e425 100644
--- a/tests/test_alchemy.py
+++ b/tests/test_alchemy.py
@@ -66,7 +66,7 @@ class SQLAlchemyPkSequenceTestCase(unittest.TestCase):
def setUp(self):
super(SQLAlchemyPkSequenceTestCase, self).setUp()
StandardFactory.reset_sequence(1)
- NonIntegerPkFactory.FACTORY_SESSION.rollback()
+ NonIntegerPkFactory._meta.sqlalchemy_session.rollback()
def test_pk_first(self):
std = StandardFactory.build()
@@ -104,7 +104,7 @@ class SQLAlchemyNonIntegerPkTestCase(unittest.TestCase):
def setUp(self):
super(SQLAlchemyNonIntegerPkTestCase, self).setUp()
NonIntegerPkFactory.reset_sequence()
- NonIntegerPkFactory.FACTORY_SESSION.rollback()
+ NonIntegerPkFactory._meta.sqlalchemy_session.rollback()
def test_first(self):
nonint = NonIntegerPkFactory.build()