diff options
Diffstat (limited to 'factory/declarations.py')
-rw-r--r-- | factory/declarations.py | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 1f1d2af..efaadbe 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -492,20 +492,23 @@ class PostGenerationMethodCall(PostGenerationDeclaration): ... password = factory.PostGenerationMethodCall('set_password', password='') """ - def __init__(self, method_name, extract_prefix=None, *args, **kwargs): + def __init__(self, method_name, *args, **kwargs): + extract_prefix = kwargs.pop('extract_prefix', None) super(PostGenerationMethodCall, self).__init__(extract_prefix) self.method_name = method_name self.method_args = args self.method_kwargs = kwargs def call(self, obj, create, extracted=None, **kwargs): - if extracted is not None: - passed_args = extracted - if isinstance(passed_args, basestring) or ( - not isinstance(passed_args, collections.Iterable)): - passed_args = (passed_args,) - else: + if extracted is None: passed_args = self.method_args + + elif len(self.method_args) <= 1: + # Max one argument expected + passed_args = (extracted,) + else: + passed_args = tuple(extracted) + passed_kwargs = dict(self.method_kwargs) passed_kwargs.update(kwargs) method = getattr(obj, self.method_name) |