summaryrefslogtreecommitdiff
path: root/factory/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'factory/declarations.py')
-rw-r--r--factory/declarations.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 1f1d2af..efaadbe 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -492,20 +492,23 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
...
password = factory.PostGenerationMethodCall('set_password', password='')
"""
- def __init__(self, method_name, extract_prefix=None, *args, **kwargs):
+ def __init__(self, method_name, *args, **kwargs):
+ extract_prefix = kwargs.pop('extract_prefix', None)
super(PostGenerationMethodCall, self).__init__(extract_prefix)
self.method_name = method_name
self.method_args = args
self.method_kwargs = kwargs
def call(self, obj, create, extracted=None, **kwargs):
- if extracted is not None:
- passed_args = extracted
- if isinstance(passed_args, basestring) or (
- not isinstance(passed_args, collections.Iterable)):
- passed_args = (passed_args,)
- else:
+ if extracted is None:
passed_args = self.method_args
+
+ elif len(self.method_args) <= 1:
+ # Max one argument expected
+ passed_args = (extracted,)
+ else:
+ passed_args = tuple(extracted)
+
passed_kwargs = dict(self.method_kwargs)
passed_kwargs.update(kwargs)
method = getattr(obj, self.method_name)