import imghdr import base64 import six import uuid from django.core.files.base import ContentFile from rest_framework import serializers, viewsets from rest_framework.response import Response # https://gist.github.com/ivlevdenis/a0c8f5b472b6b8550bbb016c6a30e0be class ExtendViewSet(object): """ This viewset mixin class with extended options list. """ permission_map = {} throttle_scope_map = {} serializer_class_map = {} def get_serializer_class(self): ser = self.serializer_class_map.get(self.action, None) self.serializer_class = ser or self.serializer_class return super().get_serializer_class() def initialize_request(self, request, *args, **kwargs): request = super().initialize_request(request, *args, **kwargs) throttle_scope = self.throttle_scope_map.get(self.action, None) cls_throttle_scope = getattr(self, 'throttle_scope', None) self.throttle_scope = throttle_scope or cls_throttle_scope or '' return request def get_permissions(self): perms = self.permission_map.get(self.action, None) if perms and not isinstance(perms, (tuple, list)): perms = [perms, ] self.permission_classes = perms or self.permission_classes return super().get_permissions() def options(self, request, *args, **kwargs): if self.metadata_class is None: return self.http_method_not_allowed(request, *args, **kwargs) data = self.metadata_class().determine_metadata(request, self) data['actions']['GET'] = self.query_metadata return Response(data, status=status.HTTP_200_OK) class ExtendedModelViewSet(ExtendViewSet, viewsets.ModelViewSet): pass class Base64ImageField(serializers.ImageField): def to_internal_value(self, data): if isinstance(data, six.string_types): if 'data:' in data and ';base64,' in data: header, data = data.split(';base64,') try: decoded_file = base64.b64decode(data) except TypeError: self.fail('invalid_image') file_name = str(uuid.uuid4())[:12] file_extension = self.get_file_extension( file_name, decoded_file) complete_file_name = "%s.%s" % (file_name, file_extension,) data = ContentFile(decoded_file, name=complete_file_name) return super().to_internal_value(data) def get_file_extension(self, file_name, decoded_file): extension = imghdr.what(file_name, decoded_file) extension = "jpg" if extension == "jpeg" else extension return extension