50 lines
1.9 KiB
Python
50 lines
1.9 KiB
Python
from channels.middleware import BaseMiddleware
|
|
from channels.db import database_sync_to_async
|
|
|
|
class TokenAuthMiddleware(BaseMiddleware):
|
|
async def __call__(self, scope, receive, send):
|
|
# 从 headers 中获取 token
|
|
headers = dict(scope['headers'])
|
|
token = headers.get(b'authorization', b'').decode('utf-8')
|
|
|
|
if token:
|
|
# 移除 'Bearer ' 前缀(如果存在)
|
|
if token.startswith('Bearer '):
|
|
token = token[7:]
|
|
|
|
# 验证 token 并获取用户
|
|
scope['user'] = await self.get_user(token)
|
|
else:
|
|
# 将导入移到函数内部
|
|
from django.contrib.auth.models import AnonymousUser
|
|
scope['user'] = AnonymousUser()
|
|
|
|
return await super().__call__(scope, receive, send)
|
|
|
|
@database_sync_to_async
|
|
def get_user(self, token):
|
|
try:
|
|
# 使用 userapp 的 token 验证
|
|
from userapp.authentication import RedisTokenAuthentication
|
|
from django.contrib.auth.models import AnonymousUser
|
|
|
|
# 创建一个模拟的 request 对象
|
|
class MockRequest:
|
|
def __init__(self, token):
|
|
self.headers = {'Authorization': f'Bearer {token}'}
|
|
|
|
# 使用 RedisTokenAuthentication 验证 token
|
|
auth = RedisTokenAuthentication()
|
|
result = auth.authenticate(MockRequest(token))
|
|
|
|
if result is None:
|
|
return AnonymousUser()
|
|
|
|
user, _ = result
|
|
return user
|
|
|
|
except Exception as e:
|
|
print(f"Token authentication error: {str(e)}")
|
|
# 将导入移到函数内部
|
|
from django.contrib.auth.models import AnonymousUser
|
|
return AnonymousUser() |