from channels.middleware import BaseMiddleware from channels.db import database_sync_to_async from django.contrib.auth.models import AnonymousUser from rest_framework_simplejwt.tokens import AccessToken from django.contrib.auth import get_user_model User = get_user_model() class TokenAuthMiddleware(BaseMiddleware): async def __call__(self, scope, receive, send): try: token = self.get_token_from_scope(scope) if token: user = await self.get_user_from_token(token) scope['user'] = user else: scope['user'] = AnonymousUser() except Exception: scope['user'] = AnonymousUser() return await super().__call__(scope, receive, send) def get_token_from_scope(self, scope): # 从查询参数中获取token query_string = scope.get('query_string', b'').decode() query_params = dict(qp.split('=') for qp in query_string.split('&') if qp) return query_params.get('token') @database_sync_to_async def get_user_from_token(self, token): access_token = AccessToken(token) user_id = access_token['user_id'] return User.objects.get(id=user_id)