diff options
-rw-r--r-- | docs/changelog.rst | 1 | ||||
-rw-r--r-- | factory/django.py | 31 | ||||
-rw-r--r-- | tests/test_django.py | 53 |
3 files changed, 57 insertions, 28 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst index a6ca79e..c2731ef 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -12,6 +12,7 @@ ChangeLog - Add support for getting/setting :mod:`factory.fuzzy`'s random state (see :issue:`175`, :issue:`185`). - Support lazy evaluation of iterables in :class:`factory.fuzzy.FuzzyChoice` (see :issue:`184`). - Support non-default databases at the factory level (see :issue:`171`) + - Make :class:`factory.django.FileField` and :class:`factory.django.ImageField` non-post_generation, i.e normal fields also available in ``save()`` (see :issue:`141`). *Bugfix:* diff --git a/factory/django.py b/factory/django.py index ee5749a..9d4cde9 100644 --- a/factory/django.py +++ b/factory/django.py @@ -144,24 +144,23 @@ class DjangoModelFactory(base.Factory): obj.save() -class FileField(declarations.PostGenerationDeclaration): +class FileField(declarations.ParameteredAttribute): """Helper to fill in django.db.models.FileField from a Factory.""" DEFAULT_FILENAME = 'example.dat' def __init__(self, **defaults): require_django() - self.defaults = defaults - super(FileField, self).__init__() + super(FileField, self).__init__(**defaults) def _make_data(self, params): """Create data for the field.""" return params.get('data', b'') - def _make_content(self, extraction_context): + def _make_content(self, extra): path = '' params = dict(self.defaults) - params.update(extraction_context.extra) + params.update(extra) if params.get('from_path') and params.get('from_file'): raise ValueError( @@ -169,12 +168,7 @@ class FileField(declarations.PostGenerationDeclaration): "be non-empty when calling factory.django.FileField." ) - if extraction_context.did_extract: - # Should be a django.core.files.File - content = extraction_context.value - path = content.name - - elif params.get('from_path'): + if params.get('from_path'): path = params['from_path'] f = open(path, 'rb') content = django_files.File(f, name=path) @@ -196,19 +190,12 @@ class FileField(declarations.PostGenerationDeclaration): filename = params.get('filename', default_filename) return filename, content - def call(self, obj, create, extraction_context): + def evaluate(self, sequence, obj, create, extra=None, containers=()): """Fill in the field.""" - if extraction_context.did_extract and extraction_context.value is None: - # User passed an empty value, don't fill - return - filename, content = self._make_content(extraction_context) - field_file = getattr(obj, extraction_context.for_field) - try: - field_file.save(filename, content, save=create) - finally: - content.file.close() - return field_file + filename, content = self._make_content(extra) + print("Returning file with filename=%r, contents=%r" % (filename, content)) + return django_files.File(content.file, filename) class ImageField(FileField): diff --git a/tests/test_django.py b/tests/test_django.py index a8f1f77..cf80edb 100644 --- a/tests/test_django.py +++ b/tests/test_django.py @@ -364,6 +364,9 @@ class DjangoFileFieldTestCase(unittest.TestCase): o = WithFileFactory.build() self.assertIsNone(o.pk) self.assertEqual(b'', o.afile.read()) + self.assertEqual('example.dat', o.afile.name) + + o.save() self.assertEqual('django/example.dat', o.afile.name) def test_default_create(self): @@ -375,19 +378,26 @@ class DjangoFileFieldTestCase(unittest.TestCase): def test_with_content(self): o = WithFileFactory.build(afile__data='foo') self.assertIsNone(o.pk) + + # Django only allocates the full path on save() + o.save() self.assertEqual(b'foo', o.afile.read()) self.assertEqual('django/example.dat', o.afile.name) def test_with_file(self): with open(testdata.TESTFILE_PATH, 'rb') as f: o = WithFileFactory.build(afile__from_file=f) - self.assertIsNone(o.pk) + o.save() + self.assertEqual(b'example_data\n', o.afile.read()) self.assertEqual('django/example.data', o.afile.name) def test_with_path(self): o = WithFileFactory.build(afile__from_path=testdata.TESTFILE_PATH) self.assertIsNone(o.pk) + + # Django only allocates the full path on save() + o.save() self.assertEqual(b'example_data\n', o.afile.read()) self.assertEqual('django/example.data', o.afile.name) @@ -397,7 +407,9 @@ class DjangoFileFieldTestCase(unittest.TestCase): afile__from_file=f, afile__from_path='' ) - self.assertIsNone(o.pk) + # Django only allocates the full path on save() + o.save() + self.assertEqual(b'example_data\n', o.afile.read()) self.assertEqual('django/example.data', o.afile.name) @@ -407,6 +419,9 @@ class DjangoFileFieldTestCase(unittest.TestCase): afile__from_file=None, ) self.assertIsNone(o.pk) + + # Django only allocates the full path on save() + o.save() self.assertEqual(b'example_data\n', o.afile.read()) self.assertEqual('django/example.data', o.afile.name) @@ -422,14 +437,21 @@ class DjangoFileFieldTestCase(unittest.TestCase): afile__filename='example.foo', ) self.assertIsNone(o.pk) + + # Django only allocates the full path on save() + o.save() self.assertEqual(b'example_data\n', o.afile.read()) self.assertEqual('django/example.foo', o.afile.name) def test_existing_file(self): o1 = WithFileFactory.build(afile__from_path=testdata.TESTFILE_PATH) + o1.save() + self.assertEqual('django/example.data', o1.afile.name) - o2 = WithFileFactory.build(afile=o1.afile) + o2 = WithFileFactory.build(afile__from_file=o1.afile) self.assertIsNone(o2.pk) + o2.save() + self.assertEqual(b'example_data\n', o2.afile.read()) self.assertNotEqual('django/example.data', o2.afile.name) self.assertRegexpMatches(o2.afile.name, r'django/example_\w+.data') @@ -453,6 +475,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): def test_default_build(self): o = WithImageFactory.build() self.assertIsNone(o.pk) + o.save() + self.assertEqual(100, o.animage.width) self.assertEqual(100, o.animage.height) self.assertEqual('django/example.jpg', o.animage.name) @@ -460,6 +484,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): def test_default_create(self): o = WithImageFactory.create() self.assertIsNotNone(o.pk) + o.save() + self.assertEqual(100, o.animage.width) self.assertEqual(100, o.animage.height) self.assertEqual('django/example.jpg', o.animage.name) @@ -467,6 +493,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): def test_with_content(self): o = WithImageFactory.build(animage__width=13, animage__color='red') self.assertIsNone(o.pk) + o.save() + self.assertEqual(13, o.animage.width) self.assertEqual(13, o.animage.height) self.assertEqual('django/example.jpg', o.animage.name) @@ -480,6 +508,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): def test_gif(self): o = WithImageFactory.build(animage__width=13, animage__color='blue', animage__format='GIF') self.assertIsNone(o.pk) + o.save() + self.assertEqual(13, o.animage.width) self.assertEqual(13, o.animage.height) self.assertEqual('django/example.jpg', o.animage.name) @@ -493,7 +523,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): def test_with_file(self): with open(testdata.TESTIMAGE_PATH, 'rb') as f: o = WithImageFactory.build(animage__from_file=f) - self.assertIsNone(o.pk) + o.save() + # Image file for a 42x42 green jpeg: 301 bytes long. self.assertEqual(301, len(o.animage.read())) self.assertEqual('django/example.jpeg', o.animage.name) @@ -501,6 +532,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): def test_with_path(self): o = WithImageFactory.build(animage__from_path=testdata.TESTIMAGE_PATH) self.assertIsNone(o.pk) + o.save() + # Image file for a 42x42 green jpeg: 301 bytes long. self.assertEqual(301, len(o.animage.read())) self.assertEqual('django/example.jpeg', o.animage.name) @@ -511,7 +544,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): animage__from_file=f, animage__from_path='' ) - self.assertIsNone(o.pk) + o.save() + # Image file for a 42x42 green jpeg: 301 bytes long. self.assertEqual(301, len(o.animage.read())) self.assertEqual('django/example.jpeg', o.animage.name) @@ -522,6 +556,8 @@ class DjangoImageFieldTestCase(unittest.TestCase): animage__from_file=None, ) self.assertIsNone(o.pk) + o.save() + # Image file for a 42x42 green jpeg: 301 bytes long. self.assertEqual(301, len(o.animage.read())) self.assertEqual('django/example.jpeg', o.animage.name) @@ -538,15 +574,20 @@ class DjangoImageFieldTestCase(unittest.TestCase): animage__filename='example.foo', ) self.assertIsNone(o.pk) + o.save() + # Image file for a 42x42 green jpeg: 301 bytes long. self.assertEqual(301, len(o.animage.read())) self.assertEqual('django/example.foo', o.animage.name) def test_existing_file(self): o1 = WithImageFactory.build(animage__from_path=testdata.TESTIMAGE_PATH) + o1.save() - o2 = WithImageFactory.build(animage=o1.animage) + o2 = WithImageFactory.build(animage__from_file=o1.animage) self.assertIsNone(o2.pk) + o2.save() + # Image file for a 42x42 green jpeg: 301 bytes long. self.assertEqual(301, len(o2.animage.read())) self.assertNotEqual('django/example.jpeg', o2.animage.name) |