You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.6 KiB
74 lines
2.6 KiB
import imghdr
|
|
import base64
|
|
import six
|
|
import uuid
|
|
|
|
from django.core.files.base import ContentFile
|
|
|
|
from rest_framework import serializers, viewsets
|
|
from rest_framework.response import Response
|
|
|
|
# https://gist.github.com/ivlevdenis/a0c8f5b472b6b8550bbb016c6a30e0be
|
|
|
|
|
|
class ExtendViewSet(object):
|
|
"""
|
|
This viewset mixin class with extended options list.
|
|
"""
|
|
permission_map = {}
|
|
throttle_scope_map = {}
|
|
serializer_class_map = {}
|
|
|
|
def get_serializer_class(self):
|
|
ser = self.serializer_class_map.get(self.action, None)
|
|
self.serializer_class = ser or self.serializer_class
|
|
return super().get_serializer_class()
|
|
|
|
def initialize_request(self, request, *args, **kwargs):
|
|
request = super().initialize_request(request, *args, **kwargs)
|
|
throttle_scope = self.throttle_scope_map.get(self.action, None)
|
|
cls_throttle_scope = getattr(self, 'throttle_scope', None)
|
|
self.throttle_scope = throttle_scope or cls_throttle_scope or ''
|
|
return request
|
|
|
|
def get_permissions(self):
|
|
perms = self.permission_map.get(self.action, None)
|
|
if perms and not isinstance(perms, (tuple, list)):
|
|
perms = [perms, ]
|
|
self.permission_classes = perms or self.permission_classes
|
|
return super().get_permissions()
|
|
|
|
def options(self, request, *args, **kwargs):
|
|
if self.metadata_class is None:
|
|
return self.http_method_not_allowed(request, *args, **kwargs)
|
|
data = self.metadata_class().determine_metadata(request, self)
|
|
data['actions']['GET'] = self.query_metadata
|
|
return Response(data, status=status.HTTP_200_OK)
|
|
|
|
|
|
class ExtendedModelViewSet(ExtendViewSet, viewsets.ModelViewSet):
|
|
pass
|
|
|
|
|
|
class Base64ImageField(serializers.ImageField):
|
|
|
|
def to_internal_value(self, data):
|
|
if isinstance(data, six.string_types):
|
|
if 'data:' in data and ';base64,' in data:
|
|
header, data = data.split(';base64,')
|
|
try:
|
|
decoded_file = base64.b64decode(data)
|
|
except TypeError:
|
|
self.fail('invalid_image')
|
|
|
|
file_name = str(uuid.uuid4())[:12]
|
|
file_extension = self.get_file_extension(
|
|
file_name, decoded_file)
|
|
complete_file_name = "%s.%s" % (file_name, file_extension,)
|
|
data = ContentFile(decoded_file, name=complete_file_name)
|
|
return super().to_internal_value(data)
|
|
|
|
def get_file_extension(self, file_name, decoded_file):
|
|
extension = imghdr.what(file_name, decoded_file)
|
|
extension = "jpg" if extension == "jpeg" else extension
|
|
return extension
|
|
|