summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/__init__.py4
-rw-r--r--factory/base.py42
-rw-r--r--tests/test_base.py21
3 files changed, 56 insertions, 11 deletions
diff --git a/factory/__init__.py b/factory/__init__.py
index efc0001..f2652f2 100644
--- a/factory/__init__.py
+++ b/factory/__init__.py
@@ -28,6 +28,10 @@ from base import (
StubFactory,
DjangoModelFactory,
+ build,
+ create,
+ stub,
+
BUILD_STRATEGY,
CREATE_STRATEGY,
STUB_STRATEGY,
diff --git a/factory/base.py b/factory/base.py
index 3f3261b..80c49ca 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -156,17 +156,18 @@ class FactoryMetaClass(BaseFactoryMetaClass):
if FACTORY_CLASS_DECLARATION in attrs:
return attrs[FACTORY_CLASS_DECLARATION]
- factory_module = sys.modules[attrs['__module__']]
- if class_name.endswith('Factory'):
- # Try a module lookup
- used_auto_discovery = True
- associated_class_name = class_name[:-len('Factory')]
- if associated_class_name:
- # Class name was longer than just 'Factory'.
- try:
- return getattr(factory_module, associated_class_name)
- except AttributeError:
- pass
+ if '__module__' in attrs:
+ factory_module = sys.modules[attrs['__module__']]
+ if class_name.endswith('Factory'):
+ # Try a module lookup
+ used_auto_discovery = True
+ associated_class_name = class_name[:-len('Factory')]
+ if associated_class_name:
+ # Class name was longer than just 'Factory'.
+ try:
+ return getattr(factory_module, associated_class_name)
+ except AttributeError:
+ pass
# Unable to guess a good option; return the inherited class.
if inherited is not None:
@@ -439,3 +440,22 @@ class DjangoModelFactory(Factory):
).order_by('-id')[0]
except IndexError:
return 1
+
+
+def _make_factory(klass, **kwargs):
+ factory_name = '%sFactory' % klass.__name__
+ kwargs[FACTORY_CLASS_DECLARATION] = klass
+ factory_class = type(Factory).__new__(type(Factory), factory_name, (Factory,), kwargs)
+ factory_class.__name__ = '%sFactory' % klass.__name__
+ factory_class.__doc__ = 'Auto-generated factory for class %s' % klass
+ return factory_class
+
+
+def build(klass, **kwargs):
+ return _make_factory(klass, **kwargs).build()
+
+def create(klass, **kwargs):
+ return _make_factory(klass, **kwargs).create()
+
+def stub(klass, **kwargs):
+ return _make_factory(klass, **kwargs).stub()
diff --git a/tests/test_base.py b/tests/test_base.py
index 01855c3..5855682 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -55,6 +55,27 @@ class SafetyTestCase(unittest.TestCase):
self.assertRaises(RuntimeError, base.BaseFactory)
+class SimpleBuildTestCase(unittest.TestCase):
+ """Tests the minimalist 'factory.build/create' functions."""
+
+ def test_build(self):
+ obj = base.build(TestObject, two=2)
+ self.assertEqual(obj.one, None)
+ self.assertEqual(obj.two, 2)
+ self.assertEqual(obj.three, None)
+ self.assertEqual(obj.four, None)
+
+ def test_create(self):
+ obj = base.create(FakeDjangoModel, foo='bar')
+ self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_stub(self):
+ obj = base.stub(TestObject, three=3)
+ self.assertEqual(obj.three, 3)
+ self.assertFalse(hasattr(obj, 'two'))
+
+
class FactoryTestCase(unittest.TestCase):
def testAttribute(self):
class TestObjectFactory(base.Factory):