diff options
-rw-r--r-- | factory/declarations.py | 73 | ||||
-rw-r--r-- | tests/test_declarations.py | 61 |
2 files changed, 56 insertions, 78 deletions
diff --git a/factory/declarations.py b/factory/declarations.py index 17f4434..abd384e 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -301,6 +301,40 @@ class ParameteredAttribute(OrderedDeclaration): raise NotImplementedError() +class _FactoryWrapper(object): + """Handle a 'factory' arg. + + Such args can be either a Factory subclass, or a fully qualified import + path for that subclass (e.g 'myapp.factories.MyFactory'). + """ + def __init__(self, factory_or_path): + self.factory = None + self.module = self.name = '' + if isinstance(factory_or_path, type): + self.factory = factory_or_path + else: + if not (compat.is_string(factory_or_path) and '.' in factory_or_path): + raise ValueError( + "A factory= argument must receive either a class " + "or the fully qualified path to a Factory subclass; got " + "%r instead." % factory_or_path) + self.module, self.name = factory_or_path.rsplit('.', 1) + + def get(self): + if self.factory is None: + self.factory = utils.import_object( + self.module, + self.name, + ) + return self.factory + + def __repr__(self): + if self.factory is None: + return '<_FactoryImport: %s.%s>' % (self.module, self.name) + else: + return '<_FactoryImport: %s>' % self.factory.__class__ + + class SubFactory(ParameteredAttribute): """Base class for attributes based upon a sub-factory. @@ -314,26 +348,11 @@ class SubFactory(ParameteredAttribute): def __init__(self, factory, **kwargs): super(SubFactory, self).__init__(**kwargs) - if isinstance(factory, type): - self.factory = factory - self.factory_module = self.factory_name = '' - else: - # Must be a string - if not (compat.is_string(factory) and '.' in factory): - raise ValueError( - "The argument of a SubFactory must be either a class " - "or the fully qualified path to a Factory class; got " - "%r instead." % factory) - self.factory = None - self.factory_module, self.factory_name = factory.rsplit('.', 1) + self.factory_wrapper = _FactoryWrapper(factory) def get_factory(self): """Retrieve the wrapped factory.Factory subclass.""" - if self.factory is None: - # Must be a module path - self.factory = utils.import_object( - self.factory_module, self.factory_name) - return self.factory + return self.factory_wrapper.get() def generate(self, sequence, obj, create, params): """Evaluate the current definition and fill its attributes. @@ -442,27 +461,11 @@ class RelatedFactory(PostGenerationDeclaration): self.name = factory_related_name self.defaults = defaults - - if isinstance(factory, type): - self.factory = factory - self.factory_module = self.factory_name = '' - else: - # Must be a string - if not (compat.is_string(factory) and '.' in factory): - raise ValueError( - "The argument of a SubFactory must be either a class " - "or the fully qualified path to a Factory class; got " - "%r instead." % factory) - self.factory = None - self.factory_module, self.factory_name = factory.rsplit('.', 1) + self.factory_wrapper = _FactoryWrapper(factory) def get_factory(self): """Retrieve the wrapped factory.Factory subclass.""" - if self.factory is None: - # Must be a module path - self.factory = utils.import_object( - self.factory_module, self.factory_name) - return self.factory + return self.factory_wrapper.get() def call(self, obj, create, extracted=None, **kwargs): passed_kwargs = dict(self.defaults) diff --git a/tests/test_declarations.py b/tests/test_declarations.py index 90e54c2..e0b2513 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -142,33 +142,40 @@ class PostGenerationDeclarationTestCase(unittest.TestCase): self.assertEqual({'bar': 42}, call_params[1]) -class SubFactoryTestCase(unittest.TestCase): +class FactoryWrapperTestCase(unittest.TestCase): + def test_invalid_path(self): + self.assertRaises(ValueError, declarations._FactoryWrapper, 'UnqualifiedSymbol') + self.assertRaises(ValueError, declarations._FactoryWrapper, 42) - def test_arg(self): - self.assertRaises(ValueError, declarations.SubFactory, 'UnqualifiedSymbol') + def test_class(self): + w = declarations._FactoryWrapper(datetime.date) + self.assertEqual(datetime.date, w.get()) + + def test_path(self): + w = declarations._FactoryWrapper('datetime.date') + self.assertEqual(datetime.date, w.get()) def test_lazyness(self): - f = declarations.SubFactory('factory.declarations.Sequence', x=3) + f = declarations._FactoryWrapper('factory.declarations.Sequence') self.assertEqual(None, f.factory) - self.assertEqual({'x': 3}, f.defaults) - - factory_class = f.get_factory() + factory_class = f.get() self.assertEqual(declarations.Sequence, factory_class) def test_cache(self): + """Ensure that _FactoryWrapper tries to import only once.""" orig_date = datetime.date - f = declarations.SubFactory('datetime.date') - self.assertEqual(None, f.factory) + w = declarations._FactoryWrapper('datetime.date') + self.assertEqual(None, w.factory) - factory_class = f.get_factory() + factory_class = w.get() self.assertEqual(orig_date, factory_class) try: # Modify original value datetime.date = None # Repeat import - factory_class = f.get_factory() + factory_class = w.get() self.assertEqual(orig_date, factory_class) finally: @@ -178,38 +185,6 @@ class SubFactoryTestCase(unittest.TestCase): class RelatedFactoryTestCase(unittest.TestCase): - def test_arg(self): - self.assertRaises(ValueError, declarations.RelatedFactory, 'UnqualifiedSymbol') - - def test_lazyness(self): - f = declarations.RelatedFactory('factory.declarations.Sequence', x=3) - self.assertEqual(None, f.factory) - - self.assertEqual({'x': 3}, f.defaults) - - factory_class = f.get_factory() - self.assertEqual(declarations.Sequence, factory_class) - - def test_cache(self): - """Ensure that RelatedFactory tries to import only once.""" - orig_date = datetime.date - f = declarations.RelatedFactory('datetime.date') - self.assertEqual(None, f.factory) - - factory_class = f.get_factory() - self.assertEqual(orig_date, factory_class) - - try: - # Modify original value - datetime.date = None - # Repeat import - factory_class = f.get_factory() - self.assertEqual(orig_date, factory_class) - - finally: - # IMPORTANT: restore attribute. - datetime.date = orig_date - def test_deprecate_name(self): with warnings.catch_warnings(record=True) as w: |