34 lines
1.2 KiB
Python
34 lines
1.2 KiB
Python
from channels.middleware import BaseMiddleware
|
|
from channels.db import database_sync_to_async
|
|
from django.contrib.auth.models import AnonymousUser
|
|
from rest_framework_simplejwt.tokens import AccessToken
|
|
from django.contrib.auth import get_user_model
|
|
|
|
User = get_user_model()
|
|
|
|
class TokenAuthMiddleware(BaseMiddleware):
|
|
async def __call__(self, scope, receive, send):
|
|
try:
|
|
token = self.get_token_from_scope(scope)
|
|
if token:
|
|
user = await self.get_user_from_token(token)
|
|
scope['user'] = user
|
|
else:
|
|
scope['user'] = AnonymousUser()
|
|
except Exception:
|
|
scope['user'] = AnonymousUser()
|
|
|
|
return await super().__call__(scope, receive, send)
|
|
|
|
def get_token_from_scope(self, scope):
|
|
# 从查询参数中获取token
|
|
query_string = scope.get('query_string', b'').decode()
|
|
query_params = dict(qp.split('=') for qp in query_string.split('&') if qp)
|
|
return query_params.get('token')
|
|
|
|
@database_sync_to_async
|
|
def get_user_from_token(self, token):
|
|
access_token = AccessToken(token)
|
|
user_id = access_token['user_id']
|
|
return User.objects.get(id=user_id)
|