From 2bc0fc8413c02a7faf3a116fe875d76bc3403117 Mon Sep 17 00:00:00 2001 From: Raphaël Barrois Date: Tue, 5 Mar 2013 00:36:08 +0100 Subject: Cleanup argument extraction in PostGenMethod (See #36). MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This provides a consistent behaviour for extracting arguments to a PostGenerationMethodCall. Signed-off-by: Raphaël Barrois --- factory/declarations.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) (limited to 'factory') diff --git a/factory/declarations.py b/factory/declarations.py index 1f1d2af..efaadbe 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -492,20 +492,23 @@ class PostGenerationMethodCall(PostGenerationDeclaration): ... password = factory.PostGenerationMethodCall('set_password', password='') """ - def __init__(self, method_name, extract_prefix=None, *args, **kwargs): + def __init__(self, method_name, *args, **kwargs): + extract_prefix = kwargs.pop('extract_prefix', None) super(PostGenerationMethodCall, self).__init__(extract_prefix) self.method_name = method_name self.method_args = args self.method_kwargs = kwargs def call(self, obj, create, extracted=None, **kwargs): - if extracted is not None: - passed_args = extracted - if isinstance(passed_args, basestring) or ( - not isinstance(passed_args, collections.Iterable)): - passed_args = (passed_args,) - else: + if extracted is None: passed_args = self.method_args + + elif len(self.method_args) <= 1: + # Max one argument expected + passed_args = (extracted,) + else: + passed_args = tuple(extracted) + passed_kwargs = dict(self.method_kwargs) passed_kwargs.update(kwargs) method = getattr(obj, self.method_name) -- cgit v1.2.3