From 924d8a6ac279ca6ad560a3cf5efa1b141cedd253 Mon Sep 17 00:00:00 2001 From: Ivan Miric Date: Wed, 16 Oct 2013 22:25:26 +0200 Subject: Added SubFactory support for MongoEngine's EmbeddedDocument --- factory/mongoengine.py | 3 ++- tests/test_mongoengine.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/factory/mongoengine.py b/factory/mongoengine.py index 8cd3a67..462f5f2 100644 --- a/factory/mongoengine.py +++ b/factory/mongoengine.py @@ -41,5 +41,6 @@ class MongoEngineFactory(base.Factory): @classmethod def _create(cls, target_class, *args, **kwargs): instance = target_class(*args, **kwargs) - instance.save() + if instance._is_document: + instance.save() return instance diff --git a/tests/test_mongoengine.py b/tests/test_mongoengine.py index f26eb85..803607a 100644 --- a/tests/test_mongoengine.py +++ b/tests/test_mongoengine.py @@ -19,7 +19,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -"""Tests for factory_boy/SQLAlchemy interactions.""" +"""Tests for factory_boy/MongoEngine interactions.""" import factory import os @@ -34,14 +34,23 @@ except ImportError: if mongoengine: from factory.mongoengine import MongoEngineFactory + class Address(mongoengine.EmbeddedDocument): + street = mongoengine.StringField() + class Person(mongoengine.Document): name = mongoengine.StringField() + address = mongoengine.EmbeddedDocumentField(Address) + + class AddressFactory(MongoEngineFactory): + FACTORY_FOR = Address + + street = factory.Sequence(lambda n: 'street%d' % n) class PersonFactory(MongoEngineFactory): FACTORY_FOR = Person name = factory.Sequence(lambda n: 'name%d' % n) - + address = factory.SubFactory(AddressFactory) @unittest.skipIf(mongoengine is None, "mongoengine not installed.") @@ -65,11 +74,11 @@ class MongoEngineTestCase(unittest.TestCase): def test_build(self): std = PersonFactory.build() self.assertEqual('name0', std.name) + self.assertEqual('street0', std.address.street) self.assertIsNone(std.id) def test_creation(self): std1 = PersonFactory.create() self.assertEqual('name1', std1.name) + self.assertEqual('street1', std1.address.street) self.assertIsNotNone(std1.id) - - -- cgit v1.2.3