diff options
-rw-r--r-- | factory/alchemy.py | 23 | ||||
-rw-r--r-- | tests/test_alchemy.py | 4 |
2 files changed, 20 insertions, 7 deletions
diff --git a/factory/alchemy.py b/factory/alchemy.py index cec15c9..b956d7e 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -24,17 +24,30 @@ from sqlalchemy.sql.functions import max from . import base +class SQLAlchemyOptions(base.FactoryOptions): + def _build_default_options(self): + return super(SQLAlchemyOptions, self)._build_default_options() + [ + base.OptionDefault('sqlalchemy_session', None, inherit=True), + ] + + class SQLAlchemyModelFactory(base.Factory): """Factory for SQLAlchemy models. """ - ABSTRACT_FACTORY = True - FACTORY_SESSION = None + _options_class = SQLAlchemyOptions + class Meta: + abstract = True + + _OLDSTYLE_ATTRIBUTES = base.Factory._OLDSTYLE_ATTRIBUTES.copy() + _OLDSTYLE_ATTRIBUTES.update({ + 'FACTORY_SESSION': 'sqlalchemy_session', + }) @classmethod def _setup_next_sequence(cls, *args, **kwargs): """Compute the next available PK, based on the 'pk' database field.""" - session = cls.FACTORY_SESSION - model = cls.FACTORY_FOR + session = cls._meta.sqlalchemy_session + model = cls._meta.target pk = getattr(model, model.__mapper__.primary_key[0].name) max_pk = session.query(max(pk)).one()[0] if isinstance(max_pk, int): @@ -45,7 +58,7 @@ class SQLAlchemyModelFactory(base.Factory): @classmethod def _create(cls, target_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" - session = cls.FACTORY_SESSION + session = cls._meta.sqlalchemy_session obj = target_class(*args, **kwargs) session.add(obj) return obj diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py index 4255417..c94e425 100644 --- a/tests/test_alchemy.py +++ b/tests/test_alchemy.py @@ -66,7 +66,7 @@ class SQLAlchemyPkSequenceTestCase(unittest.TestCase): def setUp(self): super(SQLAlchemyPkSequenceTestCase, self).setUp() StandardFactory.reset_sequence(1) - NonIntegerPkFactory.FACTORY_SESSION.rollback() + NonIntegerPkFactory._meta.sqlalchemy_session.rollback() def test_pk_first(self): std = StandardFactory.build() @@ -104,7 +104,7 @@ class SQLAlchemyNonIntegerPkTestCase(unittest.TestCase): def setUp(self): super(SQLAlchemyNonIntegerPkTestCase, self).setUp() NonIntegerPkFactory.reset_sequence() - NonIntegerPkFactory.FACTORY_SESSION.rollback() + NonIntegerPkFactory._meta.sqlalchemy_session.rollback() def test_first(self): nonint = NonIntegerPkFactory.build() |