summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-05 00:36:08 +0100
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-03-05 00:39:25 +0100
commit2bc0fc8413c02a7faf3a116fe875d76bc3403117 (patch)
tree4dc134b66bee6bf280b9d4f31766c8b3efef5117 /factory/declarations.py
parent3c011a3c6e97e40410ad88a734605759fb247301 (diff)
downloadfactory-boy-2bc0fc8413c02a7faf3a116fe875d76bc3403117.tar
factory-boy-2bc0fc8413c02a7faf3a116fe875d76bc3403117.tar.gz
Cleanup argument extraction in PostGenMethod (See #36).
This provides a consistent behaviour for extracting arguments to a PostGenerationMethodCall. Signed-off-by: Raphaël Barrois <raphael.barrois@polytechnique.org>
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 1f1d2af..efaadbe 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -492,20 +492,23 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
...
password = factory.PostGenerationMethodCall('set_password', password='')
"""
- def __init__(self, method_name, extract_prefix=None, *args, **kwargs):
+ def __init__(self, method_name, *args, **kwargs):
+ extract_prefix = kwargs.pop('extract_prefix', None)
super(PostGenerationMethodCall, self).__init__(extract_prefix)
self.method_name = method_name
self.method_args = args
self.method_kwargs = kwargs
def call(self, obj, create, extracted=None, **kwargs):
- if extracted is not None:
- passed_args = extracted
- if isinstance(passed_args, basestring) or (
- not isinstance(passed_args, collections.Iterable)):
- passed_args = (passed_args,)
- else:
+ if extracted is None:
passed_args = self.method_args
+
+ elif len(self.method_args) <= 1:
+ # Max one argument expected
+ passed_args = (extracted,)
+ else:
+ passed_args = tuple(extracted)
+
passed_kwargs = dict(self.method_kwargs)
passed_kwargs.update(kwargs)
method = getattr(obj, self.method_name)