django-rest-framework Serializers Speed up serializers queries


Example

Let's say we have model Travel with many related fields:

class Travel(models.Model):

    tags = models.ManyToManyField(
        Tag,
        related_name='travels', )
    route_places = models.ManyToManyField(
        RoutePlace,
        related_name='travels', )
    coordinate = models.ForeignKey(
        Coordinate,
        related_name='travels', )
    date_start = models.DateField()

And we want to build CRUD in /travels via view ViewSet.
Here is the simple viewset:

class TravelViewset(viewsets.ModelViewSet):

    queryset = Travel.objects.all()
    serializer_class = TravelSerializer

Problem with this ViewSet is we have many related fields in our Travel model, so Django will hit db for every Travel instance. We can call select_related and prefetch_related directly in queryset attribute, but what if we want to separate serializers for list, retrieve, create.. actions of ViewSet.
So we can put this logic in one mixin and inherit from it:

class QuerySerializerMixin(object):
    PREFETCH_FIELDS = [] # Here is for M2M fields
    RELATED_FIELDS = [] # Here is for ForeignKeys

    @classmethod
    def get_related_queries(cls, queryset):
        # This method we will use in our ViewSet
        # for modify queryset, based on RELATED_FIELDS and PREFETCH_FIELDS
        if cls.RELATED_FIELDS:
            queryset = queryset.select_related(*cls.RELATED_FIELDS)
        if cls.PREFETCH_FIELDS:
            queryset = queryset.prefetch_related(*cls.PREFETCH_FIELDS) 
        return queryset


    class TravelListSerializer(QuerySerializerMixin, serializers.ModelSerializer):
    
        PREFETCH_FIELDS = ['tags'']
        RELATED_FIELDS = ['coordinate']
        # I omit fields and Meta declare for this example


    class TravelRetrieveSerializer(QuerySerializerMixin, serializers.ModelSerializer):
    
        PREFETCH_FIELDS = ['tags', 'route_places']

Now rewrite our ViewSet with new serializers

class TravelViewset(viewsets.ModelViewSet):

    queryset = Travel.objects.all()
        
    def get_serializer_class():
        if self.action == 'retrieve':
            return TravelRetrieveSerializer
        elif self.action == 'list':
            return TravelListSerializer
        else:
            return SomeDefaultSerializer

        
    def get_queryset(self):
        # This method return serializer class
        # which we pass in class method of serializer class
        # which is also return by get_serializer()
        q = super(TravelViewset, self).get_queryset()
        serializer = self.get_serializer()
        return serializer.get_related_queries(q)