145 lines
6.0 KiB
Python
145 lines
6.0 KiB
Python
from django.utils.deprecation import MiddlewareMixin
|
||
from rest_framework.response import Response
|
||
from rest_framework.exceptions import APIException, ValidationError
|
||
import json
|
||
|
||
class StandardResponseMiddleware(MiddlewareMixin):
|
||
"""
|
||
统一API响应格式的中间件
|
||
|
||
为所有DRF响应添加标准格式:
|
||
{
|
||
"success": true/false,
|
||
"code": status_code,
|
||
"message": "响应消息",
|
||
"data": response_data
|
||
}
|
||
"""
|
||
|
||
def _extract_validation_error_message(self, data):
|
||
"""提取验证错误信息为可读字符串"""
|
||
if not isinstance(data, dict):
|
||
return str(data)
|
||
|
||
error_messages = []
|
||
for field, errors in data.items():
|
||
if isinstance(errors, list):
|
||
for error in errors:
|
||
if isinstance(error, dict):
|
||
# 嵌套错误
|
||
nested_msg = self._extract_validation_error_message(error)
|
||
error_messages.append(f"{field}: {nested_msg}")
|
||
else:
|
||
# 简单错误
|
||
error_messages.append(f"{field}: {error}")
|
||
elif isinstance(errors, dict):
|
||
# 嵌套对象错误
|
||
nested_msg = self._extract_validation_error_message(errors)
|
||
error_messages.append(f"{field}: {nested_msg}")
|
||
else:
|
||
# 直接错误
|
||
error_messages.append(f"{field}: {errors}")
|
||
|
||
return ";".join(error_messages)
|
||
|
||
def process_response(self, request, response):
|
||
# 排除Swagger文档的响应
|
||
if request.path.startswith('/swagger/') or request.path.startswith('/redoc/'):
|
||
return response
|
||
|
||
# 只处理REST framework的响应和JSON响应
|
||
if hasattr(response, 'data'):
|
||
# 检查是否已经是标准格式
|
||
if isinstance(response.data, dict) and all(k in response.data for k in ['success', 'code']):
|
||
# 已经是标准格式,无需处理
|
||
return response
|
||
|
||
# 获取状态码
|
||
status_code = response.status_code
|
||
|
||
# 判断请求是否成功
|
||
success = 200 <= status_code < 300
|
||
|
||
# 构建标准响应数据
|
||
standard_data = {
|
||
'success': success,
|
||
'code': status_code,
|
||
'message': ''
|
||
}
|
||
|
||
# 处理不同类型的响应数据
|
||
if isinstance(response.data, dict):
|
||
# 检查是否为验证错误(状态码400)
|
||
if status_code == 400 and not response.data.get('message') and not response.data.get('detail'):
|
||
# 尝试从字段错误中提取消息
|
||
error_message = self._extract_validation_error_message(response.data)
|
||
if error_message:
|
||
standard_data['message'] = error_message
|
||
# 这里可以决定是否保留原始错误数据
|
||
standard_data['data'] = response.data
|
||
else:
|
||
# 字典类型的响应
|
||
message = response.data.get('message', '')
|
||
if not message and 'detail' in response.data:
|
||
message = response.data.pop('detail')
|
||
|
||
standard_data['message'] = message
|
||
|
||
# 移除message,避免重复
|
||
if 'message' in response.data:
|
||
response.data.pop('message')
|
||
|
||
# 检查是否是分页数据
|
||
if 'count' in response.data and 'results' in response.data:
|
||
# 这是分页响应,保留分页信息
|
||
standard_data['data'] = response.data
|
||
elif response.data and response.data != {'detail': message}:
|
||
# 其他字典类型数据
|
||
standard_data['data'] = response.data
|
||
elif isinstance(response.data, list):
|
||
# 列表类型的响应(如视图集的list方法)
|
||
standard_data['message'] = '获取成功'
|
||
standard_data['data'] = response.data
|
||
else:
|
||
# 其他类型的响应(如字符串、数字等)
|
||
standard_data['message'] = '获取成功'
|
||
standard_data['data'] = response.data
|
||
|
||
# 更新响应数据
|
||
response.data = standard_data
|
||
|
||
# 确保content也更新
|
||
try:
|
||
response.content = json.dumps(response.data)
|
||
response['Content-Type'] = 'application/json'
|
||
except:
|
||
pass
|
||
|
||
return response
|
||
|
||
def process_exception(self, request, exception):
|
||
# 处理API异常,返回标准格式错误信息
|
||
if isinstance(exception, APIException):
|
||
# 提取详细的错误消息
|
||
message = str(exception)
|
||
data = None
|
||
|
||
# 如果是验证错误,尝试提取更友好的错误消息
|
||
if isinstance(exception, ValidationError) and isinstance(exception.detail, dict):
|
||
message = self._extract_validation_error_message(exception.detail)
|
||
data = exception.detail
|
||
|
||
# 构建标准响应
|
||
response_data = {
|
||
'success': False,
|
||
'code': exception.status_code,
|
||
'message': message
|
||
}
|
||
|
||
# 可选:包含原始的错误详情
|
||
if data:
|
||
response_data['data'] = data
|
||
|
||
return Response(response_data, status=exception.status_code)
|
||
|
||
return None |