From 5000ddaaef582e7504babf4f8163de13b93e7459 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Wed, 6 Jan 2016 19:36:10 -0300 Subject: optional forced flush on SQLAlchemyModelFactory fixes rbarrois/factory_boy#81 --- docs/orms.rst | 4 ++++ factory/alchemy.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/docs/orms.rst b/docs/orms.rst index 9b209bc..bd481bd 100644 --- a/docs/orms.rst +++ b/docs/orms.rst @@ -333,6 +333,10 @@ To work, this class needs an `SQLAlchemy`_ session object affected to the :attr: SQLAlchemy session to use to communicate with the database when creating an object through this :class:`SQLAlchemyModelFactory`. + .. attribute:: force_flush + + Force a session flush() at the end of :func:`~factory.alchemy.SQLAlchemyModelFactory._create()`. + A (very) simple example: .. code-block:: python diff --git a/factory/alchemy.py b/factory/alchemy.py index 20da6cf..a9aab23 100644 --- a/factory/alchemy.py +++ b/factory/alchemy.py @@ -27,6 +27,7 @@ class SQLAlchemyOptions(base.FactoryOptions): def _build_default_options(self): return super(SQLAlchemyOptions, self)._build_default_options() + [ base.OptionDefault('sqlalchemy_session', None, inherit=True), + base.OptionDefault('force_flush', False, inherit=True), ] @@ -43,4 +44,6 @@ class SQLAlchemyModelFactory(base.Factory): session = cls._meta.sqlalchemy_session obj = model_class(*args, **kwargs) session.add(obj) + if cls._meta.force_flush: + session.flush() return obj -- cgit v1.2.3 From b8050b1d61cd3171c2640eeaa6b3f71a6cbef5f5 Mon Sep 17 00:00:00 2001 From: Alejandro Date: Thu, 7 Jan 2016 12:52:57 -0300 Subject: added unittests for rbarrois/factory_boy#81 --- tests/test_alchemy.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_alchemy.py b/tests/test_alchemy.py index 9d7288a..5d8f275 100644 --- a/tests/test_alchemy.py +++ b/tests/test_alchemy.py @@ -23,6 +23,7 @@ import factory from .compat import unittest +import mock try: @@ -55,6 +56,16 @@ class StandardFactory(SQLAlchemyModelFactory): foo = factory.Sequence(lambda n: 'foo%d' % n) +class ForceFlushingStandardFactory(SQLAlchemyModelFactory): + class Meta: + model = models.StandardModel + sqlalchemy_session = mock.MagicMock() + force_flush = True + + id = factory.Sequence(lambda n: n) + foo = factory.Sequence(lambda n: 'foo%d' % n) + + class NonIntegerPkFactory(SQLAlchemyModelFactory): class Meta: model = models.NonIntegerPk @@ -102,6 +113,27 @@ class SQLAlchemyPkSequenceTestCase(unittest.TestCase): self.assertEqual(0, std2.id) +@unittest.skipIf(sqlalchemy is None, "SQLalchemy not installed.") +class SQLAlchemyForceFlushTestCase(unittest.TestCase): + def setUp(self): + super(SQLAlchemyForceFlushTestCase, self).setUp() + ForceFlushingStandardFactory.reset_sequence(1) + ForceFlushingStandardFactory._meta.sqlalchemy_session.rollback() + ForceFlushingStandardFactory._meta.sqlalchemy_session.reset_mock() + + def test_force_flush_called(self): + self.assertFalse(ForceFlushingStandardFactory._meta.sqlalchemy_session.flush.called) + ForceFlushingStandardFactory.create() + self.assertTrue(ForceFlushingStandardFactory._meta.sqlalchemy_session.flush.called) + + def test_force_flush_not_called(self): + ForceFlushingStandardFactory._meta.force_flush = False + self.assertFalse(ForceFlushingStandardFactory._meta.sqlalchemy_session.flush.called) + ForceFlushingStandardFactory.create() + self.assertFalse(ForceFlushingStandardFactory._meta.sqlalchemy_session.flush.called) + ForceFlushingStandardFactory._meta.force_flush = True + + @unittest.skipIf(sqlalchemy is None, "SQLalchemy not installed.") class SQLAlchemyNonIntegerPkTestCase(unittest.TestCase): def setUp(self): -- cgit v1.2.3