diff options
-rw-r--r-- | docs/changelog.rst | 4 | ||||
-rw-r--r-- | docs/reference.rst | 52 | ||||
-rw-r--r-- | factory/declarations.py | 17 | ||||
-rw-r--r-- | tests/test_declarations.py | 24 |
4 files changed, 81 insertions, 16 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst index eccc0a6..0671f8a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -50,6 +50,10 @@ New declarations - A :class:`~factory.Iterator` may be prevented from cycling by setting its :attr:`~factory.Iterator.cycle` argument to ``False`` + - Allow overriding default arguments in a :class:`~factory.PostGenerationMethodCall` + when generating an instance of the factory + - An object created by a :class:`~factory.DjangoModelFactory` will be saved + again after :class:`~factory.PostGeneration` hooks execution Pending deprecation diff --git a/docs/reference.rst b/docs/reference.rst index d100b40..85b299c 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -1003,10 +1003,22 @@ generated object just after it being called. Its sole argument is the name of the method to call. Extra arguments and keyword arguments for the target method may also be provided. -Once the object has been generated, the method will be called, with the arguments -provided in the :class:`PostGenerationMethodCall` declaration, and keyword -arguments taken from the combination of :class:`PostGenerationMethodCall` -declaration and prefix-based values: +Once the object has been generated, the method will be called, with arguments +taken from either the :class:`PostGenerationMethodCall` or prefix-based values: + +- If a value was extracted from kwargs (i.e an argument for the name the + :class:`PostGenerationMethodCall` was declared under): + + - If the declaration mentionned zero or one argument, the value is passed + directly to the method + - If the declaration used two or more arguments, the value is passed as + ``*args`` to the method + +- Otherwise, the arguments used when declaring the :class:`PostGenerationMethodCall` + are used + +- Keywords extracted from the factory arguments are merged into the defaults + present in the :class:`PostGenerationMethodCall` declaration. .. code-block:: python @@ -1018,10 +1030,40 @@ declaration and prefix-based values: .. code-block:: pycon >>> UserFactory() # Calls user.set_password(password='') - >>> UserFactory(password='test') # Calls user.set_password(password='test') + >>> UserFactory(password='test') # Calls user.set_password('test') >>> UserFactory(password__disabled=True) # Calls user.set_password(password='', disabled=True) +When the :class:`PostGenerationMethodCall` declaration uses two or more arguments, +the extracted value must be iterable: + +.. code-block:: python + + class UserFactory(factory.Factory): + FACTORY_FOR = User + + password = factory.PostGenerationMethodCall('set_password', '', 'sha1') + +.. code-block:: pycon + + >>> UserFactory() # Calls user.set_password('', 'sha1') + >>> UserFactory(password=('test', 'md5')) # Calls user.set_password('test', 'md5') + + >>> # Always pass in a good iterable: + >>> UserFactory(password=('test',)) # Calls user.set_password('test') + >>> UserFactory(password='test') # Calls user.set_password('t', 'e', 's', 't') + + +.. note:: While this setup provides sane and intuitive defaults for most users, + it prevents passing more than one argument when the declaration used + zero or one. + + In such cases, users are advised to either resort to the more powerful + :class:`PostGeneration` or to add the second expected argument default + value to the :class:`PostGenerationMethodCall` declaration + (``PostGenerationMethodCall('method', 'x', 'y_that_is_the_default')``) + + Module-level functions ---------------------- diff --git a/factory/declarations.py b/factory/declarations.py index 1f1d2af..efaadbe 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -492,20 +492,23 @@ class PostGenerationMethodCall(PostGenerationDeclaration): ... password = factory.PostGenerationMethodCall('set_password', password='') """ - def __init__(self, method_name, extract_prefix=None, *args, **kwargs): + def __init__(self, method_name, *args, **kwargs): + extract_prefix = kwargs.pop('extract_prefix', None) super(PostGenerationMethodCall, self).__init__(extract_prefix) self.method_name = method_name self.method_args = args self.method_kwargs = kwargs def call(self, obj, create, extracted=None, **kwargs): - if extracted is not None: - passed_args = extracted - if isinstance(passed_args, basestring) or ( - not isinstance(passed_args, collections.Iterable)): - passed_args = (passed_args,) - else: + if extracted is None: passed_args = self.method_args + + elif len(self.method_args) <= 1: + # Max one argument expected + passed_args = (extracted,) + else: + passed_args = tuple(extracted) + passed_kwargs = dict(self.method_kwargs) passed_kwargs.update(kwargs) method = getattr(obj, self.method_name) diff --git a/tests/test_declarations.py b/tests/test_declarations.py index 93e11d0..b11a4a8 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -306,13 +306,13 @@ class PostGenerationMethodCallTestCase(unittest.TestCase): def test_call_with_method_args(self): decl = declarations.PostGenerationMethodCall( - 'method', None, 'data') + 'method', 'data') decl.call(self.obj, False) self.obj.method.assert_called_once_with('data') def test_call_with_passed_extracted_string(self): decl = declarations.PostGenerationMethodCall( - 'method', None) + 'method') decl.call(self.obj, False, 'data') self.obj.method.assert_called_once_with('data') @@ -324,11 +324,11 @@ class PostGenerationMethodCallTestCase(unittest.TestCase): def test_call_with_passed_extracted_iterable(self): decl = declarations.PostGenerationMethodCall('method') decl.call(self.obj, False, (1, 2, 3)) - self.obj.method.assert_called_once_with(1, 2, 3) + self.obj.method.assert_called_once_with((1, 2, 3)) def test_call_with_method_kwargs(self): decl = declarations.PostGenerationMethodCall( - 'method', None, data='data') + 'method', data='data') decl.call(self.obj, False) self.obj.method.assert_called_once_with(data='data') @@ -337,7 +337,23 @@ class PostGenerationMethodCallTestCase(unittest.TestCase): decl.call(self.obj, False, data='other') self.obj.method.assert_called_once_with(data='other') + def test_multi_call_with_multi_method_args(self): + decl = declarations.PostGenerationMethodCall( + 'method', 'arg1', 'arg2') + decl.call(self.obj, False) + self.obj.method.assert_called_once_with('arg1', 'arg2') + def test_multi_call_with_passed_multiple_args(self): + decl = declarations.PostGenerationMethodCall( + 'method', 'arg1', 'arg2') + decl.call(self.obj, False, ('param1', 'param2', 'param3')) + self.obj.method.assert_called_once_with('param1', 'param2', 'param3') + + def test_multi_call_with_passed_tuple(self): + decl = declarations.PostGenerationMethodCall( + 'method', 'arg1', 'arg2') + decl.call(self.obj, False, (('param1', 'param2'),)) + self.obj.method.assert_called_once_with(('param1', 'param2')) class CircularSubFactoryTestCase(unittest.TestCase): |