summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--factory/declarations.py10
-rw-r--r--tests/test_declarations.py47
2 files changed, 56 insertions, 1 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 366c2c8..1f1d2af 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -21,6 +21,7 @@
# THE SOFTWARE.
+import collections
import itertools
import warnings
@@ -498,10 +499,17 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_kwargs = kwargs
def call(self, obj, create, extracted=None, **kwargs):
+ if extracted is not None:
+ passed_args = extracted
+ if isinstance(passed_args, basestring) or (
+ not isinstance(passed_args, collections.Iterable)):
+ passed_args = (passed_args,)
+ else:
+ passed_args = self.method_args
passed_kwargs = dict(self.method_kwargs)
passed_kwargs.update(kwargs)
method = getattr(obj, self.method_name)
- method(*self.method_args, **passed_kwargs)
+ method(*passed_args, **passed_kwargs)
# Decorators... in case lambdas don't cut it
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index cc921d4..59a3955 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -24,6 +24,8 @@ import datetime
import itertools
import warnings
+from mock import MagicMock
+
from factory import declarations
from .compat import unittest
@@ -295,6 +297,51 @@ class RelatedFactoryTestCase(unittest.TestCase):
datetime.date = orig_date
+class PostGenerationMethodCallTestCase(unittest.TestCase):
+ def setUp(self):
+ self.obj = MagicMock()
+
+ def test_simplest_setup_and_call(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with()
+
+ def test_call_with_method_args(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', None, 'data')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with('data')
+
+ def test_call_with_passed_extracted_string(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', None)
+ decl.call(self.obj, False, 'data')
+ self.obj.method.assert_called_once_with('data')
+
+ def test_call_with_passed_extracted_int(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False, 1)
+ self.obj.method.assert_called_once_with(1)
+
+ def test_call_with_passed_extracted_iterable(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False, (1, 2, 3))
+ self.obj.method.assert_called_once_with(1, 2, 3)
+
+ def test_call_with_method_kwargs(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', None, data='data')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with(data='data')
+
+ def test_call_with_passed_kwargs(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False, data='other')
+ self.obj.method.assert_called_once_with(data='other')
+
+
+
+
class CircularSubFactoryTestCase(unittest.TestCase):
def test_circularsubfactory_deprecated(self):