summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-06-14 23:36:25 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-06-14 23:36:25 +0200
commit1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b (patch)
treebcf4d4291b0f451dce533ff244f87fb088b90a34
parent251ae29b4beedd7e9af721ceabb82a03f2d55bab (diff)
downloadfactory-boy-1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b.tar
factory-boy-1ba20b0ed7b920fa2d161df94a0dda3d93b1e14b.tar.gz
Properly handle passed-in None in RelatedFactory (Closes #62).
Thanks to @Dhekke for the help!
-rw-r--r--docs/reference.rst16
-rw-r--r--factory/base.py4
-rw-r--r--factory/declarations.py39
-rw-r--r--tests/test_declarations.py19
4 files changed, 59 insertions, 19 deletions
diff --git a/docs/reference.rst b/docs/reference.rst
index a2d6c9a..74f2dbd 100644
--- a/docs/reference.rst
+++ b/docs/reference.rst
@@ -1157,6 +1157,22 @@ Extra kwargs may be passed to the related factory, through the usual ``ATTR__SUB
>>> City.objects.get(capital_of=england)
<City: London>
+If a value if passed for the :class:`RelatedFactory` attribute, this disables
+:class:`RelatedFactory` generation:
+
+.. code-block:: pycon
+
+ >>> france = CountryFactory()
+ >>> paris = City.objects.get()
+ >>> paris
+ <City: Paris>
+ >>> reunion = CountryFactory(capital_city=paris)
+ >>> City.objects.count() # No new capital_city generated
+ 1
+ >>> guyane = CountryFactory(capital_city=paris, capital_city__name='Kourou')
+ >>> City.objects.count() # No new capital_city generated, ``name`` ignored.
+ 1
+
PostGeneration
""""""""""""""
diff --git a/factory/base.py b/factory/base.py
index 60aa218..76d3d4a 100644
--- a/factory/base.py
+++ b/factory/base.py
@@ -380,9 +380,9 @@ class BaseFactory(object):
# Handle post-generation attributes
results = {}
for name, decl in sorted(postgen_declarations.items()):
- extracted, extracted_kwargs = postgen_attributes[name]
+ did_extract, extracted, extracted_kwargs = postgen_attributes[name]
results[name] = decl.call(obj, create, extracted,
- **extracted_kwargs)
+ factory_extracted=did_extract, **extracted_kwargs)
cls._after_postgeneration(obj, create, results)
diff --git a/factory/declarations.py b/factory/declarations.py
index dbade97..f068c0d 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -430,11 +430,18 @@ class PostGenerationDeclaration(object):
(object, dict): a tuple containing the attribute at 'name' (if
provided) and a dict of extracted attributes
"""
- extracted = attrs.pop(name, None)
+ try:
+ extracted = attrs.pop(name)
+ did_extract = True
+ except KeyError:
+ extracted = None
+ did_extract = False
+
kwargs = utils.extract_dict(name, attrs)
- return extracted, kwargs
+ return did_extract, extracted, kwargs
- def call(self, obj, create, extracted=None, **kwargs): # pragma: no cover
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs): # pragma: no cover
"""Call this hook; no return value is expected.
Args:
@@ -454,7 +461,8 @@ class PostGeneration(PostGenerationDeclaration):
super(PostGeneration, self).__init__()
self.function = function
- def call(self, obj, create, extracted=None, **kwargs):
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs):
logger.debug('PostGeneration: Calling %s.%s(%s)',
self.function.__module__,
self.function.__name__,
@@ -492,19 +500,29 @@ class RelatedFactory(PostGenerationDeclaration):
"""Retrieve the wrapped factory.Factory subclass."""
return self.factory_wrapper.get()
- def call(self, obj, create, extracted=None, **kwargs):
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs):
+ factory = self.get_factory()
+
+ if factory_extracted:
+ # The user passed in a custom value
+ logger.debug('RelatedFactory: Using provided %r instead of '
+ 'generating %s.%s.',
+ extracted, factory.__module__, factory.__name__,
+ )
+ return extracted
+
passed_kwargs = dict(self.defaults)
passed_kwargs.update(kwargs)
if self.name:
passed_kwargs[self.name] = obj
- factory = self.get_factory()
logger.debug('RelatedFactory: Generating %s.%s(%s)',
factory.__module__,
factory.__name__,
utils.log_pprint((create,), passed_kwargs),
)
- factory.simple_generate(create, **passed_kwargs)
+ return factory.simple_generate(create, **passed_kwargs)
class PostGenerationMethodCall(PostGenerationDeclaration):
@@ -526,8 +544,9 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_args = args
self.method_kwargs = kwargs
- def call(self, obj, create, extracted=None, **kwargs):
- if extracted is None:
+ def call(self, obj, create,
+ extracted=None, factory_extracted=False, **kwargs):
+ if not factory_extracted:
passed_args = self.method_args
elif len(self.method_args) <= 1:
@@ -544,4 +563,4 @@ class PostGenerationMethodCall(PostGenerationDeclaration):
self.method_name,
utils.log_pprint(passed_args, passed_kwargs),
)
- method(*passed_args, **passed_kwargs)
+ return method(*passed_args, **passed_kwargs)
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index e0b2513..cd38dd2 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -119,7 +119,9 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
def test_extract_no_prefix(self):
decl = declarations.PostGenerationDeclaration()
- extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42})
+ did_extract, extracted, kwargs = decl.extract('foo',
+ {'foo': 13, 'foo__bar': 42})
+ self.assertTrue(did_extract)
self.assertEqual(extracted, 13)
self.assertEqual(kwargs, {'bar': 42})
@@ -130,8 +132,9 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
call_params.append(args)
call_params.append(kwargs)
- extracted, kwargs = foo.extract('foo',
+ did_extract, extracted, kwargs = foo.extract('foo',
{'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1})
+ self.assertTrue(did_extract)
self.assertEqual(13, extracted)
self.assertEqual({'bar': 42}, kwargs)
@@ -215,17 +218,17 @@ class PostGenerationMethodCallTestCase(unittest.TestCase):
def test_call_with_passed_extracted_string(self):
decl = declarations.PostGenerationMethodCall(
'method')
- decl.call(self.obj, False, 'data')
+ decl.call(self.obj, False, 'data', factory_extracted=True)
self.obj.method.assert_called_once_with('data')
def test_call_with_passed_extracted_int(self):
decl = declarations.PostGenerationMethodCall('method')
- decl.call(self.obj, False, 1)
+ decl.call(self.obj, False, 1, factory_extracted=True)
self.obj.method.assert_called_once_with(1)
def test_call_with_passed_extracted_iterable(self):
decl = declarations.PostGenerationMethodCall('method')
- decl.call(self.obj, False, (1, 2, 3))
+ decl.call(self.obj, False, (1, 2, 3), factory_extracted=True)
self.obj.method.assert_called_once_with((1, 2, 3))
def test_call_with_method_kwargs(self):
@@ -248,13 +251,15 @@ class PostGenerationMethodCallTestCase(unittest.TestCase):
def test_multi_call_with_passed_multiple_args(self):
decl = declarations.PostGenerationMethodCall(
'method', 'arg1', 'arg2')
- decl.call(self.obj, False, ('param1', 'param2', 'param3'))
+ decl.call(self.obj, False, ('param1', 'param2', 'param3'),
+ factory_extracted=True)
self.obj.method.assert_called_once_with('param1', 'param2', 'param3')
def test_multi_call_with_passed_tuple(self):
decl = declarations.PostGenerationMethodCall(
'method', 'arg1', 'arg2')
- decl.call(self.obj, False, (('param1', 'param2'),))
+ decl.call(self.obj, False, (('param1', 'param2'),),
+ factory_extracted=True)
self.obj.method.assert_called_once_with(('param1', 'param2'))
def test_multi_call_with_kwargs(self):