diff options
Diffstat (limited to 'factory')
-rw-r--r-- | factory/django.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/factory/django.py b/factory/django.py index 3eabca1..e0a744a 100644 --- a/factory/django.py +++ b/factory/django.py @@ -21,13 +21,21 @@ # THE SOFTWARE. +from __future__ import absolute_import from __future__ import unicode_literals +import os """factory_boy extensions for use with the Django framework.""" +try: + from django.core import files as django_files +except ImportError as e: # pragma: no cover + django_files = None + import_failure = e from . import base +from . import declarations class DjangoModelFactory(base.Factory): @@ -100,3 +108,61 @@ class DjangoModelFactory(base.Factory): obj.save() +class FileField(declarations.PostGenerationDeclaration): + """Helper to fill in django.db.models.FileField from a Factory.""" + + def __init__(self, *args, **kwargs): + if django_files is None: # pragma: no cover + raise import_failure + super(FileField, self).__init__(*args, **kwargs) + + def _make_content(self, extraction_context): + path = '' + params = extraction_context.extra + + if params.get('from_path') and params.get('from_file'): + raise ValueError( + "At most one argument from 'from_file' and 'from_path' should " + "be non-empty when calling factory.django.FileField." + ) + + if extraction_context.did_extract: + # Should be a django.core.files.File + content = extraction_context.value + path = content.name + + elif params.get('from_path'): + path = params['from_path'] + f = open(path, 'rb') + content = django_files.File(f, name=path) + + elif params.get('from_file'): + f = params['from_file'] + content = django_files.File(f) + path = content.name + + else: + data = params.get('data', '') + content = django_files.base.ContentFile(data) + + if path: + default_filename = os.path.basename(path) + else: + default_filename = 'example.dat' + + filename = params.get('filename', default_filename) + return filename, content + + def call(self, obj, create, extraction_context): + """Fill in the field.""" + if extraction_context.did_extract and extraction_context.value is None: + # User passed an empty value, don't fill + return + + filename, content = self._make_content(extraction_context) + field_file = getattr(obj, extraction_context.for_field) + try: + field_file.save(filename, content, save=create) + finally: + content.file.close() + return field_file |