summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRaphaël Barrois <raphael.barrois@polytechnique.org>2013-06-10 02:00:56 +0200
committerRaphaël Barrois <raphael.barrois@polytechnique.org>2013-06-10 02:00:56 +0200
commita4460de7719eb8c8bb1f3aa72b2ce233b45d9a87 (patch)
tree2a026b06189a7ce87753d1e6a354bb27d90e82db
parent1f90c656089326228ef4aaf3d634cc843fad14b2 (diff)
downloadfactory-boy-a4460de7719eb8c8bb1f3aa72b2ce233b45d9a87.tar
factory-boy-a4460de7719eb8c8bb1f3aa72b2ce233b45d9a87.tar.gz
Factor lazy Factory import code.
-rw-r--r--factory/declarations.py73
-rw-r--r--tests/test_declarations.py61
2 files changed, 56 insertions, 78 deletions
diff --git a/factory/declarations.py b/factory/declarations.py
index 17f4434..abd384e 100644
--- a/factory/declarations.py
+++ b/factory/declarations.py
@@ -301,6 +301,40 @@ class ParameteredAttribute(OrderedDeclaration):
raise NotImplementedError()
+class _FactoryWrapper(object):
+ """Handle a 'factory' arg.
+
+ Such args can be either a Factory subclass, or a fully qualified import
+ path for that subclass (e.g 'myapp.factories.MyFactory').
+ """
+ def __init__(self, factory_or_path):
+ self.factory = None
+ self.module = self.name = ''
+ if isinstance(factory_or_path, type):
+ self.factory = factory_or_path
+ else:
+ if not (compat.is_string(factory_or_path) and '.' in factory_or_path):
+ raise ValueError(
+ "A factory= argument must receive either a class "
+ "or the fully qualified path to a Factory subclass; got "
+ "%r instead." % factory_or_path)
+ self.module, self.name = factory_or_path.rsplit('.', 1)
+
+ def get(self):
+ if self.factory is None:
+ self.factory = utils.import_object(
+ self.module,
+ self.name,
+ )
+ return self.factory
+
+ def __repr__(self):
+ if self.factory is None:
+ return '<_FactoryImport: %s.%s>' % (self.module, self.name)
+ else:
+ return '<_FactoryImport: %s>' % self.factory.__class__
+
+
class SubFactory(ParameteredAttribute):
"""Base class for attributes based upon a sub-factory.
@@ -314,26 +348,11 @@ class SubFactory(ParameteredAttribute):
def __init__(self, factory, **kwargs):
super(SubFactory, self).__init__(**kwargs)
- if isinstance(factory, type):
- self.factory = factory
- self.factory_module = self.factory_name = ''
- else:
- # Must be a string
- if not (compat.is_string(factory) and '.' in factory):
- raise ValueError(
- "The argument of a SubFactory must be either a class "
- "or the fully qualified path to a Factory class; got "
- "%r instead." % factory)
- self.factory = None
- self.factory_module, self.factory_name = factory.rsplit('.', 1)
+ self.factory_wrapper = _FactoryWrapper(factory)
def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
- if self.factory is None:
- # Must be a module path
- self.factory = utils.import_object(
- self.factory_module, self.factory_name)
- return self.factory
+ return self.factory_wrapper.get()
def generate(self, sequence, obj, create, params):
"""Evaluate the current definition and fill its attributes.
@@ -442,27 +461,11 @@ class RelatedFactory(PostGenerationDeclaration):
self.name = factory_related_name
self.defaults = defaults
-
- if isinstance(factory, type):
- self.factory = factory
- self.factory_module = self.factory_name = ''
- else:
- # Must be a string
- if not (compat.is_string(factory) and '.' in factory):
- raise ValueError(
- "The argument of a SubFactory must be either a class "
- "or the fully qualified path to a Factory class; got "
- "%r instead." % factory)
- self.factory = None
- self.factory_module, self.factory_name = factory.rsplit('.', 1)
+ self.factory_wrapper = _FactoryWrapper(factory)
def get_factory(self):
"""Retrieve the wrapped factory.Factory subclass."""
- if self.factory is None:
- # Must be a module path
- self.factory = utils.import_object(
- self.factory_module, self.factory_name)
- return self.factory
+ return self.factory_wrapper.get()
def call(self, obj, create, extracted=None, **kwargs):
passed_kwargs = dict(self.defaults)
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index 90e54c2..e0b2513 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -142,33 +142,40 @@ class PostGenerationDeclarationTestCase(unittest.TestCase):
self.assertEqual({'bar': 42}, call_params[1])
-class SubFactoryTestCase(unittest.TestCase):
+class FactoryWrapperTestCase(unittest.TestCase):
+ def test_invalid_path(self):
+ self.assertRaises(ValueError, declarations._FactoryWrapper, 'UnqualifiedSymbol')
+ self.assertRaises(ValueError, declarations._FactoryWrapper, 42)
- def test_arg(self):
- self.assertRaises(ValueError, declarations.SubFactory, 'UnqualifiedSymbol')
+ def test_class(self):
+ w = declarations._FactoryWrapper(datetime.date)
+ self.assertEqual(datetime.date, w.get())
+
+ def test_path(self):
+ w = declarations._FactoryWrapper('datetime.date')
+ self.assertEqual(datetime.date, w.get())
def test_lazyness(self):
- f = declarations.SubFactory('factory.declarations.Sequence', x=3)
+ f = declarations._FactoryWrapper('factory.declarations.Sequence')
self.assertEqual(None, f.factory)
- self.assertEqual({'x': 3}, f.defaults)
-
- factory_class = f.get_factory()
+ factory_class = f.get()
self.assertEqual(declarations.Sequence, factory_class)
def test_cache(self):
+ """Ensure that _FactoryWrapper tries to import only once."""
orig_date = datetime.date
- f = declarations.SubFactory('datetime.date')
- self.assertEqual(None, f.factory)
+ w = declarations._FactoryWrapper('datetime.date')
+ self.assertEqual(None, w.factory)
- factory_class = f.get_factory()
+ factory_class = w.get()
self.assertEqual(orig_date, factory_class)
try:
# Modify original value
datetime.date = None
# Repeat import
- factory_class = f.get_factory()
+ factory_class = w.get()
self.assertEqual(orig_date, factory_class)
finally:
@@ -178,38 +185,6 @@ class SubFactoryTestCase(unittest.TestCase):
class RelatedFactoryTestCase(unittest.TestCase):
- def test_arg(self):
- self.assertRaises(ValueError, declarations.RelatedFactory, 'UnqualifiedSymbol')
-
- def test_lazyness(self):
- f = declarations.RelatedFactory('factory.declarations.Sequence', x=3)
- self.assertEqual(None, f.factory)
-
- self.assertEqual({'x': 3}, f.defaults)
-
- factory_class = f.get_factory()
- self.assertEqual(declarations.Sequence, factory_class)
-
- def test_cache(self):
- """Ensure that RelatedFactory tries to import only once."""
- orig_date = datetime.date
- f = declarations.RelatedFactory('datetime.date')
- self.assertEqual(None, f.factory)
-
- factory_class = f.get_factory()
- self.assertEqual(orig_date, factory_class)
-
- try:
- # Modify original value
- datetime.date = None
- # Repeat import
- factory_class = f.get_factory()
- self.assertEqual(orig_date, factory_class)
-
- finally:
- # IMPORTANT: restore attribute.
- datetime.date = orig_date
-
def test_deprecate_name(self):
with warnings.catch_warnings(record=True) as w: