summaryrefslogtreecommitdiff
path: root/factory
diff options
context:
space:
mode:
Diffstat (limited to 'factory')
-rw-r--r--factory/django.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/factory/django.py b/factory/django.py
index 3eabca1..e0a744a 100644
--- a/factory/django.py
+++ b/factory/django.py
@@ -21,13 +21,21 @@
# THE SOFTWARE.
+from __future__ import absolute_import
from __future__ import unicode_literals
+import os
"""factory_boy extensions for use with the Django framework."""
+try:
+ from django.core import files as django_files
+except ImportError as e: # pragma: no cover
+ django_files = None
+ import_failure = e
from . import base
+from . import declarations
class DjangoModelFactory(base.Factory):
@@ -100,3 +108,61 @@ class DjangoModelFactory(base.Factory):
obj.save()
+class FileField(declarations.PostGenerationDeclaration):
+ """Helper to fill in django.db.models.FileField from a Factory."""
+
+ def __init__(self, *args, **kwargs):
+ if django_files is None: # pragma: no cover
+ raise import_failure
+ super(FileField, self).__init__(*args, **kwargs)
+
+ def _make_content(self, extraction_context):
+ path = ''
+ params = extraction_context.extra
+
+ if params.get('from_path') and params.get('from_file'):
+ raise ValueError(
+ "At most one argument from 'from_file' and 'from_path' should "
+ "be non-empty when calling factory.django.FileField."
+ )
+
+ if extraction_context.did_extract:
+ # Should be a django.core.files.File
+ content = extraction_context.value
+ path = content.name
+
+ elif params.get('from_path'):
+ path = params['from_path']
+ f = open(path, 'rb')
+ content = django_files.File(f, name=path)
+
+ elif params.get('from_file'):
+ f = params['from_file']
+ content = django_files.File(f)
+ path = content.name
+
+ else:
+ data = params.get('data', '')
+ content = django_files.base.ContentFile(data)
+
+ if path:
+ default_filename = os.path.basename(path)
+ else:
+ default_filename = 'example.dat'
+
+ filename = params.get('filename', default_filename)
+ return filename, content
+
+ def call(self, obj, create, extraction_context):
+ """Fill in the field."""
+ if extraction_context.did_extract and extraction_context.value is None:
+ # User passed an empty value, don't fill
+ return
+
+ filename, content = self._make_content(extraction_context)
+ field_file = getattr(obj, extraction_context.for_field)
+ try:
+ field_file.save(filename, content, save=create)
+ finally:
+ content.file.close()
+ return field_file