diff options
author | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2012-11-14 23:42:46 +0100 |
---|---|---|
committer | Raphaël Barrois <raphael.barrois@polytechnique.org> | 2012-11-15 00:03:37 +0100 |
commit | 5ec4a50edc67073e54218549d6985f934f94b88f (patch) | |
tree | 4a68d1367731a29e198090715cfe1cf7b039bba9 | |
parent | a19f64cbfadc0e36b2ff9812980e23955276632c (diff) | |
download | factory-boy-5ec4a50edc67073e54218549d6985f934f94b88f.tar factory-boy-5ec4a50edc67073e54218549d6985f934f94b88f.tar.gz |
Add an extension point for kwargs mangling.
Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
-rw-r--r-- | factory/base.py | 6 | ||||
-rw-r--r-- | tests/test_using.py | 22 |
2 files changed, 28 insertions, 0 deletions
diff --git a/factory/base.py b/factory/base.py index 20a3a6b..59f37eb 100644 --- a/factory/base.py +++ b/factory/base.py @@ -570,6 +570,11 @@ class Factory(BaseFactory): return building_function[0] @classmethod + def _adjust_kwargs(cls, **kwargs): + """Extension point for custom kwargs adjustment.""" + return kwargs + + @classmethod def _prepare(cls, create, **kwargs): """Prepare an object for this factory. @@ -578,6 +583,7 @@ class Factory(BaseFactory): **kwargs: arguments to pass to the creation function """ target_class = getattr(cls, CLASS_ATTRIBUTE_ASSOCIATED_CLASS) + kwargs = cls._adjust_kwargs(**kwargs) # Extract *args from **kwargs args = tuple(kwargs.pop(key) for key in cls.FACTORY_ARG_PARAMETERS) diff --git a/tests/test_using.py b/tests/test_using.py index 38c9e9e..f489f28 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -741,6 +741,28 @@ class NonKwargParametersTestCase(unittest.TestCase): self.assertEqual({'three': 3}, obj.kwargs) +class KwargAdjustTestCase(unittest.TestCase): + """Tests for the _adjust_kwargs method.""" + + def test_build(self): + class TestObject(object): + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + class TestObjectFactory(factory.Factory): + FACTORY_FOR = TestObject + + @classmethod + def _adjust_kwargs(cls, **kwargs): + kwargs['foo'] = len(kwargs) + return kwargs + + obj = TestObjectFactory.build(x=1, y=2, z=3) + self.assertEqual({'x': 1, 'y': 2, 'z': 3, 'foo': 3}, obj.kwargs) + self.assertEqual((), obj.args) + + class SubFactoryTestCase(unittest.TestCase): def testSubFactory(self): class TestModel2(FakeModel): |