summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorThomas Goirand <thomas@goirand.fr>2013-05-12 05:32:34 +0000
committerThomas Goirand <thomas@goirand.fr>2013-05-12 05:32:34 +0000
commit28991f9514e3cd78a528bbbe956d9b4536c416e0 (patch)
treea3871392d2382f60490824d79058f8a71ae1c34e /tests
parent57fa2e21aed37c1af2a87f36a998046b73092a21 (diff)
parent876845102c4a217496d0f6435bfe1e3726d31fe4 (diff)
downloadfactory-boy-28991f9514e3cd78a528bbbe956d9b4536c416e0.tar
factory-boy-28991f9514e3cd78a528bbbe956d9b4536c416e0.tar.gz
Merge tag '2.0.2' into debian/unstable
Release of factory_boy 2.0.2 Conflicts: docs/changelog.rst docs/index.rst docs/subfactory.rst tests/cyclic/bar.py tests/cyclic/foo.py
Diffstat (limited to 'tests')
-rw-r--r--tests/__init__.py3
-rw-r--r--tests/compat.py12
-rw-r--r--tests/test_base.py107
-rw-r--r--tests/test_containers.py12
-rw-r--r--tests/test_declarations.py251
-rw-r--r--tests/test_fuzzy.py104
-rw-r--r--tests/test_using.py918
-rw-r--r--tests/test_utils.py2
-rw-r--r--tests/tools.py36
9 files changed, 1252 insertions, 193 deletions
diff --git a/tests/__init__.py b/tests/__init__.py
index 80a96a4..3c620d6 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
from .test_base import *
from .test_containers import *
from .test_declarations import *
+from .test_fuzzy import *
from .test_using import *
from .test_utils import *
diff --git a/tests/compat.py b/tests/compat.py
index 15fa3ae..6a1eb80 100644
--- a/tests/compat.py
+++ b/tests/compat.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -21,7 +21,17 @@
"""Compatibility tools for tests"""
+import sys
+
+is_python2 = (sys.version_info[0] == 2)
+
try:
import unittest2 as unittest
except ImportError:
import unittest
+
+if sys.version_info[0:2] < (3, 3):
+ import mock
+else:
+ from unittest import mock
+
diff --git a/tests/test_base.py b/tests/test_base.py
index 7575ee2..216711a 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -44,7 +44,7 @@ class FakeDjangoModel(object):
objects = FakeDjangoManager()
def __init__(self, **kwargs):
- for name, value in kwargs.iteritems():
+ for name, value in kwargs.items():
setattr(self, name, value)
self.id = None
@@ -53,25 +53,42 @@ class TestModel(FakeDjangoModel):
class SafetyTestCase(unittest.TestCase):
- def testBaseFactory(self):
- self.assertRaises(RuntimeError, base.BaseFactory)
+ def test_base_factory(self):
+ self.assertRaises(base.FactoryError, base.BaseFactory)
+
+
+class AbstractFactoryTestCase(unittest.TestCase):
+ def test_factory_for_optional(self):
+ """Ensure that FACTORY_FOR is optional for ABSTRACT_FACTORY."""
+ class TestObjectFactory(base.Factory):
+ ABSTRACT_FACTORY = True
+
+ # Passed
class FactoryTestCase(unittest.TestCase):
- def testDisplay(self):
+ def test_factory_for(self):
+ class TestObjectFactory(base.Factory):
+ FACTORY_FOR = TestObject
+
+ self.assertEqual(TestObject, TestObjectFactory.FACTORY_FOR)
+ obj = TestObjectFactory.build()
+ self.assertFalse(hasattr(obj, 'FACTORY_FOR'))
+
+ def test_display(self):
class TestObjectFactory(base.Factory):
FACTORY_FOR = FakeDjangoModel
self.assertIn('TestObjectFactory', str(TestObjectFactory))
self.assertIn('FakeDjangoModel', str(TestObjectFactory))
- def testLazyAttributeNonExistentParam(self):
+ def test_lazy_attribute_non_existent_param(self):
class TestObjectFactory(base.Factory):
one = declarations.LazyAttribute(lambda a: a.does_not_exist )
self.assertRaises(AttributeError, TestObjectFactory)
- def testInheritanceWithSequence(self):
+ def test_inheritance_with_sequence(self):
"""Tests that sequence IDs are shared between parent and son."""
class TestObjectFactory(base.Factory):
one = declarations.Sequence(lambda a: a)
@@ -86,15 +103,16 @@ class FactoryTestCase(unittest.TestCase):
ones = set([x.one for x in (parent, alt_parent, sub, alt_sub)])
self.assertEqual(4, len(ones))
+
class FactoryDefaultStrategyTestCase(unittest.TestCase):
def setUp(self):
- self.default_strategy = base.Factory.default_strategy
+ self.default_strategy = base.Factory.FACTORY_STRATEGY
def tearDown(self):
- base.Factory.default_strategy = self.default_strategy
+ base.Factory.FACTORY_STRATEGY = self.default_strategy
- def testBuildStrategy(self):
- base.Factory.default_strategy = base.BUILD_STRATEGY
+ def test_build_strategy(self):
+ base.Factory.FACTORY_STRATEGY = base.BUILD_STRATEGY
class TestModelFactory(base.Factory):
one = 'one'
@@ -103,8 +121,8 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertEqual(test_model.one, 'one')
self.assertFalse(test_model.id)
- def testCreateStrategy(self):
- # Default default_strategy
+ def test_create_strategy(self):
+ # Default FACTORY_STRATEGY
class TestModelFactory(base.Factory):
one = 'one'
@@ -113,8 +131,8 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertEqual(test_model.one, 'one')
self.assertTrue(test_model.id)
- def testStubStrategy(self):
- base.Factory.default_strategy = base.STUB_STRATEGY
+ def test_stub_strategy(self):
+ base.Factory.FACTORY_STRATEGY = base.STUB_STRATEGY
class TestModelFactory(base.Factory):
one = 'one'
@@ -123,23 +141,23 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
self.assertEqual(test_model.one, 'one')
self.assertFalse(hasattr(test_model, 'id')) # We should have a plain old object
- def testUnknownStrategy(self):
- base.Factory.default_strategy = 'unknown'
+ def test_unknown_strategy(self):
+ base.Factory.FACTORY_STRATEGY = 'unknown'
class TestModelFactory(base.Factory):
one = 'one'
self.assertRaises(base.Factory.UnknownStrategy, TestModelFactory)
- def testStubWithNonStubStrategy(self):
+ def test_stub_with_non_stub_strategy(self):
class TestModelFactory(base.StubFactory):
one = 'one'
- TestModelFactory.default_strategy = base.CREATE_STRATEGY
+ TestModelFactory.FACTORY_STRATEGY = base.CREATE_STRATEGY
self.assertRaises(base.StubFactory.UnsupportedStrategy, TestModelFactory)
- TestModelFactory.default_strategy = base.BUILD_STRATEGY
+ TestModelFactory.FACTORY_STRATEGY = base.BUILD_STRATEGY
self.assertRaises(base.StubFactory.UnsupportedStrategy, TestModelFactory)
def test_change_strategy(self):
@@ -147,54 +165,35 @@ class FactoryDefaultStrategyTestCase(unittest.TestCase):
class TestModelFactory(base.StubFactory):
one = 'one'
- self.assertEqual(base.CREATE_STRATEGY, TestModelFactory.default_strategy)
+ self.assertEqual(base.CREATE_STRATEGY, TestModelFactory.FACTORY_STRATEGY)
class FactoryCreationTestCase(unittest.TestCase):
- def testFactoryFor(self):
+ def test_factory_for(self):
class TestFactory(base.Factory):
FACTORY_FOR = TestObject
self.assertTrue(isinstance(TestFactory.build(), TestObject))
- def testAutomaticAssociatedClassDiscovery(self):
- class TestObjectFactory(base.Factory):
- pass
-
- self.assertTrue(isinstance(TestObjectFactory.build(), TestObject))
-
- def testDeprecationWarning(self):
- """Make sure the 'auto-discovery' deprecation warning is issued."""
-
- with warnings.catch_warnings(record=True) as w:
- # Clear the warning registry.
- if hasattr(base, '__warningregistry__'):
- base.__warningregistry__.clear()
-
- warnings.simplefilter('always')
- class TestObjectFactory(base.Factory):
- pass
-
- self.assertEqual(1, len(w))
- self.assertIn('deprecated', str(w[0].message))
-
- def testStub(self):
+ def test_stub(self):
class TestFactory(base.StubFactory):
pass
- self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY)
+ self.assertEqual(TestFactory.FACTORY_STRATEGY, base.STUB_STRATEGY)
- def testInheritanceWithStub(self):
+ def test_inheritance_with_stub(self):
class TestObjectFactory(base.StubFactory):
pass
class TestFactory(TestObjectFactory):
pass
- self.assertEqual(TestFactory.default_strategy, base.STUB_STRATEGY)
+ self.assertEqual(TestFactory.FACTORY_STRATEGY, base.STUB_STRATEGY)
+
+ def test_custom_creation(self):
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
- def testCustomCreation(self):
- class TestModelFactory(base.Factory):
@classmethod
def _prepare(cls, create, **kwargs):
kwargs['four'] = 4
@@ -212,15 +211,7 @@ class FactoryCreationTestCase(unittest.TestCase):
# Errors
- def testNoAssociatedClassWithAutodiscovery(self):
- try:
- class TestFactory(base.Factory):
- pass
- self.fail()
- except base.Factory.AssociatedClassError as e:
- self.assertTrue('autodiscovery' in str(e))
-
- def testNoAssociatedClassWithoutAutodiscovery(self):
+ def test_no_associated_class(self):
try:
class Test(base.Factory):
pass
diff --git a/tests/test_containers.py b/tests/test_containers.py
index 797c480..7c8d829 100644
--- a/tests/test_containers.py
+++ b/tests/test_containers.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -242,7 +242,7 @@ class AttributeBuilderTestCase(unittest.TestCase):
self.assertEqual({'one': 2}, ab.build(create=False))
def test_factory_defined_sequence(self):
- seq = declarations.Sequence(lambda n: 'xx' + n)
+ seq = declarations.Sequence(lambda n: 'xx%d' % n)
class FakeFactory(object):
@classmethod
@@ -259,7 +259,7 @@ class AttributeBuilderTestCase(unittest.TestCase):
self.assertEqual({'one': 'xx1'}, ab.build(create=False))
def test_additionnal_sequence(self):
- seq = declarations.Sequence(lambda n: 'xx' + n)
+ seq = declarations.Sequence(lambda n: 'xx%d' % n)
class FakeFactory(object):
@classmethod
@@ -276,8 +276,8 @@ class AttributeBuilderTestCase(unittest.TestCase):
self.assertEqual({'one': 1, 'two': 'xx1'}, ab.build(create=False))
def test_replaced_sequence(self):
- seq = declarations.Sequence(lambda n: 'xx' + n)
- seq2 = declarations.Sequence(lambda n: 'yy' + n)
+ seq = declarations.Sequence(lambda n: 'xx%d' % n)
+ seq2 = declarations.Sequence(lambda n: 'yy%d' % n)
class FakeFactory(object):
@classmethod
@@ -331,7 +331,7 @@ class AttributeBuilderTestCase(unittest.TestCase):
ab = containers.AttributeBuilder(FakeFactory, {'one__blah': 1, 'two__bar': 2})
self.assertTrue(ab.has_subfields(sf))
- self.assertEqual(['one'], ab._subfields.keys())
+ self.assertEqual(['one'], list(ab._subfields.keys()))
self.assertEqual(2, ab._attrs['two__bar'])
def test_sub_factory(self):
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index 1c0502b..4c08dfa 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -20,16 +20,21 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
+import datetime
+import itertools
+import warnings
-from factory.declarations import deepgetattr, OrderedDeclaration, \
- PostGenerationDeclaration, Sequence
+from factory import declarations
+from factory import helpers
+
+from .compat import mock, unittest
+from . import tools
-from .compat import unittest
class OrderedDeclarationTestCase(unittest.TestCase):
def test_errors(self):
- decl = OrderedDeclaration()
- self.assertRaises(NotImplementedError, decl.evaluate, None, {})
+ decl = declarations.OrderedDeclaration()
+ self.assertRaises(NotImplementedError, decl.evaluate, None, {}, False)
class DigTestCase(unittest.TestCase):
@@ -43,34 +48,234 @@ class DigTestCase(unittest.TestCase):
obj.a.b = self.MyObj(3)
obj.a.b.c = self.MyObj(4)
- self.assertEqual(2, deepgetattr(obj, 'a').n)
- self.assertRaises(AttributeError, deepgetattr, obj, 'b')
- self.assertEqual(2, deepgetattr(obj, 'a.n'))
- self.assertEqual(3, deepgetattr(obj, 'a.c', 3))
- self.assertRaises(AttributeError, deepgetattr, obj, 'a.c.n')
- self.assertRaises(AttributeError, deepgetattr, obj, 'a.d')
- self.assertEqual(3, deepgetattr(obj, 'a.b').n)
- self.assertEqual(3, deepgetattr(obj, 'a.b.n'))
- self.assertEqual(4, deepgetattr(obj, 'a.b.c').n)
- self.assertEqual(4, deepgetattr(obj, 'a.b.c.n'))
- self.assertEqual(42, deepgetattr(obj, 'a.b.c.n.x', 42))
+ self.assertEqual(2, declarations.deepgetattr(obj, 'a').n)
+ self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'b')
+ self.assertEqual(2, declarations.deepgetattr(obj, 'a.n'))
+ self.assertEqual(3, declarations.deepgetattr(obj, 'a.c', 3))
+ self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'a.c.n')
+ self.assertRaises(AttributeError, declarations.deepgetattr, obj, 'a.d')
+ self.assertEqual(3, declarations.deepgetattr(obj, 'a.b').n)
+ self.assertEqual(3, declarations.deepgetattr(obj, 'a.b.n'))
+ self.assertEqual(4, declarations.deepgetattr(obj, 'a.b.c').n)
+ self.assertEqual(4, declarations.deepgetattr(obj, 'a.b.c.n'))
+ self.assertEqual(42, declarations.deepgetattr(obj, 'a.b.c.n.x', 42))
+
+
+class SelfAttributeTestCase(unittest.TestCase):
+ def test_standard(self):
+ a = declarations.SelfAttribute('foo.bar.baz')
+ self.assertEqual(0, a.depth)
+ self.assertEqual('foo.bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+ def test_dot(self):
+ a = declarations.SelfAttribute('.bar.baz')
+ self.assertEqual(1, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+ def test_default(self):
+ a = declarations.SelfAttribute('bar.baz', 42)
+ self.assertEqual(0, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(42, a.default)
+
+ def test_parent(self):
+ a = declarations.SelfAttribute('..bar.baz')
+ self.assertEqual(2, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+ def test_grandparent(self):
+ a = declarations.SelfAttribute('...bar.baz')
+ self.assertEqual(3, a.depth)
+ self.assertEqual('bar.baz', a.attribute_name)
+ self.assertEqual(declarations._UNSPECIFIED, a.default)
+
+
+class IteratorTestCase(unittest.TestCase):
+ def test_cycle(self):
+ it = declarations.Iterator([1, 2])
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertEqual(1, it.evaluate(2, None, False))
+ self.assertEqual(2, it.evaluate(3, None, False))
+
+ def test_no_cycling(self):
+ it = declarations.Iterator([1, 2], cycle=False)
+ self.assertEqual(1, it.evaluate(0, None, False))
+ self.assertEqual(2, it.evaluate(1, None, False))
+ self.assertRaises(StopIteration, it.evaluate, 2, None, False)
+
+ def test_getter(self):
+ it = declarations.Iterator([(1, 2), (1, 3)], getter=lambda p: p[1])
+ self.assertEqual(2, it.evaluate(0, None, False))
+ self.assertEqual(3, it.evaluate(1, None, False))
+ self.assertEqual(2, it.evaluate(2, None, False))
+ self.assertEqual(3, it.evaluate(3, None, False))
class PostGenerationDeclarationTestCase(unittest.TestCase):
def test_extract_no_prefix(self):
- decl = PostGenerationDeclaration()
+ decl = declarations.PostGenerationDeclaration()
extracted, kwargs = decl.extract('foo', {'foo': 13, 'foo__bar': 42})
self.assertEqual(extracted, 13)
self.assertEqual(kwargs, {'bar': 42})
- def test_extract_with_prefix(self):
- decl = PostGenerationDeclaration(extract_prefix='blah')
+ def test_decorator_simple(self):
+ call_params = []
+ @helpers.post_generation
+ def foo(*args, **kwargs):
+ call_params.append(args)
+ call_params.append(kwargs)
- extracted, kwargs = decl.extract('foo',
+ extracted, kwargs = foo.extract('foo',
{'foo': 13, 'foo__bar': 42, 'blah': 42, 'blah__baz': 1})
- self.assertEqual(extracted, 42)
- self.assertEqual(kwargs, {'baz': 1})
+ self.assertEqual(13, extracted)
+ self.assertEqual({'bar': 42}, kwargs)
+
+ # No value returned.
+ foo.call(None, False, extracted, **kwargs)
+ self.assertEqual(2, len(call_params))
+ self.assertEqual((None, False, 13), call_params[0])
+ self.assertEqual({'bar': 42}, call_params[1])
+
+
+class SubFactoryTestCase(unittest.TestCase):
+
+ def test_arg(self):
+ self.assertRaises(ValueError, declarations.SubFactory, 'UnqualifiedSymbol')
+
+ def test_lazyness(self):
+ f = declarations.SubFactory('factory.declarations.Sequence', x=3)
+ self.assertEqual(None, f.factory)
+
+ self.assertEqual({'x': 3}, f.defaults)
+
+ factory_class = f.get_factory()
+ self.assertEqual(declarations.Sequence, factory_class)
+
+ def test_cache(self):
+ orig_date = datetime.date
+ f = declarations.SubFactory('datetime.date')
+ self.assertEqual(None, f.factory)
+
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ try:
+ # Modify original value
+ datetime.date = None
+ # Repeat import
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ finally:
+ # IMPORTANT: restore attribute.
+ datetime.date = orig_date
+
+
+class RelatedFactoryTestCase(unittest.TestCase):
+
+ def test_arg(self):
+ self.assertRaises(ValueError, declarations.RelatedFactory, 'UnqualifiedSymbol')
+
+ def test_lazyness(self):
+ f = declarations.RelatedFactory('factory.declarations.Sequence', x=3)
+ self.assertEqual(None, f.factory)
+
+ self.assertEqual({'x': 3}, f.defaults)
+
+ factory_class = f.get_factory()
+ self.assertEqual(declarations.Sequence, factory_class)
+
+ def test_cache(self):
+ """Ensure that RelatedFactory tries to import only once."""
+ orig_date = datetime.date
+ f = declarations.RelatedFactory('datetime.date')
+ self.assertEqual(None, f.factory)
+
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ try:
+ # Modify original value
+ datetime.date = None
+ # Repeat import
+ factory_class = f.get_factory()
+ self.assertEqual(orig_date, factory_class)
+
+ finally:
+ # IMPORTANT: restore attribute.
+ datetime.date = orig_date
+
+
+class PostGenerationMethodCallTestCase(unittest.TestCase):
+ def setUp(self):
+ self.obj = mock.MagicMock()
+
+ def test_simplest_setup_and_call(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with()
+
+ def test_call_with_method_args(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', 'data')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with('data')
+
+ def test_call_with_passed_extracted_string(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method')
+ decl.call(self.obj, False, 'data')
+ self.obj.method.assert_called_once_with('data')
+
+ def test_call_with_passed_extracted_int(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False, 1)
+ self.obj.method.assert_called_once_with(1)
+
+ def test_call_with_passed_extracted_iterable(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False, (1, 2, 3))
+ self.obj.method.assert_called_once_with((1, 2, 3))
+
+ def test_call_with_method_kwargs(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', data='data')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with(data='data')
+
+ def test_call_with_passed_kwargs(self):
+ decl = declarations.PostGenerationMethodCall('method')
+ decl.call(self.obj, False, data='other')
+ self.obj.method.assert_called_once_with(data='other')
+
+ def test_multi_call_with_multi_method_args(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', 'arg1', 'arg2')
+ decl.call(self.obj, False)
+ self.obj.method.assert_called_once_with('arg1', 'arg2')
+
+ def test_multi_call_with_passed_multiple_args(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', 'arg1', 'arg2')
+ decl.call(self.obj, False, ('param1', 'param2', 'param3'))
+ self.obj.method.assert_called_once_with('param1', 'param2', 'param3')
+
+ def test_multi_call_with_passed_tuple(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', 'arg1', 'arg2')
+ decl.call(self.obj, False, (('param1', 'param2'),))
+ self.obj.method.assert_called_once_with(('param1', 'param2'))
+
+ def test_multi_call_with_kwargs(self):
+ decl = declarations.PostGenerationMethodCall(
+ 'method', 'arg1', 'arg2')
+ decl.call(self.obj, False, x=2)
+ self.obj.method.assert_called_once_with('arg1', 'arg2', x=2)
diff --git a/tests/test_fuzzy.py b/tests/test_fuzzy.py
new file mode 100644
index 0000000..70a2095
--- /dev/null
+++ b/tests/test_fuzzy.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2010 Mark Sandstrom
+# Copyright (c) 2011-2013 Raphaël Barrois
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+
+from factory import fuzzy
+
+from .compat import mock, unittest
+
+
+class FuzzyAttributeTestCase(unittest.TestCase):
+ def test_simple_call(self):
+ d = fuzzy.FuzzyAttribute(lambda: 10)
+
+ res = d.evaluate(2, None, False)
+ self.assertEqual(10, res)
+
+ res = d.evaluate(2, None, False)
+ self.assertEqual(10, res)
+
+
+class FuzzyChoiceTestCase(unittest.TestCase):
+ def test_unbiased(self):
+ options = [1, 2, 3]
+ d = fuzzy.FuzzyChoice(options)
+ res = d.evaluate(2, None, False)
+ self.assertIn(res, options)
+
+ def test_mock(self):
+ options = [1, 2, 3]
+ fake_choice = lambda d: sum(d)
+
+ d = fuzzy.FuzzyChoice(options)
+
+ with mock.patch('random.choice', fake_choice):
+ res = d.evaluate(2, None, False)
+
+ self.assertEqual(6, res)
+
+ def test_generator(self):
+ def options():
+ for i in range(3):
+ yield i
+
+ d = fuzzy.FuzzyChoice(options())
+
+ res = d.evaluate(2, None, False)
+ self.assertIn(res, [0, 1, 2])
+
+ # And repeat
+ res = d.evaluate(2, None, False)
+ self.assertIn(res, [0, 1, 2])
+
+
+class FuzzyIntegerTestCase(unittest.TestCase):
+ def test_definition(self):
+ """Tests all ways of defining a FuzzyInteger."""
+ fuzz = fuzzy.FuzzyInteger(2, 3)
+ for _i in range(20):
+ res = fuzz.evaluate(2, None, False)
+ self.assertIn(res, [2, 3])
+
+ fuzz = fuzzy.FuzzyInteger(4)
+ for _i in range(20):
+ res = fuzz.evaluate(2, None, False)
+ self.assertIn(res, [0, 1, 2, 3, 4])
+
+ def test_biased(self):
+ fake_randint = lambda low, high: low + high
+
+ fuzz = fuzzy.FuzzyInteger(2, 8)
+
+ with mock.patch('random.randint', fake_randint):
+ res = fuzz.evaluate(2, None, False)
+
+ self.assertEqual(10, res)
+
+ def test_biased_high_only(self):
+ fake_randint = lambda low, high: low + high
+
+ fuzz = fuzzy.FuzzyInteger(8)
+
+ with mock.patch('random.randint', fake_randint):
+ res = fuzz.evaluate(2, None, False)
+
+ self.assertEqual(8, res)
diff --git a/tests/test_using.py b/tests/test_using.py
index f4d5440..497e206 100644
--- a/tests/test_using.py
+++ b/tests/test_using.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
@@ -23,7 +23,8 @@
import factory
-from .compat import unittest
+from .compat import is_python2, unittest
+from . import tools
class TestObject(object):
@@ -34,17 +35,39 @@ class TestObject(object):
self.four = four
self.five = five
-class FakeDjangoModel(object):
- class FakeDjangoManager(object):
+
+class FakeModel(object):
+ @classmethod
+ def create(cls, **kwargs):
+ instance = cls(**kwargs)
+ instance.id = 1
+ return instance
+
+ class FakeModelManager(object):
+ def get_or_create(self, **kwargs):
+ defaults = kwargs.pop('defaults', {})
+ kwargs.update(defaults)
+ instance = FakeModel.create(**kwargs)
+ instance.id = 2
+ instance._defaults = defaults
+ return instance, True
+
def create(self, **kwargs):
- fake_model = FakeDjangoModel(**kwargs)
- fake_model.id = 1
- return fake_model
+ instance = FakeModel.create(**kwargs)
+ instance.id = 2
+ instance._defaults = None
+ return instance
+
+ def values_list(self, *args, **kwargs):
+ return self
+
+ def order_by(self, *args, **kwargs):
+ return [1]
- objects = FakeDjangoManager()
+ objects = FakeModelManager()
def __init__(self, **kwargs):
- for name, value in kwargs.iteritems():
+ for name, value in kwargs.items():
setattr(self, name, value)
self.id = None
@@ -83,8 +106,13 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.four, None)
def test_create(self):
- obj = factory.create(FakeDjangoModel, foo='bar')
- self.assertEqual(obj.id, 1)
+ obj = factory.create(FakeModel, foo='bar')
+ self.assertEqual(obj.id, None)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_create_custom_base(self):
+ obj = factory.create(FakeModel, foo='bar', FACTORY_CLASS=factory.DjangoModelFactory)
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_create_batch(self):
@@ -94,7 +122,18 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(4, len(set(objs)))
for obj in objs:
- self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.id, None)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_create_batch_custom_base(self):
+ objs = factory.create_batch(FakeModel, 4, foo='bar',
+ FACTORY_CLASS=factory.DjangoModelFactory)
+
+ self.assertEqual(4, len(objs))
+ self.assertEqual(4, len(set(objs)))
+
+ for obj in objs:
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_stub(self):
@@ -118,8 +157,14 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.foo, 'bar')
def test_generate_create(self):
- obj = factory.generate(FakeDjangoModel, factory.CREATE_STRATEGY, foo='bar')
- self.assertEqual(obj.id, 1)
+ obj = factory.generate(FakeModel, factory.CREATE_STRATEGY, foo='bar')
+ self.assertEqual(obj.id, None)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_generate_create_custom_base(self):
+ obj = factory.generate(FakeModel, factory.CREATE_STRATEGY, foo='bar',
+ FACTORY_CLASS=factory.DjangoModelFactory)
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_generate_stub(self):
@@ -144,7 +189,18 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(20, len(set(objs)))
for obj in objs:
- self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.id, None)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_generate_batch_create_custom_base(self):
+ objs = factory.generate_batch(FakeModel, factory.CREATE_STRATEGY, 20, foo='bar',
+ FACTORY_CLASS=factory.DjangoModelFactory)
+
+ self.assertEqual(20, len(objs))
+ self.assertEqual(20, len(set(objs)))
+
+ for obj in objs:
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_generate_batch_stub(self):
@@ -163,8 +219,13 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(obj.foo, 'bar')
def test_simple_generate_create(self):
- obj = factory.simple_generate(FakeDjangoModel, True, foo='bar')
- self.assertEqual(obj.id, 1)
+ obj = factory.simple_generate(FakeModel, True, foo='bar')
+ self.assertEqual(obj.id, None)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_simple_generate_create_custom_base(self):
+ obj = factory.simple_generate(FakeModel, True, foo='bar', FACTORY_CLASS=factory.DjangoModelFactory)
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_simple_generate_batch_build(self):
@@ -184,7 +245,18 @@ class SimpleBuildTestCase(unittest.TestCase):
self.assertEqual(20, len(set(objs)))
for obj in objs:
- self.assertEqual(obj.id, 1)
+ self.assertEqual(obj.id, None)
+ self.assertEqual(obj.foo, 'bar')
+
+ def test_simple_generate_batch_create_custom_base(self):
+ objs = factory.simple_generate_batch(FakeModel, True, 20, foo='bar',
+ FACTORY_CLASS=factory.DjangoModelFactory)
+
+ self.assertEqual(20, len(objs))
+ self.assertEqual(20, len(set(objs)))
+
+ for obj in objs:
+ self.assertEqual(obj.id, 2)
self.assertEqual(obj.foo, 'bar')
def test_make_factory(self):
@@ -204,13 +276,23 @@ class SimpleBuildTestCase(unittest.TestCase):
class UsingFactoryTestCase(unittest.TestCase):
- def testAttribute(self):
+ def test_attribute(self):
class TestObjectFactory(factory.Factory):
one = 'one'
test_object = TestObjectFactory.build()
self.assertEqual(test_object.one, 'one')
+ def test_inheritance(self):
+ @factory.use_strategy(factory.BUILD_STRATEGY)
+ class TestObjectFactory(factory.Factory, TestObject):
+ FACTORY_FOR = TestObject
+
+ one = 'one'
+
+ test_object = TestObjectFactory()
+ self.assertEqual(test_object.one, 'one')
+
def test_abstract(self):
class SomeAbstractFactory(factory.Factory):
ABSTRACT_FACTORY = True
@@ -222,10 +304,12 @@ class UsingFactoryTestCase(unittest.TestCase):
test_object = InheritedFactory.build()
self.assertEqual(test_object.one, 'one')
- def testSequence(self):
+ def test_sequence(self):
class TestObjectFactory(factory.Factory):
- one = factory.Sequence(lambda n: 'one' + n)
- two = factory.Sequence(lambda n: 'two' + n)
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: 'one%d' % n)
+ two = factory.Sequence(lambda n: 'two%d' % n)
test_object0 = TestObjectFactory.build()
self.assertEqual(test_object0.one, 'one0')
@@ -235,14 +319,14 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(test_object1.one, 'one1')
self.assertEqual(test_object1.two, 'two1')
- def testSequenceCustomBegin(self):
+ def test_sequence_custom_begin(self):
class TestObjectFactory(factory.Factory):
@classmethod
def _setup_next_sequence(cls):
return 42
- one = factory.Sequence(lambda n: 'one' + n)
- two = factory.Sequence(lambda n: 'two' + n)
+ one = factory.Sequence(lambda n: 'one%d' % n)
+ two = factory.Sequence(lambda n: 'two%d' % n)
test_object0 = TestObjectFactory.build()
self.assertEqual('one42', test_object0.one)
@@ -252,10 +336,61 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual('one43', test_object1.one)
self.assertEqual('two43', test_object1.two)
+ def test_sequence_override(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: 'one%d' % n)
+
+ o1 = TestObjectFactory()
+ o2 = TestObjectFactory()
+ o3 = TestObjectFactory(__sequence=42)
+ o4 = TestObjectFactory()
+
+ self.assertEqual('one0', o1.one)
+ self.assertEqual('one1', o2.one)
+ self.assertEqual('one42', o3.one)
+ self.assertEqual('one2', o4.one)
+
+ def test_custom_create(self):
+ class TestModelFactory(factory.Factory):
+ FACTORY_FOR = TestModel
+
+ two = 2
+
+ @classmethod
+ def _create(cls, target_class, *args, **kwargs):
+ obj = target_class.create(**kwargs)
+ obj.properly_created = True
+ return obj
+
+ obj = TestModelFactory.create(one=1)
+ self.assertEqual(1, obj.one)
+ self.assertEqual(2, obj.two)
+ self.assertEqual(1, obj.id)
+ self.assertTrue(obj.properly_created)
+
+ def test_non_django_create(self):
+ class NonDjango(object):
+ def __init__(self, x, y=2):
+ self.x = x
+ self.y = y
+
+ class NonDjangoFactory(factory.Factory):
+ FACTORY_FOR = NonDjango
+
+ x = 3
+
+ obj = NonDjangoFactory.create()
+ self.assertEqual(3, obj.x)
+ self.assertEqual(2, obj.y)
+
def test_sequence_batch(self):
class TestObjectFactory(factory.Factory):
- one = factory.Sequence(lambda n: 'one' + n)
- two = factory.Sequence(lambda n: 'two' + n)
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: 'one%d' % n)
+ two = factory.Sequence(lambda n: 'two%d' % n)
objs = TestObjectFactory.build_batch(20)
@@ -265,7 +400,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual('one%d' % i, obj.one)
self.assertEqual('two%d' % i, obj.two)
- def testLazyAttribute(self):
+ def test_lazy_attribute(self):
class TestObjectFactory(factory.Factory):
one = factory.LazyAttribute(lambda a: 'abc' )
two = factory.LazyAttribute(lambda a: a.one + ' xyz')
@@ -274,10 +409,12 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(test_object.one, 'abc')
self.assertEqual(test_object.two, 'abc xyz')
- def testLazyAttributeSequence(self):
+ def test_lazy_attribute_sequence(self):
class TestObjectFactory(factory.Factory):
- one = factory.LazyAttributeSequence(lambda a, n: 'abc' + n)
- two = factory.LazyAttributeSequence(lambda a, n: a.one + ' xyz' + n)
+ FACTORY_FOR = TestObject
+
+ one = factory.LazyAttributeSequence(lambda a, n: 'abc%d' % n)
+ two = factory.LazyAttributeSequence(lambda a, n: a.one + ' xyz%d' % n)
test_object0 = TestObjectFactory.build()
self.assertEqual(test_object0.one, 'abc0')
@@ -287,7 +424,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(test_object1.one, 'abc1')
self.assertEqual(test_object1.two, 'abc1 xyz1')
- def testLazyAttributeDecorator(self):
+ def test_lazy_attribute_decorator(self):
class TestObjectFactory(factory.Factory):
@factory.lazy_attribute
def one(a):
@@ -296,7 +433,7 @@ class UsingFactoryTestCase(unittest.TestCase):
test_object = TestObjectFactory.build()
self.assertEqual(test_object.one, 'one')
- def testSelfAttribute(self):
+ def test_self_attribute(self):
class TmpObj(object):
n = 3
@@ -313,32 +450,51 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(3, test_object.four)
self.assertEqual(5, test_object.five)
- def testSequenceDecorator(self):
+ def test_self_attribute_parent(self):
+ class TestModel2(FakeModel):
+ pass
+
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
+ one = 3
+ three = factory.SelfAttribute('..bar')
+
+ class TestModel2Factory(FakeModelFactory):
+ FACTORY_FOR = TestModel2
+ bar = 4
+ two = factory.SubFactory(TestModelFactory, one=1)
+
+ test_model = TestModel2Factory()
+ self.assertEqual(4, test_model.two.three)
+
+ def test_sequence_decorator(self):
class TestObjectFactory(factory.Factory):
@factory.sequence
def one(n):
- return 'one' + n
+ return 'one%d' % n
test_object = TestObjectFactory.build()
self.assertEqual(test_object.one, 'one0')
- def testLazyAttributeSequenceDecorator(self):
+ def test_lazy_attribute_sequence_decorator(self):
class TestObjectFactory(factory.Factory):
@factory.lazy_attribute_sequence
def one(a, n):
- return 'one' + n
+ return 'one%d' % n
@factory.lazy_attribute_sequence
def two(a, n):
- return a.one + ' two' + n
+ return a.one + ' two%d' % n
test_object = TestObjectFactory.build()
self.assertEqual(test_object.one, 'one0')
self.assertEqual(test_object.two, 'one0 two0')
- def testBuildWithParameters(self):
+ def test_build_with_parameters(self):
class TestObjectFactory(factory.Factory):
- one = factory.Sequence(lambda n: 'one' + n)
- two = factory.Sequence(lambda n: 'two' + n)
+ FACTORY_FOR = TestObject
+
+ one = factory.Sequence(lambda n: 'one%d' % n)
+ two = factory.Sequence(lambda n: 'two%d' % n)
test_object0 = TestObjectFactory.build(three='three')
self.assertEqual(test_object0.one, 'one0')
@@ -349,8 +505,10 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(test_object1.one, 'other')
self.assertEqual(test_object1.two, 'two1')
- def testCreate(self):
- class TestModelFactory(factory.Factory):
+ def test_create(self):
+ class TestModelFactory(FakeModelFactory):
+ FACTORY_FOR = TestModel
+
one = 'one'
test_model = TestModelFactory.create()
@@ -488,7 +646,7 @@ class UsingFactoryTestCase(unittest.TestCase):
three = factory.Sequence(lambda n: int(n))
objs = TestObjectFactory.stub_batch(20,
- one=factory.Sequence(lambda n: n))
+ one=factory.Sequence(lambda n: str(n)))
self.assertEqual(20, len(objs))
self.assertEqual(20, len(set(objs)))
@@ -498,7 +656,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual('%d two' % i, obj.two)
self.assertEqual(i, obj.three)
- def testInheritance(self):
+ def test_inheritance(self):
class TestObjectFactory(factory.Factory):
one = 'one'
two = factory.LazyAttribute(lambda a: a.one + ' two')
@@ -518,7 +676,7 @@ class UsingFactoryTestCase(unittest.TestCase):
test_object_alt = TestObjectFactory.build()
self.assertEqual(None, test_object_alt.three)
- def testInheritanceWithInheritedClass(self):
+ def test_inheritance_with_inherited_class(self):
class TestObjectFactory(factory.Factory):
one = 'one'
two = factory.LazyAttribute(lambda a: a.one + ' two')
@@ -533,7 +691,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(test_object.three, 'three')
self.assertEqual(test_object.four, 'three four')
- def testDualInheritance(self):
+ def test_dual_inheritance(self):
class TestObjectFactory(factory.Factory):
one = 'one'
@@ -551,19 +709,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual('three', obj.three)
self.assertEqual('four', obj.four)
- def testSetCreationFunction(self):
- def creation_function(class_to_create, **kwargs):
- return "This doesn't even return an instance of {0}".format(class_to_create.__name__)
-
- class TestModelFactory(factory.Factory):
- pass
-
- TestModelFactory.set_creation_function(creation_function)
-
- test_object = TestModelFactory.create()
- self.assertEqual(test_object, "This doesn't even return an instance of TestModel")
-
- def testClassMethodAccessible(self):
+ def test_class_method_accessible(self):
class TestObjectFactory(factory.Factory):
@classmethod
def alt_create(cls, **kwargs):
@@ -571,7 +717,7 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(TestObjectFactory.alt_create(foo=1), {"foo": 1})
- def testStaticMethodAccessible(self):
+ def test_static_method_accessible(self):
class TestObjectFactory(factory.Factory):
@staticmethod
def alt_create(**kwargs):
@@ -579,10 +725,140 @@ class UsingFactoryTestCase(unittest.TestCase):
self.assertEqual(TestObjectFactory.alt_create(foo=1), {"foo": 1})
+ def test_arg_parameters(self):
+ class TestObject(object):
+ def __init__(self, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ FACTORY_ARG_PARAMETERS = ('x', 'y')
+
+ x = 1
+ y = 2
+ z = 3
+ t = 4
+
+ obj = TestObjectFactory.build(x=42, z=5)
+ self.assertEqual((42, 2), obj.args)
+ self.assertEqual({'z': 5, 't': 4}, obj.kwargs)
+
+ def test_hidden_args(self):
+ class TestObject(object):
+ def __init__(self, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ FACTORY_HIDDEN_ARGS = ('x', 'z')
+
+ x = 1
+ y = 2
+ z = 3
+ t = 4
+
+ obj = TestObjectFactory.build(x=42, z=5)
+ self.assertEqual((), obj.args)
+ self.assertEqual({'y': 2, 't': 4}, obj.kwargs)
+
+ def test_hidden_args_and_arg_parameters(self):
+ class TestObject(object):
+ def __init__(self, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ FACTORY_HIDDEN_ARGS = ('x', 'z')
+ FACTORY_ARG_PARAMETERS = ('y',)
+
+ x = 1
+ y = 2
+ z = 3
+ t = 4
+
+ obj = TestObjectFactory.build(x=42, z=5)
+ self.assertEqual((2,), obj.args)
+ self.assertEqual({'t': 4}, obj.kwargs)
+
+
+
+class NonKwargParametersTestCase(unittest.TestCase):
+ def test_build(self):
+ class TestObject(object):
+ def __init__(self, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ FACTORY_ARG_PARAMETERS = ('one', 'two',)
+
+ one = 1
+ two = 2
+ three = 3
+
+ obj = TestObjectFactory.build()
+ self.assertEqual((1, 2), obj.args)
+ self.assertEqual({'three': 3}, obj.kwargs)
+
+ def test_create(self):
+ class TestObject(object):
+ def __init__(self, *args, **kwargs):
+ self.args = None
+ self.kwargs = None
+
+ @classmethod
+ def create(cls, *args, **kwargs):
+ inst = cls()
+ inst.args = args
+ inst.kwargs = kwargs
+ return inst
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ FACTORY_ARG_PARAMETERS = ('one', 'two')
+
+ one = 1
+ two = 2
+ three = 3
+
+ @classmethod
+ def _create(cls, target_class, *args, **kwargs):
+ return target_class.create(*args, **kwargs)
+
+ obj = TestObjectFactory.create()
+ self.assertEqual((1, 2), obj.args)
+ self.assertEqual({'three': 3}, obj.kwargs)
+
+
+class KwargAdjustTestCase(unittest.TestCase):
+ """Tests for the _adjust_kwargs method."""
+
+ def test_build(self):
+ class TestObject(object):
+ def __init__(self, *args, **kwargs):
+ self.args = args
+ self.kwargs = kwargs
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ @classmethod
+ def _adjust_kwargs(cls, **kwargs):
+ kwargs['foo'] = len(kwargs)
+ return kwargs
+
+ obj = TestObjectFactory.build(x=1, y=2, z=3)
+ self.assertEqual({'x': 1, 'y': 2, 'z': 3, 'foo': 3}, obj.kwargs)
+ self.assertEqual((), obj.args)
+
class SubFactoryTestCase(unittest.TestCase):
- def testSubFactory(self):
- class TestModel2(FakeDjangoModel):
+ def test_sub_factory(self):
+ class TestModel2(FakeModel):
pass
class TestModelFactory(factory.Factory):
@@ -598,8 +874,8 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(1, test_model.id)
self.assertEqual(1, test_model.two.id)
- def testSubFactoryWithLazyFields(self):
- class TestModel2(FakeDjangoModel):
+ def test_sub_factory_with_lazy_fields(self):
+ class TestModel2(FakeModel):
pass
class TestModelFactory(factory.Factory):
@@ -608,7 +884,7 @@ class SubFactoryTestCase(unittest.TestCase):
class TestModel2Factory(factory.Factory):
FACTORY_FOR = TestModel2
two = factory.SubFactory(TestModelFactory,
- one=factory.Sequence(lambda n: 'x%sx' % n),
+ one=factory.Sequence(lambda n: 'x%dx' % n),
two=factory.LazyAttribute(
lambda o: '%s%s' % (o.one, o.one)))
@@ -616,10 +892,10 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual('x0x', test_model.two.one)
self.assertEqual('x0xx0x', test_model.two.two)
- def testSubFactoryAndSequence(self):
+ def test_sub_factory_and_sequence(self):
class TestObject(object):
def __init__(self, **kwargs):
- for k, v in kwargs.iteritems():
+ for k, v in kwargs.items():
setattr(self, k, v)
class TestObjectFactory(factory.Factory):
@@ -637,10 +913,10 @@ class SubFactoryTestCase(unittest.TestCase):
wrapping = WrappingTestObjectFactory.build()
self.assertEqual(1, wrapping.wrapped.one)
- def testSubFactoryOverriding(self):
+ def test_sub_factory_overriding(self):
class TestObject(object):
def __init__(self, **kwargs):
- for k, v in kwargs.iteritems():
+ for k, v in kwargs.items():
setattr(self, k, v)
class TestObjectFactory(factory.Factory):
@@ -649,7 +925,7 @@ class SubFactoryTestCase(unittest.TestCase):
class OtherTestObject(object):
def __init__(self, **kwargs):
- for k, v in kwargs.iteritems():
+ for k, v in kwargs.items():
setattr(self, k, v)
class WrappingTestObjectFactory(factory.Factory):
@@ -664,12 +940,12 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(wrapping.wrapped.three, 3)
self.assertEqual(wrapping.wrapped.four, 4)
- def testNestedSubFactory(self):
+ def test_nested_sub_factory(self):
"""Test nested sub-factories."""
class TestObject(object):
def __init__(self, **kwargs):
- for k, v in kwargs.iteritems():
+ for k, v in kwargs.items():
setattr(self, k, v)
class TestObjectFactory(factory.Factory):
@@ -690,12 +966,12 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(outer.wrap.wrapped.two, 2)
self.assertEqual(outer.wrap.wrapped_bis.one, 1)
- def testNestedSubFactoryWithOverriddenSubFactories(self):
+ def test_nested_sub_factory_with_overridden_sub_factories(self):
"""Test nested sub-factories, with attributes overridden with subfactories."""
class TestObject(object):
def __init__(self, **kwargs):
- for k, v in kwargs.iteritems():
+ for k, v in kwargs.items():
setattr(self, k, v)
class TestObjectFactory(factory.Factory):
@@ -718,11 +994,11 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(outer.wrap.wrapped.two.four, 4)
self.assertEqual(outer.wrap.friend, 5)
- def testSubFactoryAndInheritance(self):
+ def test_sub_factory_and_inheritance(self):
"""Test inheriting from a factory with subfactories, overriding."""
class TestObject(object):
def __init__(self, **kwargs):
- for k, v in kwargs.iteritems():
+ for k, v in kwargs.items():
setattr(self, k, v)
class TestObjectFactory(factory.Factory):
@@ -742,7 +1018,7 @@ class SubFactoryTestCase(unittest.TestCase):
self.assertEqual(wrapping.wrapped.two, 4)
self.assertEqual(wrapping.friend, 5)
- def testDiamondSubFactory(self):
+ def test_diamond_sub_factory(self):
"""Tests the case where an object has two fields with a common field."""
class InnerMost(object):
def __init__(self, a, b):
@@ -802,33 +1078,33 @@ class IteratorTestCase(unittest.TestCase):
def test_iterator(self):
class TestObjectFactory(factory.Factory):
- one = factory.Iterator(xrange(10, 30))
+ FACTORY_FOR = TestObject
+
+ one = factory.Iterator(range(10, 30))
objs = TestObjectFactory.build_batch(20)
for i, obj in enumerate(objs):
self.assertEqual(i + 10, obj.one)
- def test_infinite_iterator(self):
+ @unittest.skipUnless(is_python2, "Scope bleeding fixed in Python3+")
+ @tools.disable_warnings
+ def test_iterator_list_comprehension_scope_bleeding(self):
class TestObjectFactory(factory.Factory):
- one = factory.InfiniteIterator(xrange(5))
-
- objs = TestObjectFactory.build_batch(20)
-
- for i, obj in enumerate(objs):
- self.assertEqual(i % 5, obj.one)
+ FACTORY_FOR = TestObject
- def test_infinite_iterator_list_comprehension(self):
- class TestObjectFactory(factory.Factory):
- one = factory.InfiniteIterator([j * 3 for j in xrange(5)])
+ one = factory.Iterator([j * 3 for j in range(5)])
# Scope bleeding: j will end up in TestObjectFactory's scope.
self.assertRaises(TypeError, TestObjectFactory.build)
- def test_infinite_iterator_list_comprehension_protected(self):
+ @tools.disable_warnings
+ def test_iterator_list_comprehension_protected(self):
class TestObjectFactory(factory.Factory):
- one = factory.InfiniteIterator([_j * 3 for _j in xrange(5)])
+ FACTORY_FOR = TestObject
+
+ one = factory.Iterator([_j * 3 for _j in range(5)])
# Scope bleeding : _j will end up in TestObjectFactory's scope.
# But factory_boy ignores it, as a protected variable.
@@ -841,7 +1117,7 @@ class IteratorTestCase(unittest.TestCase):
class TestObjectFactory(factory.Factory):
@factory.iterator
def one():
- for i in xrange(10, 50):
+ for i in range(10, 50):
yield i
objs = TestObjectFactory.build_batch(20)
@@ -849,17 +1125,133 @@ class IteratorTestCase(unittest.TestCase):
for i, obj in enumerate(objs):
self.assertEqual(i + 10, obj.one)
- def test_infinite_iterator_decorator(self):
- class TestObjectFactory(factory.Factory):
- @factory.infinite_iterator
- def one():
- for i in xrange(5):
- yield i
- objs = TestObjectFactory.build_batch(20)
+class BetterFakeModelManager(object):
+ def __init__(self, keys, instance):
+ self.keys = keys
+ self.instance = instance
- for i, obj in enumerate(objs):
- self.assertEqual(i % 5, obj.one)
+ def get_or_create(self, **kwargs):
+ defaults = kwargs.pop('defaults', {})
+ if kwargs == self.keys:
+ return self.instance, False
+ kwargs.update(defaults)
+ instance = FakeModel.create(**kwargs)
+ instance.id = 2
+ return instance, True
+
+ def values_list(self, *args, **kwargs):
+ return self
+
+ def order_by(self, *args, **kwargs):
+ return [1]
+
+
+class BetterFakeModel(object):
+ @classmethod
+ def create(cls, **kwargs):
+ instance = cls(**kwargs)
+ instance.id = 1
+ return instance
+
+ def __init__(self, **kwargs):
+ for name, value in kwargs.items():
+ setattr(self, name, value)
+ self.id = None
+
+
+class DjangoModelFactoryTestCase(unittest.TestCase):
+ def test_simple(self):
+ class FakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = FakeModel
+
+ obj = FakeModelFactory(one=1)
+ self.assertEqual(1, obj.one)
+ self.assertEqual(2, obj.id)
+
+ def test_existing_instance(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x',)
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory()
+ self.assertEqual(prev, obj)
+ self.assertEqual(1, obj.x)
+ self.assertEqual(2, obj.y)
+ self.assertEqual(3, obj.z)
+ self.assertEqual(42, obj.id)
+
+ def test_existing_instance_complex_key(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1, 'y': 2, 'z': 3}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x', 'y', 'z')
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory(y=2, z=3)
+ self.assertEqual(prev, obj)
+ self.assertEqual(1, obj.x)
+ self.assertEqual(2, obj.y)
+ self.assertEqual(3, obj.z)
+ self.assertEqual(42, obj.id)
+
+ def test_new_instance(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x',)
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory(x=2)
+ self.assertNotEqual(prev, obj)
+ self.assertEqual(2, obj.x)
+ self.assertEqual(4, obj.y)
+ self.assertEqual(6, obj.z)
+ self.assertEqual(2, obj.id)
+
+ def test_new_instance_complex_key(self):
+ prev = BetterFakeModel.create(x=1, y=2, z=3)
+ prev.id = 42
+
+ class MyFakeModel(BetterFakeModel):
+ objects = BetterFakeModelManager({'x': 1, 'y': 2, 'z': 3}, prev)
+
+ class MyFakeModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = MyFakeModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('x', 'y', 'z')
+ x = 1
+ y = 4
+ z = 6
+
+ obj = MyFakeModelFactory(y=2, z=4)
+ self.assertNotEqual(prev, obj)
+ self.assertEqual(1, obj.x)
+ self.assertEqual(2, obj.y)
+ self.assertEqual(4, obj.z)
+ self.assertEqual(2, obj.id)
class PostGenerationTestCase(unittest.TestCase):
@@ -867,7 +1259,7 @@ class PostGenerationTestCase(unittest.TestCase):
class TestObjectFactory(factory.Factory):
one = 1
- @factory.post_generation()
+ @factory.post_generation
def incr_one(self, _create, _increment):
self.one += 1
@@ -879,11 +1271,32 @@ class PostGenerationTestCase(unittest.TestCase):
self.assertEqual(3, obj.one)
self.assertFalse(hasattr(obj, 'incr_one'))
+ def test_post_generation_hook(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+
+ one = 1
+
+ @factory.post_generation
+ def incr_one(self, _create, _increment):
+ self.one += 1
+ return 42
+
+ @classmethod
+ def _after_postgeneration(cls, obj, create, results):
+ obj.create = create
+ obj.results = results
+
+ obj = TestObjectFactory.build()
+ self.assertEqual(2, obj.one)
+ self.assertFalse(obj.create)
+ self.assertEqual({'incr_one': 42}, obj.results)
+
def test_post_generation_extraction(self):
class TestObjectFactory(factory.Factory):
one = 1
- @factory.post_generation()
+ @factory.post_generation
def incr_one(self, _create, increment=1):
self.one += increment
@@ -952,6 +1365,305 @@ class PostGenerationTestCase(unittest.TestCase):
# RelatedFactory received "parent" object
self.assertEqual(obj, obj.related.three)
+ def test_related_factory_no_name(self):
+ relateds = []
+ class TestRelatedObject(object):
+ def __init__(self, obj=None, one=None, two=None):
+ relateds.append(self)
+ self.one = one
+ self.two = two
+ self.three = obj
+
+ class TestRelatedObjectFactory(factory.Factory):
+ FACTORY_FOR = TestRelatedObject
+ one = 1
+ two = factory.LazyAttribute(lambda o: o.one + 1)
+
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = 3
+ two = 2
+ three = factory.RelatedFactory(TestRelatedObjectFactory)
+
+ obj = TestObjectFactory.build()
+ # Normal fields
+ self.assertEqual(3, obj.one)
+ self.assertEqual(2, obj.two)
+ # RelatedFactory was built
+ self.assertIsNone(obj.three)
+ self.assertEqual(1, len(relateds))
+ related = relateds[0]
+ self.assertEqual(1, related.one)
+ self.assertEqual(2, related.two)
+ self.assertIsNone(related.three)
+
+ obj = TestObjectFactory.build(three__one=3)
+ # Normal fields
+ self.assertEqual(3, obj.one)
+ self.assertEqual(2, obj.two)
+ # RelatedFactory was build
+ self.assertIsNone(obj.three)
+ self.assertEqual(2, len(relateds))
+
+ related = relateds[1]
+ self.assertEqual(3, related.one)
+ self.assertEqual(4, related.two)
+
+
+class CircularTestCase(unittest.TestCase):
+ def test_example(self):
+ sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
+
+ from .cyclic import foo
+ f = foo.FooFactory.build(bar__foo=None)
+ self.assertEqual(42, f.x)
+ self.assertEqual(13, f.bar.y)
+ self.assertIsNone(f.bar.foo)
+
+ from .cyclic import bar
+ b = bar.BarFactory.build(foo__bar__foo__bar=None)
+ self.assertEqual(13, b.y)
+ self.assertEqual(42, b.foo.x)
+ self.assertEqual(13, b.foo.bar.y)
+ self.assertEqual(42, b.foo.bar.foo.x)
+ self.assertIsNone(b.foo.bar.foo.bar)
+
+
+class DictTestCase(unittest.TestCase):
+ def test_empty_dict(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.Dict({})
+
+ o = TestObjectFactory()
+ self.assertEqual({}, o.one)
+
+ def test_naive_dict(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.Dict({'a': 1})
+
+ o = TestObjectFactory()
+ self.assertEqual({'a': 1}, o.one)
+
+ def test_sequence_dict(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.Dict({'a': factory.Sequence(lambda n: n + 2)})
+
+ o1 = TestObjectFactory()
+ o2 = TestObjectFactory()
+
+ self.assertEqual({'a': 2}, o1.one)
+ self.assertEqual({'a': 3}, o2.one)
+
+ def test_dict_override(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.Dict({'a': 1})
+
+ o = TestObjectFactory(one__a=2)
+ self.assertEqual({'a': 2}, o.one)
+
+ def test_dict_extra_key(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.Dict({'a': 1})
+
+ o = TestObjectFactory(one__b=2)
+ self.assertEqual({'a': 1, 'b': 2}, o.one)
+
+ def test_dict_merged_fields(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ two = 13
+ one = factory.Dict({
+ 'one': 1,
+ 'two': 2,
+ 'three': factory.SelfAttribute('two'),
+ })
+
+ o = TestObjectFactory(one__one=42)
+ self.assertEqual({'one': 42, 'two': 2, 'three': 2}, o.one)
+
+ def test_nested_dicts(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = 1
+ two = factory.Dict({
+ 'one': 3,
+ 'two': factory.SelfAttribute('one'),
+ 'three': factory.Dict({
+ 'one': 5,
+ 'two': factory.SelfAttribute('..one'),
+ 'three': factory.SelfAttribute('...one'),
+ }),
+ })
+
+ o = TestObjectFactory()
+ self.assertEqual(1, o.one)
+ self.assertEqual({
+ 'one': 3,
+ 'two': 3,
+ 'three': {
+ 'one': 5,
+ 'two': 3,
+ 'three': 1,
+ },
+ }, o.two)
+
+
+class ListTestCase(unittest.TestCase):
+ def test_empty_list(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.List([])
+
+ o = TestObjectFactory()
+ self.assertEqual([], o.one)
+
+ def test_naive_list(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.List([1])
+
+ o = TestObjectFactory()
+ self.assertEqual([1], o.one)
+
+ def test_sequence_list(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.List([factory.Sequence(lambda n: n + 2)])
+
+ o1 = TestObjectFactory()
+ o2 = TestObjectFactory()
+
+ self.assertEqual([2], o1.one)
+ self.assertEqual([3], o2.one)
+
+ def test_list_override(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.List([1])
+
+ o = TestObjectFactory(one__0=2)
+ self.assertEqual([2], o.one)
+
+ def test_list_extra_key(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = factory.List([1])
+
+ o = TestObjectFactory(one__1=2)
+ self.assertEqual([1, 2], o.one)
+
+ def test_list_merged_fields(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ two = 13
+ one = factory.List([
+ 1,
+ 2,
+ factory.SelfAttribute('1'),
+ ])
+
+ o = TestObjectFactory(one__0=42)
+ self.assertEqual([42, 2, 2], o.one)
+
+ def test_nested_lists(self):
+ class TestObjectFactory(factory.Factory):
+ FACTORY_FOR = TestObject
+ one = 1
+ two = factory.List([
+ 3,
+ factory.SelfAttribute('0'),
+ factory.List([
+ 5,
+ factory.SelfAttribute('..0'),
+ factory.SelfAttribute('...one'),
+ ]),
+ ])
+
+ o = TestObjectFactory()
+ self.assertEqual(1, o.one)
+ self.assertEqual([
+ 3,
+ 3,
+ [
+ 5,
+ 3,
+ 1,
+ ],
+ ], o.two)
+
+
+class DjangoModelFactoryTestCase(unittest.TestCase):
+ def test_sequence(self):
+ class TestModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = TestModel
+
+ a = factory.Sequence(lambda n: 'foo_%s' % n)
+
+ o1 = TestModelFactory()
+ o2 = TestModelFactory()
+
+ self.assertEqual('foo_2', o1.a)
+ self.assertEqual('foo_3', o2.a)
+
+ o3 = TestModelFactory.build()
+ o4 = TestModelFactory.build()
+
+ self.assertEqual('foo_4', o3.a)
+ self.assertEqual('foo_5', o4.a)
+
+ def test_no_get_or_create(self):
+ class TestModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = TestModel
+
+ a = factory.Sequence(lambda n: 'foo_%s' % n)
+
+ o = TestModelFactory()
+ self.assertEqual(None, o._defaults)
+ self.assertEqual('foo_2', o.a)
+ self.assertEqual(2, o.id)
+
+ def test_get_or_create(self):
+ class TestModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = TestModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('a', 'b')
+
+ a = factory.Sequence(lambda n: 'foo_%s' % n)
+ b = 2
+ c = 3
+ d = 4
+
+ o = TestModelFactory()
+ self.assertEqual({'c': 3, 'd': 4}, o._defaults)
+ self.assertEqual('foo_2', o.a)
+ self.assertEqual(2, o.b)
+ self.assertEqual(3, o.c)
+ self.assertEqual(4, o.d)
+ self.assertEqual(2, o.id)
+
+ def test_full_get_or_create(self):
+ """Test a DjangoModelFactory with all fields in get_or_create."""
+ class TestModelFactory(factory.DjangoModelFactory):
+ FACTORY_FOR = TestModel
+ FACTORY_DJANGO_GET_OR_CREATE = ('a', 'b', 'c', 'd')
+
+ a = factory.Sequence(lambda n: 'foo_%s' % n)
+ b = 2
+ c = 3
+ d = 4
+
+ o = TestModelFactory()
+ self.assertEqual({}, o._defaults)
+ self.assertEqual('foo_2', o.a)
+ self.assertEqual(2, o.b)
+ self.assertEqual(3, o.c)
+ self.assertEqual(4, o.d)
+ self.assertEqual(2, o.id)
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 6fd6ee2..787164a 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2010 Mark Sandstrom
-# Copyright (c) 2011 Raphaël Barrois
+# Copyright (c) 2011-2013 Raphaël Barrois
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
diff --git a/tests/tools.py b/tests/tools.py
new file mode 100644
index 0000000..571899b
--- /dev/null
+++ b/tests/tools.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2010 Mark Sandstrom
+# Copyright (c) 2011-2013 Raphaël Barrois
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+
+
+import functools
+import warnings
+
+
+def disable_warnings(fun):
+ @functools.wraps(fun)
+ def decorated(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ return fun(*args, **kwargs)
+ return decorated
+
+