diff options
-rw-r--r-- | factory/__init__.py | 4 | ||||
-rw-r--r-- | factory/base.py | 42 | ||||
-rw-r--r-- | tests/test_base.py | 21 |
3 files changed, 56 insertions, 11 deletions
diff --git a/factory/__init__.py b/factory/__init__.py index efc0001..f2652f2 100644 --- a/factory/__init__.py +++ b/factory/__init__.py @@ -28,6 +28,10 @@ from base import ( StubFactory, DjangoModelFactory, + build, + create, + stub, + BUILD_STRATEGY, CREATE_STRATEGY, STUB_STRATEGY, diff --git a/factory/base.py b/factory/base.py index 3f3261b..80c49ca 100644 --- a/factory/base.py +++ b/factory/base.py @@ -156,17 +156,18 @@ class FactoryMetaClass(BaseFactoryMetaClass): if FACTORY_CLASS_DECLARATION in attrs: return attrs[FACTORY_CLASS_DECLARATION] - factory_module = sys.modules[attrs['__module__']] - if class_name.endswith('Factory'): - # Try a module lookup - used_auto_discovery = True - associated_class_name = class_name[:-len('Factory')] - if associated_class_name: - # Class name was longer than just 'Factory'. - try: - return getattr(factory_module, associated_class_name) - except AttributeError: - pass + if '__module__' in attrs: + factory_module = sys.modules[attrs['__module__']] + if class_name.endswith('Factory'): + # Try a module lookup + used_auto_discovery = True + associated_class_name = class_name[:-len('Factory')] + if associated_class_name: + # Class name was longer than just 'Factory'. + try: + return getattr(factory_module, associated_class_name) + except AttributeError: + pass # Unable to guess a good option; return the inherited class. if inherited is not None: @@ -439,3 +440,22 @@ class DjangoModelFactory(Factory): ).order_by('-id')[0] except IndexError: return 1 + + +def _make_factory(klass, **kwargs): + factory_name = '%sFactory' % klass.__name__ + kwargs[FACTORY_CLASS_DECLARATION] = klass + factory_class = type(Factory).__new__(type(Factory), factory_name, (Factory,), kwargs) + factory_class.__name__ = '%sFactory' % klass.__name__ + factory_class.__doc__ = 'Auto-generated factory for class %s' % klass + return factory_class + + +def build(klass, **kwargs): + return _make_factory(klass, **kwargs).build() + +def create(klass, **kwargs): + return _make_factory(klass, **kwargs).create() + +def stub(klass, **kwargs): + return _make_factory(klass, **kwargs).stub() diff --git a/tests/test_base.py b/tests/test_base.py index 01855c3..5855682 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -55,6 +55,27 @@ class SafetyTestCase(unittest.TestCase): self.assertRaises(RuntimeError, base.BaseFactory) +class SimpleBuildTestCase(unittest.TestCase): + """Tests the minimalist 'factory.build/create' functions.""" + + def test_build(self): + obj = base.build(TestObject, two=2) + self.assertEqual(obj.one, None) + self.assertEqual(obj.two, 2) + self.assertEqual(obj.three, None) + self.assertEqual(obj.four, None) + + def test_create(self): + obj = base.create(FakeDjangoModel, foo='bar') + self.assertEqual(obj.id, 1) + self.assertEqual(obj.foo, 'bar') + + def test_stub(self): + obj = base.stub(TestObject, three=3) + self.assertEqual(obj.three, 3) + self.assertFalse(hasattr(obj, 'two')) + + class FactoryTestCase(unittest.TestCase): def testAttribute(self): class TestObjectFactory(base.Factory): |