diff options
-rw-r--r-- | MANIFEST.in | 2 | ||||
-rw-r--r-- | Makefile | 3 | ||||
-rw-r--r-- | factory/django.py | 66 | ||||
-rw-r--r-- | tests/djapp/models.py | 8 | ||||
-rw-r--r-- | tests/djapp/settings.py | 10 | ||||
-rw-r--r-- | tests/test_django.py | 100 | ||||
-rw-r--r-- | tests/testdata/__init__.py | 26 | ||||
-rw-r--r-- | tests/testdata/example.data | 1 |
8 files changed, 215 insertions, 1 deletions
diff --git a/MANIFEST.in b/MANIFEST.in index 3912fee..3f09bc6 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,4 +3,4 @@ include docs/Makefile recursive-include docs *.py *.rst include docs/_static/.keep_dir prune docs/_build -recursive-include tests *.py +recursive-include tests *.py *.data @@ -13,6 +13,9 @@ default: clean: find . -type f -name '*.pyc' -delete + find . -type f -path '*/__pycache__/*' -delete + find . -type d -empty -delete + @rm -rf tmp_test/ test: diff --git a/factory/django.py b/factory/django.py index 3eabca1..e0a744a 100644 --- a/factory/django.py +++ b/factory/django.py @@ -21,13 +21,21 @@ # THE SOFTWARE. +from __future__ import absolute_import from __future__ import unicode_literals +import os """factory_boy extensions for use with the Django framework.""" +try: + from django.core import files as django_files +except ImportError as e: # pragma: no cover + django_files = None + import_failure = e from . import base +from . import declarations class DjangoModelFactory(base.Factory): @@ -100,3 +108,61 @@ class DjangoModelFactory(base.Factory): obj.save() +class FileField(declarations.PostGenerationDeclaration): + """Helper to fill in django.db.models.FileField from a Factory.""" + + def __init__(self, *args, **kwargs): + if django_files is None: # pragma: no cover + raise import_failure + super(FileField, self).__init__(*args, **kwargs) + + def _make_content(self, extraction_context): + path = '' + params = extraction_context.extra + + if params.get('from_path') and params.get('from_file'): + raise ValueError( + "At most one argument from 'from_file' and 'from_path' should " + "be non-empty when calling factory.django.FileField." + ) + + if extraction_context.did_extract: + # Should be a django.core.files.File + content = extraction_context.value + path = content.name + + elif params.get('from_path'): + path = params['from_path'] + f = open(path, 'rb') + content = django_files.File(f, name=path) + + elif params.get('from_file'): + f = params['from_file'] + content = django_files.File(f) + path = content.name + + else: + data = params.get('data', '') + content = django_files.base.ContentFile(data) + + if path: + default_filename = os.path.basename(path) + else: + default_filename = 'example.dat' + + filename = params.get('filename', default_filename) + return filename, content + + def call(self, obj, create, extraction_context): + """Fill in the field.""" + if extraction_context.did_extract and extraction_context.value is None: + # User passed an empty value, don't fill + return + + filename, content = self._make_content(extraction_context) + field_file = getattr(obj, extraction_context.for_field) + try: + field_file.save(filename, content, save=create) + finally: + content.file.close() + return field_file diff --git a/tests/djapp/models.py b/tests/djapp/models.py index c107add..52acebe 100644 --- a/tests/djapp/models.py +++ b/tests/djapp/models.py @@ -22,7 +22,9 @@ """Helpers for testing django apps.""" +import os.path +from django.conf import settings from django.db import models class StandardModel(models.Model): @@ -34,3 +36,9 @@ class NonIntegerPk(models.Model): bar = models.CharField(max_length=20, blank=True) +WITHFILE_UPLOAD_TO = 'django' +WITHFILE_UPLOAD_DIR = os.path.join(settings.MEDIA_ROOT, WITHFILE_UPLOAD_TO) + +class WithFile(models.Model): + afile = models.FileField(upload_to=WITHFILE_UPLOAD_TO) + diff --git a/tests/djapp/settings.py b/tests/djapp/settings.py index 787d3f3..c1b79b0 100644 --- a/tests/djapp/settings.py +++ b/tests/djapp/settings.py @@ -20,6 +20,16 @@ # THE SOFTWARE. """Settings for factory_boy/Django tests.""" +import os + +FACTORY_ROOT = os.path.join( + os.path.abspath(os.path.dirname(__file__)), # /path/to/fboy/tests/djapp/ + os.pardir, # /path/to/fboy/tests/ + os.pardir, # /path/to/fboy +) + +MEDIA_ROOT = os.path.join(FACTORY_ROOT, 'tmp_test') + DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', diff --git a/tests/test_django.py b/tests/test_django.py index 70bc376..eb35f7a 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -33,6 +33,7 @@ except ImportError: # pragma: no cover from .compat import is_python2, unittest +from . import testdata from . import tools @@ -40,6 +41,7 @@ if django is not None: os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.djapp.settings') from django import test as django_test + from django.db import models as django_models from django.test import simple as django_test_simple from django.test import utils as django_test_utils from .djapp import models @@ -91,6 +93,12 @@ class NonIntegerPkFactory(factory.django.DjangoModelFactory): bar = '' +class WithFileFactory(factory.django.DjangoModelFactory): + FACTORY_FOR = models.WithFile + + afile = factory.django.FileField() + + @unittest.skipIf(django is None, "Django not installed.") class DjangoPkSequenceTestCase(django_test.TestCase): def setUp(self): @@ -163,3 +171,95 @@ class DjangoNonIntegerPkTestCase(django_test.TestCase): nonint2 = NonIntegerPkFactory.create() self.assertEqual('foo1', nonint2.foo) self.assertEqual('foo1', nonint2.pk) + + +@unittest.skipIf(django is None, "Django not installed.") +class DjangoFileFieldTestCase(unittest.TestCase): + + def tearDown(self): + super(DjangoFileFieldTestCase, self).tearDown() + for path in os.listdir(models.WITHFILE_UPLOAD_DIR): + # Remove temporary files written during tests. + os.unlink(os.path.join(models.WITHFILE_UPLOAD_DIR, path)) + + def test_default_build(self): + o = WithFileFactory.build() + self.assertIsNone(o.pk) + self.assertEqual('', o.afile.read()) + self.assertEqual('django/example.dat', o.afile.name) + + def test_default_create(self): + o = WithFileFactory.create() + self.assertIsNotNone(o.pk) + self.assertEqual('', o.afile.read()) + self.assertEqual('django/example.dat', o.afile.name) + + def test_with_content(self): + o = WithFileFactory.build(afile__data='foo') + self.assertIsNone(o.pk) + self.assertEqual('foo', o.afile.read()) + self.assertEqual('django/example.dat', o.afile.name) + + def test_with_file(self): + with open(testdata.TESTFILE_PATH, 'rb') as f: + o = WithFileFactory.build(afile__from_file=f) + self.assertIsNone(o.pk) + self.assertEqual('example_data\n', o.afile.read()) + self.assertEqual('django/example.data', o.afile.name) + + def test_with_path(self): + o = WithFileFactory.build(afile__from_path=testdata.TESTFILE_PATH) + self.assertIsNone(o.pk) + self.assertEqual('example_data\n', o.afile.read()) + self.assertEqual('django/example.data', o.afile.name) + + def test_with_file_empty_path(self): + with open(testdata.TESTFILE_PATH, 'rb') as f: + o = WithFileFactory.build( + afile__from_file=f, + afile__from_path='' + ) + self.assertIsNone(o.pk) + self.assertEqual('example_data\n', o.afile.read()) + self.assertEqual('django/example.data', o.afile.name) + + def test_with_path_empty_file(self): + o = WithFileFactory.build( + afile__from_path=testdata.TESTFILE_PATH, + afile__from_file=None, + ) + self.assertIsNone(o.pk) + self.assertEqual('example_data\n', o.afile.read()) + self.assertEqual('django/example.data', o.afile.name) + + def test_error_both_file_and_path(self): + self.assertRaises(ValueError, WithFileFactory.build, + afile__from_file='fakefile', + afile__from_path=testdata.TESTFILE_PATH, + ) + + def test_override_filename_with_path(self): + o = WithFileFactory.build( + afile__from_path=testdata.TESTFILE_PATH, + afile__filename='example.foo', + ) + self.assertIsNone(o.pk) + self.assertEqual('example_data\n', o.afile.read()) + self.assertEqual('django/example.foo', o.afile.name) + + def test_existing_file(self): + o1 = WithFileFactory.build(afile__from_path=testdata.TESTFILE_PATH) + + o2 = WithFileFactory.build(afile=o1.afile) + self.assertIsNone(o2.pk) + self.assertEqual('example_data\n', o2.afile.read()) + self.assertEqual('django/example_1.data', o2.afile.name) + + def test_no_file(self): + o = WithFileFactory.build(afile=None) + self.assertIsNone(o.pk) + self.assertFalse(o.afile) + + +if __name__ == '__main__': # pragma: no cover + unittest.main() diff --git a/tests/testdata/__init__.py b/tests/testdata/__init__.py new file mode 100644 index 0000000..3d1d441 --- /dev/null +++ b/tests/testdata/__init__.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2011-2013 Raphaƫl Barrois +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. + + +import os.path + +TESTDATA_ROOT = os.path.abspath(os.path.dirname(__file__)) +TESTFILE_PATH = os.path.join(TESTDATA_ROOT, 'example.data') diff --git a/tests/testdata/example.data b/tests/testdata/example.data new file mode 100644 index 0000000..02ff8ec --- /dev/null +++ b/tests/testdata/example.data @@ -0,0 +1 @@ +example_data |