516 lines
21 KiB
Python
516 lines
21 KiB
Python
from rest_framework.views import APIView
|
||
from rest_framework.response import Response
|
||
from rest_framework import status, viewsets, permissions
|
||
from django.contrib.auth.models import User
|
||
from .models import ChatMessage, Bot
|
||
from userapp.models import ParadiseUser
|
||
from .serializers import ChatMessageSerializer
|
||
from rest_framework.permissions import IsAuthenticated
|
||
from userapp.authentication import RedisTokenAuthentication
|
||
from rest_framework import serializers
|
||
from common.swagger_utils import swagger_schema
|
||
from common.responses import success_response, created_response, error_response
|
||
from drf_yasg import openapi
|
||
from drf_yasg.utils import swagger_auto_schema
|
||
import logging
|
||
import requests
|
||
import re
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class BotSerializer(serializers.ModelSerializer):
|
||
class Meta:
|
||
model = Bot
|
||
fields = ['id', 'name', 'description']
|
||
|
||
|
||
class IsAdminOrReadOnly(permissions.BasePermission):
|
||
def has_permission(self, request, view):
|
||
if request.method in permissions.SAFE_METHODS:
|
||
return request.user and request.user.is_authenticated
|
||
return request.user and request.user.is_staff
|
||
|
||
|
||
class BotViewSet(viewsets.ModelViewSet):
|
||
"""
|
||
机器人(AI模型)管理接口
|
||
管理员可增删改查,普通用户只读
|
||
"""
|
||
queryset = Bot.objects.all()
|
||
serializer_class = BotSerializer
|
||
permission_classes = [IsAdminOrReadOnly]
|
||
authentication_classes = [RedisTokenAuthentication]
|
||
|
||
from .kimi import KIMI
|
||
from .audio.AudioService import get_audio_service
|
||
|
||
# Swagger Schema 定义
|
||
class ChatRequestSchema(serializers.Serializer):
|
||
message = serializers.CharField(required=True, help_text="用户发送的文本消息内容")
|
||
|
||
class ChatMessageDetailSchema(serializers.Serializer):
|
||
id = serializers.IntegerField(help_text="消息ID")
|
||
user = serializers.IntegerField(help_text="用户ID")
|
||
bot = serializers.IntegerField(help_text="机器人ID")
|
||
message = serializers.CharField(help_text="消息内容")
|
||
sender = serializers.CharField(help_text="发送者,user或bot")
|
||
message_type = serializers.CharField(help_text="消息类型")
|
||
message_audio_url = serializers.URLField(help_text="音频链接", allow_null=True, required=False)
|
||
message_video_url = serializers.URLField(help_text="视频链接", allow_null=True, required=False)
|
||
created_at = serializers.DateTimeField(help_text="创建时间")
|
||
|
||
class ChatResponseSchema(serializers.Serializer):
|
||
user_message = ChatMessageDetailSchema(help_text="用户发送的消息详情")
|
||
bot_reply = ChatMessageDetailSchema(help_text="机器人的回复消息详情")
|
||
|
||
class MultiChatRequestSchema(serializers.Serializer):
|
||
botId = serializers.IntegerField(required=True, help_text="机器人ID")
|
||
message = serializers.CharField(required=True, help_text="用户发送的消息内容")
|
||
messageType = serializers.ChoiceField(
|
||
choices=[ChatMessage.MESSAGE_TYPE_TEXT, ChatMessage.MESSAGE_TYPE_AUDIO, ChatMessage.MESSAGE_TYPE_VIDEO],
|
||
help_text="消息类型(text, audio, video)"
|
||
)
|
||
messageAudioUrl = serializers.URLField(required=False, help_text="音频消息的URL(当messageType为audio时)")
|
||
messageAudioBase64 = serializers.CharField(required=False, help_text="Base64编码的音频数据(当messageType为audio时)")
|
||
messageVideoUrl = serializers.URLField(required=False, help_text="视频消息的URL(当messageType为video时)")
|
||
messageVideoBase64 = serializers.CharField(required=False, help_text="Base64编码的视频数据(当messageType为video时)")
|
||
returnAudioAsBase64 = serializers.BooleanField(required=False, default=False, help_text="是否返回Base64编码的音频而不是URL")
|
||
|
||
|
||
class ChatBotAPIView(APIView):
|
||
"""
|
||
AI 聊天机器人接口
|
||
|
||
提供与 AI 聊天机器人的单轮对话功能。
|
||
支持发送文本消息并获取 AI 的回复。
|
||
"""
|
||
authentication_classes = [RedisTokenAuthentication] # 使用自定义的认证类
|
||
permission_classes = [IsAuthenticated] # 仅允许已认证用户访问
|
||
tags = ['ai-chat']
|
||
|
||
@swagger_schema(
|
||
request_schema=ChatRequestSchema,
|
||
responses={
|
||
201: openapi.Response('对话成功', ChatResponseSchema),
|
||
404: openapi.Response('机器人不存在', openapi.Schema(
|
||
type=openapi.TYPE_OBJECT,
|
||
properties={
|
||
'error': openapi.Schema(type=openapi.TYPE_STRING, description='错误信息')
|
||
}
|
||
))
|
||
},
|
||
operation_description="与AI机器人进行单轮对话,发送文本消息并获取回复",
|
||
tags=['AI聊天'],
|
||
security=[{'Bearer': []}],
|
||
manual_parameters=[
|
||
openapi.Parameter(
|
||
'bot_id',
|
||
openapi.IN_PATH,
|
||
description="AI机器人的唯一标识",
|
||
type=openapi.TYPE_INTEGER,
|
||
required=True
|
||
)
|
||
]
|
||
)
|
||
def post(self, request, bot_id):
|
||
"""
|
||
发送消息给 AI 聊天机器人
|
||
|
||
接收用户发送的文本消息,并返回 AI 机器人的回复。
|
||
"""
|
||
self.user = request.user
|
||
try:
|
||
self.bot = Bot.objects.get(id=bot_id)
|
||
except Exception as _:
|
||
return Response(
|
||
{"error": "Bot not found"}, status=status.HTTP_404_NOT_FOUND
|
||
)
|
||
|
||
# Check is there a message for the current user and bot,if not, create a initial message
|
||
handle_first_access(self.user, self.bot)
|
||
|
||
# create a new message for the current request
|
||
message_type = ChatMessage.MESSAGE_TYPE_TEXT
|
||
user_message = request.data.get("message")
|
||
message_audio_url = ''
|
||
message_video_url = ''
|
||
user_chat_message = create_user_message(self.user, self.bot, message_type, user_message, message_audio_url, message_video_url)
|
||
|
||
# prepare the history
|
||
history = list(map(lambda message: {'role':message.sender, 'content':message.message},
|
||
ChatMessage.objects.filter(bot=self.bot,user=self.user).all()))
|
||
|
||
bot_chat_message = ask_kimi(self.user, self.bot, history)
|
||
|
||
# serialize the messages
|
||
user_message_serializer = ChatMessageSerializer(user_chat_message)
|
||
bot_message_serializer = ChatMessageSerializer(bot_chat_message)
|
||
logger.info(bot_chat_message)
|
||
return Response(
|
||
{
|
||
"user_message": user_message_serializer.data,
|
||
"bot_reply": bot_message_serializer.data,
|
||
},
|
||
status=status.HTTP_201_CREATED,
|
||
)
|
||
|
||
class MultiChatAPIView(APIView):
|
||
"""
|
||
AI 多轮对话接口
|
||
|
||
提供与 AI 的多轮对话功能。
|
||
支持上下文关联的对话,AI 能够记住对话历史。
|
||
"""
|
||
authentication_classes = [RedisTokenAuthentication] # 使用自定义的认证类
|
||
permission_classes = [IsAuthenticated] # 仅允许已认证用户访问
|
||
tags = ['ai-multichat']
|
||
|
||
@swagger_schema(
|
||
request_schema=MultiChatRequestSchema,
|
||
responses={
|
||
201: openapi.Response('对话成功', ChatMessageDetailSchema),
|
||
400: openapi.Response('请求参数错误', openapi.Schema(
|
||
type=openapi.TYPE_OBJECT,
|
||
properties={
|
||
'error': openapi.Schema(type=openapi.TYPE_STRING, description='错误信息')
|
||
}
|
||
)),
|
||
404: openapi.Response('机器人不存在', openapi.Schema(
|
||
type=openapi.TYPE_OBJECT,
|
||
properties={
|
||
'error': openapi.Schema(type=openapi.TYPE_STRING, description='错误信息')
|
||
}
|
||
))
|
||
},
|
||
operation_description="与AI进行多轮对话,支持文本、音频和视频消息,AI能记住对话历史",
|
||
tags=['AI聊天'],
|
||
security=[{'Bearer': []}]
|
||
)
|
||
def post(self, request):
|
||
"""
|
||
发起多轮对话
|
||
|
||
支持与 AI 进行多轮对话,AI 会记住对话上下文。
|
||
"""
|
||
self.user = request.user
|
||
bot_id = request.data.get("botId")
|
||
if bot_id is None:
|
||
logger.warning("Bot id is missing in request")
|
||
return Response(
|
||
{"error": "Bot id is required"}
|
||
)
|
||
try:
|
||
self.bot = Bot.objects.get(id=bot_id)
|
||
logger.info(f"Found bot with id {bot_id}")
|
||
except Exception as e:
|
||
logger.error(f"Bot not found: {str(e)}")
|
||
return Response(
|
||
{"error": "Bot not found"}, status=status.HTTP_404_NOT_FOUND
|
||
)
|
||
|
||
# Check is there a message for the current user and bot,if not, create a initial message
|
||
handle_first_access(self.user, self.bot)
|
||
|
||
# create a new message for the current request
|
||
message_type = request.data.get("messageType")
|
||
user_message = request.data.get("message")
|
||
message_audio_url = request.data.get("messageAudioUrl")
|
||
message_audio_base64 = request.data.get("messageAudioBase64")
|
||
message_video_url = request.data.get("messageVideoUrl")
|
||
message_video_base64 = request.data.get("messageVideoBase64")
|
||
return_audio_as_base64 = request.data.get("returnAudioAsBase64", False)
|
||
logger.info(f"Processing message of type {message_type} from user {self.user.id}")
|
||
|
||
user_chat_message = create_user_message(
|
||
self.user,
|
||
self.bot,
|
||
message_type,
|
||
user_message,
|
||
message_audio_url,
|
||
message_video_url,
|
||
message_audio_base64,
|
||
message_video_base64
|
||
)
|
||
|
||
# prepare the history
|
||
history = list(map(lambda message: {'role':message.sender, 'content':message.message},
|
||
ChatMessage.objects.filter(bot=self.bot,user=self.user).all()))
|
||
|
||
bot_chat_message = ask_kimi(self.user, self.bot, history)
|
||
|
||
# 对于语音请求,合成语音返回
|
||
if message_type == ChatMessage.MESSAGE_TYPE_AUDIO:
|
||
audio_ser = get_audio_service()
|
||
logger.info(f"Processing audio message: {bot_chat_message.message}")
|
||
|
||
if return_audio_as_base64:
|
||
# 生成音频并返回Base64编码
|
||
audio_data = audio_ser.synthesize_speech_raw(bot_chat_message.message)
|
||
import base64
|
||
response_message_audio_base64 = base64.b64encode(audio_data).decode('utf-8')
|
||
bot_chat_message.message_audio_url = None # 不使用URL
|
||
bot_chat_message.message_type = ChatMessage.MESSAGE_TYPE_AUDIO
|
||
bot_chat_message.save()
|
||
logger.info("Generated audio as Base64 data")
|
||
else:
|
||
# 生成音频URL并返回
|
||
response_message_audio_url = audio_ser.synthesize_speech(bot_chat_message.message)
|
||
logger.info(f"Generated audio URL: {response_message_audio_url}")
|
||
bot_chat_message.message_audio_url = response_message_audio_url
|
||
bot_chat_message.message_type = ChatMessage.MESSAGE_TYPE_AUDIO
|
||
bot_chat_message.save()
|
||
|
||
logger.info(f"Bot response: {bot_chat_message.message}")
|
||
|
||
# 构建响应数据
|
||
response_data = {
|
||
"text": bot_chat_message.message
|
||
}
|
||
|
||
# 添加用户的语音URL和识别的文字(如果是语音消息)
|
||
if message_type == ChatMessage.MESSAGE_TYPE_AUDIO:
|
||
response_data["user_message"] = user_chat_message.message
|
||
|
||
if message_audio_url:
|
||
response_data["user_audio_url"] = message_audio_url
|
||
elif message_audio_base64:
|
||
response_data["user_audio_base64"] = message_audio_base64
|
||
|
||
# 如果有音频URL,添加到响应中
|
||
if bot_chat_message.message_audio_url:
|
||
response_data["audio_url"] = bot_chat_message.message_audio_url
|
||
|
||
# 如果有返回Base64编码的音频数据
|
||
if return_audio_as_base64 and message_type == ChatMessage.MESSAGE_TYPE_AUDIO:
|
||
response_data["audio_base64"] = response_message_audio_base64
|
||
|
||
# 如果有视频URL,添加到响应中
|
||
if bot_chat_message.message_video_url:
|
||
response_data["video_url"] = bot_chat_message.message_video_url
|
||
|
||
return Response(
|
||
response_data,
|
||
status=status.HTTP_201_CREATED,
|
||
)
|
||
|
||
def ask_kimi(user: ParadiseUser, bot: Bot, history: list) -> ChatMessage:
|
||
response = KIMI(history).get_response()
|
||
|
||
# 移除开头和结尾的括号内容
|
||
# 移除开头的中英文括号内容
|
||
response = re.sub(r'^[\((][^))]*[\))][\s]*', '', response)
|
||
# 移除结尾的中英文括号内容
|
||
response = re.sub(r'[\s]*[\((][^))]*[\))]$', '', response)
|
||
|
||
chat_message = ChatMessage.objects.create(
|
||
user=user,
|
||
bot=bot,
|
||
message=response,
|
||
sender=ChatMessage.SENDER_BOT,
|
||
)
|
||
chat_message.save()
|
||
return chat_message
|
||
|
||
def handle_first_access(user: ParadiseUser, bot: Bot):
|
||
if not ChatMessage.objects.filter(user=user, bot=bot).exists():
|
||
ChatMessage.objects.create(
|
||
user=user,
|
||
bot=bot,
|
||
message=bot.description,
|
||
sender=ChatMessage.SENDER_SYSTEM,
|
||
message_type=ChatMessage.MESSAGE_TYPE_TEXT
|
||
).save()
|
||
|
||
def create_user_message(user: ParadiseUser, bot: Bot, message_type: str, user_message: str, message_audio_url: str, message_video_url: str, message_audio_base64: str = None, message_video_base64: str = None) -> ChatMessage:
|
||
if message_type == ChatMessage.MESSAGE_TYPE_TEXT:
|
||
user_chat_message = ChatMessage.objects.create(
|
||
user=user,
|
||
bot=bot,
|
||
message=user_message,
|
||
sender=ChatMessage.SENDER_USER,
|
||
message_type=ChatMessage.MESSAGE_TYPE_TEXT
|
||
)
|
||
user_chat_message.save()
|
||
logger.info(f"Created text message for user {user.id} and bot {bot.id}")
|
||
return user_chat_message
|
||
|
||
if message_type == ChatMessage.MESSAGE_TYPE_AUDIO:
|
||
# 从 URL 获取音频文件内容或处理 Base64 音频数据
|
||
audio_data = None
|
||
audio_service = get_audio_service()
|
||
|
||
# 检查是否提供了音频URL
|
||
if message_audio_url:
|
||
logger.info(f"Processing audio message from URL: {message_audio_url}")
|
||
response = requests.get(message_audio_url)
|
||
if response.status_code != 200:
|
||
logger.error(f"Failed to download audio file: {response.status_code}")
|
||
raise Exception("无法下载音频文件")
|
||
audio_data = response.content
|
||
elif message_audio_base64:
|
||
# 处理Base64编码的音频数据
|
||
logger.info("Processing audio message from Base64 data")
|
||
import base64
|
||
try:
|
||
audio_data = base64.b64decode(message_audio_base64)
|
||
except Exception as e:
|
||
logger.error(f"Failed to decode base64 audio data: {str(e)}")
|
||
raise Exception("无法解码Base64音频数据")
|
||
|
||
if not audio_data:
|
||
logger.error("No audio data provided (neither URL nor Base64)")
|
||
raise Exception("未提供音频数据")
|
||
|
||
message = audio_service.recognize_speech(audio_data)
|
||
logger.info(f"Recognized speech message: {message}")
|
||
user_chat_message = ChatMessage.objects.create(
|
||
user=user,
|
||
bot=bot,
|
||
message=message,
|
||
sender=ChatMessage.SENDER_USER,
|
||
message_type=ChatMessage.MESSAGE_TYPE_AUDIO
|
||
)
|
||
if message_audio_url:
|
||
user_chat_message.message_audio_url = message_audio_url
|
||
user_chat_message.save()
|
||
logger.info(f"Created audio message for user {user.id} and bot {bot.id}")
|
||
return user_chat_message
|
||
|
||
if message_type == ChatMessage.MESSAGE_TYPE_VIDEO:
|
||
# 处理视频消息
|
||
# 目前仅保存视频URL或转换Base64数据为文件URL的逻辑
|
||
video_url = message_video_url
|
||
|
||
# 处理Base64视频数据
|
||
if not video_url and message_video_base64:
|
||
logger.info("Processing video message from Base64 data")
|
||
import base64
|
||
try:
|
||
# 这里需要将Base64解码并保存为文件,然后返回URL
|
||
# 以下为示例代码,实际实现需要根据项目需求调整
|
||
import os
|
||
import uuid
|
||
from django.conf import settings
|
||
|
||
# 创建保存视频的目录
|
||
video_dir = os.path.join(settings.MEDIA_ROOT, 'video')
|
||
os.makedirs(video_dir, exist_ok=True)
|
||
|
||
# 生成唯一文件名
|
||
video_filename = f"{uuid.uuid4()}.mp4"
|
||
video_path = os.path.join(video_dir, video_filename)
|
||
|
||
# 解码并保存文件
|
||
video_data = base64.b64decode(message_video_base64)
|
||
with open(video_path, 'wb') as f:
|
||
f.write(video_data)
|
||
|
||
# 生成URL
|
||
video_url = f"{settings.MEDIA_URL}video/{video_filename}"
|
||
logger.info(f"Generated video URL from Base64: {video_url}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to process base64 video data: {str(e)}")
|
||
raise Exception("无法处理Base64视频数据")
|
||
|
||
if not video_url:
|
||
logger.error("No video URL or data provided")
|
||
raise Exception("未提供视频数据")
|
||
|
||
user_chat_message = ChatMessage.objects.create(
|
||
user=user,
|
||
bot=bot,
|
||
message=user_message, # 可能需要从视频中提取文本
|
||
sender=ChatMessage.SENDER_USER,
|
||
message_type=ChatMessage.MESSAGE_TYPE_VIDEO,
|
||
message_video_url=video_url
|
||
)
|
||
user_chat_message.save()
|
||
logger.info(f"Created video message for user {user.id} and bot {bot.id}")
|
||
return user_chat_message
|
||
|
||
return None
|
||
|
||
|
||
class RTCChatHistoryAPIView(APIView):
|
||
"""
|
||
RTC 语音智能体聊天记录接口
|
||
|
||
GET: 获取当前用户的 RTC 聊天历史
|
||
POST: 保存一条 RTC 聊天消息
|
||
DELETE: 清空当前用户的 RTC 聊天记录
|
||
"""
|
||
authentication_classes = [RedisTokenAuthentication]
|
||
permission_classes = [IsAuthenticated]
|
||
|
||
RTC_BOT_NAME = 'RTC_Voice_Agent'
|
||
|
||
def _get_rtc_bot(self):
|
||
try:
|
||
return Bot.objects.get(name=self.RTC_BOT_NAME)
|
||
except Bot.DoesNotExist:
|
||
return None
|
||
|
||
def get(self, request):
|
||
bot = self._get_rtc_bot()
|
||
if bot is None:
|
||
return error_response(message='RTC Bot 未配置', code=500, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
page_size = int(request.query_params.get('page_size', 50))
|
||
page_size = min(page_size, 200)
|
||
|
||
queryset = ChatMessage.objects.filter(
|
||
user=request.user,
|
||
bot=bot
|
||
).order_by('timestamp')
|
||
|
||
total = queryset.count()
|
||
messages = queryset[max(0, total - page_size):]
|
||
|
||
data = {
|
||
'messages': [
|
||
{
|
||
'id': msg.id,
|
||
'message': msg.message,
|
||
'sender': msg.sender,
|
||
'timestamp': msg.timestamp.isoformat(),
|
||
}
|
||
for msg in messages
|
||
],
|
||
'total': total,
|
||
'has_more': total > page_size,
|
||
}
|
||
return success_response(data=data)
|
||
|
||
def post(self, request):
|
||
bot = self._get_rtc_bot()
|
||
if bot is None:
|
||
return error_response(message='RTC Bot 未配置', code=500, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
message_text = request.data.get('message', '').strip()
|
||
sender = request.data.get('sender', '').strip()
|
||
|
||
if not message_text:
|
||
return error_response(message='消息内容不能为空')
|
||
if sender not in ('user', 'assistant'):
|
||
return error_response(message='sender 必须是 user 或 assistant')
|
||
|
||
chat_msg = ChatMessage.objects.create(
|
||
user=request.user,
|
||
bot=bot,
|
||
message=message_text,
|
||
sender=sender,
|
||
message_type=ChatMessage.MESSAGE_TYPE_TEXT,
|
||
)
|
||
|
||
return created_response(data={
|
||
'id': chat_msg.id,
|
||
'timestamp': chat_msg.timestamp.isoformat(),
|
||
})
|
||
|
||
def delete(self, request):
|
||
bot = self._get_rtc_bot()
|
||
if bot is None:
|
||
return error_response(message='RTC Bot 未配置', code=500, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
count, _ = ChatMessage.objects.filter(user=request.user, bot=bot).delete()
|
||
return success_response(data={'deleted': count}, message=f'已删除 {count} 条记录') |