diff options
Diffstat (limited to 'factory/alchemy.py')
-rw-r--r-- | factory/alchemy.py | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/factory/alchemy.py b/factory/alchemy.py index cec15c9..a9aab23 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -19,33 +19,31 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. from __future__ import unicode_literals -from sqlalchemy.sql.functions import max from . import base +class SQLAlchemyOptions(base.FactoryOptions): + def _build_default_options(self): + return super(SQLAlchemyOptions, self)._build_default_options() + [ + base.OptionDefault('sqlalchemy_session', None, inherit=True), + base.OptionDefault('force_flush', False, inherit=True), + ] + + class SQLAlchemyModelFactory(base.Factory): """Factory for SQLAlchemy models. """ - ABSTRACT_FACTORY = True - FACTORY_SESSION = None - - @classmethod - def _setup_next_sequence(cls, *args, **kwargs): - """Compute the next available PK, based on the 'pk' database field.""" - session = cls.FACTORY_SESSION - model = cls.FACTORY_FOR - pk = getattr(model, model.__mapper__.primary_key[0].name) - max_pk = session.query(max(pk)).one()[0] - if isinstance(max_pk, int): - return max_pk + 1 if max_pk else 1 - else: - return 1 + _options_class = SQLAlchemyOptions + class Meta: + abstract = True @classmethod - def _create(cls, target_class, *args, **kwargs): + def _create(cls, model_class, *args, **kwargs): """Create an instance of the model, and save it to the database.""" - session = cls.FACTORY_SESSION - obj = target_class(*args, **kwargs) + session = cls._meta.sqlalchemy_session + obj = model_class(*args, **kwargs) session.add(obj) + if cls._meta.force_flush: + session.flush() return obj |