diff options
-rw-r--r-- | README.rst | 18 | ||||
-rw-r--r-- | factory/base.py | 27 | ||||
-rw-r--r-- | tests.py | 17 |
3 files changed, 59 insertions, 3 deletions
@@ -165,3 +165,21 @@ Sequences can be combined with lazy attributes:: email = factory.LazyAttributeSequence(lambda a, n: '{0}+{1}@example.com'.format(a.name, n).lower()) UserFactory().email # => mark+0@example.com + +Customizing creation +-------------------- + +Sometimes, the default build/create by keyword arguments doesn't allow for enough +customization of the generated objects. In such cases, you should override the +:meth:`base.Factory._prepare` method:: + + class UserFactory(factory.Factory): + @classmethod + def _prepare(cls, create, **kwargs): + password = kwargs.pop('password', None) + user = super(UserFactory, cls)._prepare(create, kwargs) + if password: + user.set_password(user) + if create: + user.save() + return user diff --git a/factory/base.py b/factory/base.py index 1d84b03..222c1bf 100644 --- a/factory/base.py +++ b/factory/base.py @@ -86,7 +86,7 @@ class BaseFactoryMetaClass(type): ordered_declarations = [(_name, declaration) for (_name, declaration) in ordered_declarations if _name != name] ordered_declarations.append((name, attrs[name])) del attrs[name] - elif not name.startswith('__'): + elif not name.startswith('_'): unordered_declarations = [(_name, value) for (_name, value) in unordered_declarations if _name != name] unordered_declarations.append((name, attrs[name])) del attrs[name] @@ -230,9 +230,30 @@ class Factory(BaseFactory): return cls._creation_function[0] @classmethod + def _prepare(cls, create, **kwargs): + """Prepare an object for this factory. + + Args: + create: bool, whether to create or to build the object + **kwargs: arguments to pass to the creation function + """ + if create: + return cls.get_creation_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **kwargs) + else: + return getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS)(**kwargs) + + @classmethod + def _build(cls, **kwargs): + return cls._prepare(create=False, **kwargs) + + @classmethod + def _create(cls, **kwargs): + return cls._prepare(create=True, **kwargs) + + @classmethod def build(cls, **kwargs): - return getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS)(**cls.attributes(**kwargs)) + return cls._build(**cls.attributes(**kwargs)) @classmethod def create(cls, **kwargs): - return cls.get_creation_function()(getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS), **cls.attributes(**kwargs)) + return cls._create(**cls.attributes(**kwargs)) @@ -258,6 +258,23 @@ class FactoryCreationTestCase(unittest.TestCase): self.assertEqual(TestFactory.default_strategy, STUB_STRATEGY) + def testCustomCreation(self): + class TestModelFactory(Factory): + @classmethod + def _prepare(cls, create, **kwargs): + kwargs['four'] = 4 + return super(TestModelFactory, cls)._prepare(create, **kwargs) + + b = TestModelFactory.build(one=1) + self.assertEqual(1, b.one) + self.assertEqual(4, b.four) + self.assertEqual(None, b.id) + + c = TestModelFactory(one=1) + self.assertEqual(1, c.one) + self.assertEqual(4, c.four) + self.assertEqual(1, c.id) + # Errors def testNoAssociatedClassWithAutodiscovery(self): |