From 7a6d7814e08f815230b1b320cd0c49045ec91306 Mon Sep 17 00:00:00 2001 From: repair-agent Date: Thu, 12 Feb 2026 14:05:51 +0800 Subject: [PATCH] fix store bug --- .env.example | 5 + apps/devices/serializers.py | 13 + apps/devices/views.py | 52 +- ...yshelf_options_story_audio_url_and_more.py | 33 + .../migrations/0003_story_shelf_nullable.py | 26 + apps/stories/models.py | 13 +- apps/stories/serializers.py | 37 +- apps/stories/services/__init__.py | 0 apps/stories/services/llm_service.py | 165 +++++ apps/stories/services/tts_service.py | 117 ++++ apps/stories/urls.py | 2 +- apps/stories/views.py | 154 ++++- .../migrations/0003_pointsrecord_and_more.py | 73 ++ apps/users/models.py | 31 + apps/users/serializers.py | 14 +- apps/users/views.py | 29 + config/settings.py | 15 + requirements.txt | 2 + tests.py | 622 +++++++++++++++++- utils/exceptions.py | 1 + utils/middleware.py | 123 ++++ utils/routers.py | 9 + 22 files changed, 1494 insertions(+), 42 deletions(-) create mode 100644 apps/stories/migrations/0002_alter_storyshelf_options_story_audio_url_and_more.py create mode 100644 apps/stories/migrations/0003_story_shelf_nullable.py create mode 100644 apps/stories/services/__init__.py create mode 100644 apps/stories/services/llm_service.py create mode 100644 apps/stories/services/tts_service.py create mode 100644 apps/users/migrations/0003_pointsrecord_and_more.py create mode 100644 utils/middleware.py create mode 100644 utils/routers.py diff --git a/.env.example b/.env.example index c92fd4b..54f97b6 100644 --- a/.env.example +++ b/.env.example @@ -20,5 +20,10 @@ OSS_ENDPOINT=oss-cn-hangzhou.aliyuncs.com OSS_BUCKET_NAME=your-bucket-name OSS_CUSTOM_DOMAIN= +# Volcengine / 火山引擎豆包 (Story Generation) +VOLCENGINE_API_KEY=your-volcengine-api-key +VOLCENGINE_API_BASE_URL=https://ark.cn-beijing.volces.com/api/v3 +VOLCENGINE_MODEL_NAME=doubao-seed-1-6-lite-251015 + # CORS (production only) CORS_ALLOWED_ORIGINS=https://your-domain.com diff --git a/apps/devices/serializers.py b/apps/devices/serializers.py index f9bc1e2..c08d7b4 100644 --- a/apps/devices/serializers.py +++ b/apps/devices/serializers.py @@ -136,3 +136,16 @@ class DeviceSettingsUpdateSerializer(serializers.Serializer): brightness = serializers.IntegerField(min_value=0, max_value=100, required=False) allow_interrupt = serializers.BooleanField(required=False) privacy_mode = serializers.BooleanField(required=False) + + +class DeviceReportStatusSerializer(serializers.Serializer): + """设备状态上报序列化器(硬件端使用)""" + mac_address = serializers.CharField(max_length=20, help_text='MAC地址') + is_online = serializers.BooleanField(required=False, help_text='是否在线') + battery = serializers.IntegerField(min_value=0, max_value=100, required=False, help_text='电量百分比') + volume = serializers.IntegerField(min_value=0, max_value=100, required=False, help_text='音量') + brightness = serializers.IntegerField(min_value=0, max_value=100, required=False, help_text='亮度') + firmware_version = serializers.CharField(max_length=20, required=False, help_text='固件版本') + + def validate_mac_address(self, value): + return value.upper().replace('-', ':') diff --git a/apps/devices/views.py b/apps/devices/views.py index 15566ac..eaf9777 100644 --- a/apps/devices/views.py +++ b/apps/devices/views.py @@ -18,6 +18,7 @@ from .serializers import ( DeviceTypeSerializer, DeviceDetailSerializer, DeviceSettingsUpdateSerializer, + DeviceReportStatusSerializer, ) @@ -182,8 +183,8 @@ class DeviceViewSet(viewsets.ViewSet): return success(data=UserDeviceSerializer(user_device).data, message='更新成功') - @action(detail=True, methods=['get']) - def detail(self, request, pk=None): + @action(detail=True, methods=['get'], url_path='detail') + def device_detail(self, request, pk=None): """ 获取设备详情 GET /api/v1/devices/{user_device_id}/detail/ @@ -260,3 +261,50 @@ class DeviceViewSet(viewsets.ViewSet): # TODO: 通过设备通信协议下发 WiFi 配置(password 不存库) return success(message='WiFi 配置成功') + + @action(detail=False, methods=['post'], url_path='report-status', + authentication_classes=[], permission_classes=[AllowAny]) + def report_status(self, request): + """ + 设备状态上报(硬件端调用,无需认证) + POST /api/v1/devices/report-status + """ + serializer = DeviceReportStatusSerializer(data=request.data) + if not serializer.is_valid(): + return error(message=str(serializer.errors)) + + data = serializer.validated_data + mac = data.pop('mac_address') + + try: + device = Device.objects.get(mac_address=mac) + except Device.DoesNotExist: + return error(code=ErrorCode.DEVICE_NOT_FOUND, message='设备不存在') + + # 更新 Device 表字段 + device_fields_updated = False + for field in ('is_online', 'battery', 'firmware_version'): + if field in data: + setattr(device, field, data[field]) + device_fields_updated = True + + if data.get('is_online'): + from django.utils import timezone + device.last_online_at = timezone.now() + device_fields_updated = True + + if device_fields_updated: + device.save() + + # 更新 DeviceSettings 表字段(仅 volume / brightness) + settings_data = {k: data[k] for k in ('volume', 'brightness') if k in data} + if settings_data: + settings_obj, _ = DeviceSettings.objects.get_or_create(device=device) + for field, value in settings_data.items(): + setattr(settings_obj, field, value) + settings_obj.save() + + return success( + data={'device_id': device.id, 'sn': device.sn}, + message='状态上报成功' + ) diff --git a/apps/stories/migrations/0002_alter_storyshelf_options_story_audio_url_and_more.py b/apps/stories/migrations/0002_alter_storyshelf_options_story_audio_url_and_more.py new file mode 100644 index 0000000..431944d --- /dev/null +++ b/apps/stories/migrations/0002_alter_storyshelf_options_story_audio_url_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2 on 2026-02-12 02:53 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("stories", "0001_initial"), + ] + + operations = [ + migrations.AlterModelOptions( + name="storyshelf", + options={ + "ordering": ["created_at"], + "verbose_name": "故事书架", + "verbose_name_plural": "故事书架", + }, + ), + migrations.AddField( + model_name="story", + name="audio_url", + field=models.URLField( + blank=True, default="", max_length=500, verbose_name="音频URL" + ), + ), + migrations.AddField( + model_name="storyshelf", + name="capacity", + field=models.IntegerField(default=10, verbose_name="容量上限"), + ), + ] diff --git a/apps/stories/migrations/0003_story_shelf_nullable.py b/apps/stories/migrations/0003_story_shelf_nullable.py new file mode 100644 index 0000000..b959c1c --- /dev/null +++ b/apps/stories/migrations/0003_story_shelf_nullable.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2 on 2026-02-12 03:14 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("stories", "0002_alter_storyshelf_options_story_audio_url_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="story", + name="shelf", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="stories", + to="stories.storyshelf", + verbose_name="所属书架", + ), + ), + ] diff --git a/apps/stories/models.py b/apps/stories/models.py index 7f58a43..8c2d651 100644 --- a/apps/stories/models.py +++ b/apps/stories/models.py @@ -13,6 +13,7 @@ class StoryShelf(models.Model): related_name='story_shelves', verbose_name='用户' ) name = models.CharField('书架名称', max_length=100) + capacity = models.IntegerField('容量上限', default=10) is_locked = models.BooleanField('是否加锁', default=False) unlock_cost = models.IntegerField('解锁积分', default=0) created_at = models.DateTimeField('创建时间', auto_now_add=True) @@ -22,11 +23,15 @@ class StoryShelf(models.Model): db_table = 'story_shelf' verbose_name = '故事书架' verbose_name_plural = verbose_name - ordering = ['-created_at'] + ordering = ['created_at'] def __str__(self): return self.name + @property + def is_full(self): + return self.stories.count() >= self.capacity + class Story(models.Model): """故事""" @@ -41,12 +46,14 @@ class Story(models.Model): related_name='stories', verbose_name='用户' ) shelf = models.ForeignKey( - StoryShelf, on_delete=models.CASCADE, - related_name='stories', verbose_name='所属书架' + StoryShelf, on_delete=models.SET_NULL, + related_name='stories', verbose_name='所属书架', + null=True, blank=True, ) title = models.CharField('标题', max_length=200) content = models.TextField('内容', blank=True, default='') cover_url = models.URLField('封面URL', max_length=500, blank=True, default='') + audio_url = models.URLField('音频URL', max_length=500, blank=True, default='') has_video = models.BooleanField('是否有视频', default=False) video_url = models.URLField('视频URL', max_length=500, blank=True, default='') generation_mode = models.CharField( diff --git a/apps/stories/serializers.py b/apps/stories/serializers.py index a26c393..94ed49f 100644 --- a/apps/stories/serializers.py +++ b/apps/stories/serializers.py @@ -11,7 +11,7 @@ class StoryShelfSerializer(serializers.ModelSerializer): class Meta: model = StoryShelf - fields = ['id', 'name', 'is_locked', 'unlock_cost', 'story_count', 'created_at'] + fields = ['id', 'name', 'capacity', 'is_locked', 'unlock_cost', 'story_count', 'created_at'] read_only_fields = ['id', 'created_at'] @@ -25,8 +25,8 @@ class StoryListSerializer(serializers.ModelSerializer): class Meta: model = Story - fields = ['id', 'title', 'cover_url', 'content', 'has_video', - 'video_url', 'created_at'] + fields = ['id', 'title', 'cover_url', 'content', 'audio_url', + 'has_video', 'video_url', 'created_at'] class StoryDetailSerializer(serializers.ModelSerializer): @@ -34,17 +34,32 @@ class StoryDetailSerializer(serializers.ModelSerializer): class Meta: model = Story - fields = ['id', 'title', 'content', 'cover_url', 'has_video', - 'video_url', 'generation_mode', 'prompt', 'shelf', - 'created_at', 'updated_at'] + fields = ['id', 'title', 'content', 'cover_url', 'audio_url', + 'has_video', 'video_url', 'generation_mode', 'prompt', + 'shelf', 'created_at', 'updated_at'] + + +class CreateStorySerializer(serializers.Serializer): + """保存故事序列化器""" + title = serializers.CharField(max_length=200) + content = serializers.CharField() + shelf_id = serializers.IntegerField() + cover_url = serializers.URLField(required=False, allow_blank=True, default='') + generation_mode = serializers.ChoiceField( + choices=['ai', 'manual'], default='ai' + ) + prompt = serializers.CharField(required=False, allow_blank=True, default='') class GenerateStorySerializer(serializers.Serializer): """生成故事序列化器""" - mode = serializers.ChoiceField( - choices=['random', 'keyword', 'theme'], - default='random' + characters = serializers.ListField( + child=serializers.CharField(), default=[] + ) + scenes = serializers.ListField( + child=serializers.CharField(), default=[] + ) + props = serializers.ListField( + child=serializers.CharField(), default=[] ) - prompt = serializers.CharField(required=False, allow_blank=True, default='') - theme = serializers.CharField(required=False, allow_blank=True, default='') shelf_id = serializers.IntegerField(required=False, allow_null=True) diff --git a/apps/stories/services/__init__.py b/apps/stories/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/stories/services/llm_service.py b/apps/stories/services/llm_service.py new file mode 100644 index 0000000..d9ca74a --- /dev/null +++ b/apps/stories/services/llm_service.py @@ -0,0 +1,165 @@ +""" +LLM 故事生成服务 — 基于火山引擎豆包大模型 +""" +import json +import logging +from django.conf import settings + +try: + from openai import OpenAI + OPENAI_AVAILABLE = True +except ImportError: + OPENAI_AVAILABLE = False + +logger = logging.getLogger(__name__) + +STORY_SYSTEM_PROMPT = """# 角色 + +你是「卡皮巴拉故事工坊」的首席故事大师。你为 3-8 岁的小朋友创作原创童话故事。 + +# 任务 + +根据用户提供的**角色、场景、道具**素材,创作一个完整的儿童故事。 + +# 输出格式 + +你 **必须** 只返回如下 JSON,不要返回任何其他内容(不要 markdown 代码块,不要解释): + +{"title": "故事标题(6字以内)", "content": "故事正文"} + +# 故事创作规范 + +1. **字数**:正文 400-600 字,不要太短也不要太长 +2. **段落**:用 `\\n\\n` 分段,每段 2-4 句话 +3. **语言**:简单易懂,适合给小朋友朗读;可以包含拟声词("哗啦啦"、"咕噜噜")和语气词("哇!"、"嘿嘿") +4. **结构**:开头引入角色和场景 → 中间遇到挑战或趣事 → 结尾温馨圆满 +5. **情感**:温暖、有趣、充满想象力,带一点小幽默 +6. **教育**:自然融入一个小道理(勇气、友谊、分享等),不要说教 +7. **创意**:即使收到相同的素材组合,每次也要创作全新的、不同的故事情节 +8. **角色融合**:所有用户选择的角色、场景、道具都必须在故事中出现并发挥作用 +9. **标题**:简短有趣,6 个字以内,能引起小朋友的好奇心""" + + +def build_user_prompt(characters, scenes, props): + """构建用户提示词""" + parts = [] + if characters: + parts.append(f"角色:{', '.join(characters)}") + if scenes: + parts.append(f"场景:{', '.join(scenes)}") + if props: + parts.append(f"道具:{', '.join(props)}") + return '请根据以下元素创作一个儿童故事:\n' + '\n'.join(parts) + + +def generate_story_stream(characters, scenes, props): + """ + 流式生成故事,yield SSE 事件字符串。 + 使用火山引擎豆包大模型(OpenAI 兼容接口)。 + + Yields: + str: SSE 格式的事件数据行 + """ + config = settings.LLM_CONFIG + + if not config.get('API_KEY'): + yield _sse_event('error', {'message': 'Volcengine API Key 未配置'}) + return + + if not OPENAI_AVAILABLE: + yield _sse_event('error', {'message': 'openai 库未安装,请运行 pip install openai'}) + return + + yield _sse_event('stage', { + 'stage': 'connecting', + 'progress': 0, + 'message': '正在收集灵感碎片...', + }) + + client = OpenAI( + api_key=config['API_KEY'], + base_url=config['API_BASE_URL'], + ) + + user_prompt = build_user_prompt(characters, scenes, props) + + try: + yield _sse_event('stage', { + 'stage': 'generating', + 'progress': 10, + 'message': '故事正在诞生...', + }) + + stream = client.chat.completions.create( + model=config['MODEL_NAME'], + messages=[ + {'role': 'system', 'content': STORY_SYSTEM_PROMPT}, + {'role': 'user', 'content': user_prompt}, + ], + max_tokens=2048, + stream=True, + ) + + full_content = '' + chunk_count = 0 + + for chunk in stream: + delta = chunk.choices[0].delta if chunk.choices else None + if delta and delta.content: + full_content += delta.content + chunk_count += 1 + + if chunk_count % 5 == 0: + progress = min(10 + int(chunk_count * 0.5), 80) + yield _sse_event('stage', { + 'stage': 'generating', + 'progress': progress, + 'message': '故事正在诞生...', + }) + + yield _sse_event('stage', { + 'stage': 'parsing', + 'progress': 85, + 'message': '正在编制最后的魔法...', + }) + + result = _parse_story_json(full_content) + + yield _sse_event('done', { + 'stage': 'done', + 'progress': 100, + 'message': '大功告成!', + 'title': result['title'], + 'content': result['content'], + }) + + except Exception as e: + logger.error(f'LLM story generation failed: {e}') + yield _sse_event('error', {'message': f'故事生成失败: {str(e)}'}) + + +def _parse_story_json(text): + """从 LLM 输出中解析故事 JSON""" + text = text.strip() + if text.startswith('```'): + text = text.split('\n', 1)[1] if '\n' in text else text[3:] + if text.endswith('```'): + text = text[:-3] + text = text.strip() + + try: + data = json.loads(text) + return { + 'title': data.get('title', '未命名故事'), + 'content': data.get('content', text), + } + except json.JSONDecodeError: + return { + 'title': '新故事', + 'content': text, + } + + +def _sse_event(event, data): + """格式化 SSE 事件""" + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" diff --git a/apps/stories/services/tts_service.py b/apps/stories/services/tts_service.py new file mode 100644 index 0000000..d0ebd7c --- /dev/null +++ b/apps/stories/services/tts_service.py @@ -0,0 +1,117 @@ +""" +TTS 语音合成服务 + OSS 上传 +""" +import io +import json +import uuid +import logging +from datetime import datetime +from django.conf import settings + +logger = logging.getLogger(__name__) + +# TTS 提供商可在此切换,当前预留 edge-tts(免费) +TTS_VOICE = 'zh-CN-XiaoxiaoNeural' + + +def generate_tts_stream(story): + """ + 为故事生成 TTS 音频并上传 OSS,通过 SSE 推送进度。 + + Args: + story: Story model instance + + Yields: + str: SSE 格式的事件数据行 + """ + yield _sse_event('stage', { + 'stage': 'connecting', + 'progress': 0, + 'message': '正在准备语音合成...', + }) + + try: + import edge_tts + except ImportError: + yield _sse_event('error', {'message': 'edge-tts 库未安装,请运行 pip install edge-tts'}) + return + + # 如果已有音频,直接返回 + if story.audio_url: + yield _sse_event('done', { + 'stage': 'done', + 'progress': 100, + 'message': '音频已存在', + 'audio_url': story.audio_url, + }) + return + + yield _sse_event('stage', { + 'stage': 'generating', + 'progress': 10, + 'message': '正在合成语音...', + }) + + try: + # edge-tts 是异步的,需要在同步上下文中运行 + import asyncio + audio_data = asyncio.run(_synthesize(story.content)) + except Exception as e: + logger.error(f'TTS synthesis failed: {e}') + yield _sse_event('error', {'message': f'语音合成失败: {str(e)}'}) + return + + yield _sse_event('stage', { + 'stage': 'saving', + 'progress': 70, + 'message': '正在保存音频文件...', + }) + + # 上传到 OSS + try: + from utils.oss import get_oss_client + oss_client = get_oss_client() + + filename = f"{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.mp3" + key = f"stories/audio/{filename}" + + oss_client.bucket.put_object(key, audio_data) + + oss_config = settings.ALIYUN_OSS + if oss_config.get('CUSTOM_DOMAIN'): + audio_url = f"https://{oss_config['CUSTOM_DOMAIN']}/{key}" + else: + audio_url = f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}/{key}" + + # 更新故事记录 + story.audio_url = audio_url + story.save(update_fields=['audio_url']) + + except Exception as e: + logger.error(f'OSS upload failed: {e}') + yield _sse_event('error', {'message': f'音频上传失败: {str(e)}'}) + return + + yield _sse_event('done', { + 'stage': 'done', + 'progress': 100, + 'message': '语音合成完成!', + 'audio_url': audio_url, + }) + + +async def _synthesize(text): + """使用 edge-tts 合成语音,返回音频 bytes""" + import edge_tts + + communicate = edge_tts.Communicate(text, TTS_VOICE) + audio_chunks = [] + async for chunk in communicate.stream(): + if chunk['type'] == 'audio': + audio_chunks.append(chunk['data']) + return b''.join(audio_chunks) + + +def _sse_event(event, data): + """格式化 SSE 事件""" + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" diff --git a/apps/stories/urls.py b/apps/stories/urls.py index 7d931f2..408ee73 100644 --- a/apps/stories/urls.py +++ b/apps/stories/urls.py @@ -10,5 +10,5 @@ router.register('shelves', ShelfViewSet, basename='shelves') router.register('', StoryViewSet, basename='stories') urlpatterns = [ - path('', include(router.urls)), + path('stories/', include(router.urls)), ] diff --git a/apps/stories/views.py b/apps/stories/views.py index 2ca2896..c8ec4bf 100644 --- a/apps/stories/views.py +++ b/apps/stories/views.py @@ -1,7 +1,9 @@ """ 故事模块视图 - App端 """ +from django.db import transaction from django.db.models import Count +from django.http import StreamingHttpResponse from rest_framework import viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticated @@ -15,10 +17,18 @@ from .serializers import ( StoryShelfSerializer, CreateShelfSerializer, StoryListSerializer, + StoryDetailSerializer, + CreateStorySerializer, GenerateStorySerializer, ) +def ensure_default_shelf(user): + """确保用户有默认书架,没有则创建""" + if not StoryShelf.objects.filter(user=user).exists(): + StoryShelf.objects.create(user=user, name='我的书架') + + @extend_schema(tags=['故事']) class StoryViewSet(viewsets.ViewSet): """故事视图集(App端)""" @@ -49,6 +59,41 @@ class StoryViewSet(viewsets.ViewSet): 'items': StoryListSerializer(items, many=True).data, }) + def create(self, request): + """ + 保存故事 + POST /api/v1/stories/ + """ + serializer = CreateStorySerializer(data=request.data) + if not serializer.is_valid(): + return error(message=str(serializer.errors)) + + shelf_id = serializer.validated_data['shelf_id'] + try: + shelf = StoryShelf.objects.get( + id=shelf_id, user=request.user, is_locked=False + ) + except StoryShelf.DoesNotExist: + return error(code=ErrorCode.SHELF_NOT_FOUND, message='书架不存在或未解锁') + + if shelf.is_full: + return error(code=ErrorCode.SHELF_FULL, message='书架已满,请解锁新书架') + + story = Story.objects.create( + user=request.user, + shelf=shelf, + title=serializer.validated_data['title'], + content=serializer.validated_data['content'], + cover_url=serializer.validated_data.get('cover_url', ''), + generation_mode=serializer.validated_data.get('generation_mode', 'ai'), + prompt=serializer.validated_data.get('prompt', ''), + ) + + return success( + data=StoryDetailSerializer(story).data, + message='保存成功' + ) + def destroy(self, request, pk=None): """ 删除故事 @@ -64,35 +109,55 @@ class StoryViewSet(viewsets.ViewSet): @action(detail=False, methods=['post'], url_path='generate') def generate(self, request): """ - 生成故事 (SSE 流式 - 占位) + 生成故事 (SSE 流式) POST /api/v1/stories/generate/ """ serializer = GenerateStorySerializer(data=request.data) if not serializer.is_valid(): return error(message=str(serializer.errors)) - # TODO: 接入 LLM API 实现 SSE 流式生成 - shelf_id = serializer.validated_data.get('shelf_id') - shelf = None - if shelf_id: - try: - shelf = StoryShelf.objects.get(id=shelf_id, user=request.user) - except StoryShelf.DoesNotExist: - return error(code=ErrorCode.SHELF_NOT_FOUND, message='书架不存在') + characters = serializer.validated_data.get('characters', []) + scenes = serializer.validated_data.get('scenes', []) + props = serializer.validated_data.get('props', []) - story = Story.objects.create( - user=request.user, - shelf=shelf, - title='生成中...', - content='', - generation_mode=serializer.validated_data.get('mode', 'random'), - prompt=serializer.validated_data.get('prompt', ''), + from .services.llm_service import generate_story_stream + + response = StreamingHttpResponse( + generate_story_stream(characters, scenes, props), + content_type='text/event-stream', ) + response['Cache-Control'] = 'no-cache' + response['X-Accel-Buffering'] = 'no' + return response - return success(data={ - 'id': story.id, - 'message': '故事生成功能待接入 LLM API', - }) + @action(detail=True, methods=['get', 'post'], url_path='tts') + def tts(self, request, pk=None): + """ + TTS 音频接口 + GET /api/v1/stories/{id}/tts/ - 查询音频状态 + POST /api/v1/stories/{id}/tts/ - 生成 TTS 音频 (SSE 流式) + """ + try: + story = Story.objects.get(id=pk, user=request.user) + except Story.DoesNotExist: + return error(code=ErrorCode.STORY_NOT_FOUND, message='故事不存在') + + if request.method == 'GET': + return success(data={ + 'exists': bool(story.audio_url), + 'audio_url': story.audio_url, + }) + + # POST: 生成音频 + from .services.tts_service import generate_tts_stream + + response = StreamingHttpResponse( + generate_tts_stream(story), + content_type='text/event-stream', + ) + response['Cache-Control'] = 'no-cache' + response['X-Accel-Buffering'] = 'no' + return response @extend_schema(tags=['故事']) @@ -107,6 +172,8 @@ class ShelfViewSet(viewsets.ViewSet): 书架列表 GET /api/v1/stories/shelves/ """ + ensure_default_shelf(request.user) + shelves = StoryShelf.objects.filter( user=request.user ).annotate(story_count=Count('stories')) @@ -140,3 +207,50 @@ class ShelfViewSet(viewsets.ViewSet): Story.objects.filter(shelf=shelf).update(shelf=None) shelf.delete() return success(message='删除成功') + + @action(detail=False, methods=['post'], url_path='unlock') + def unlock(self, request): + """ + 积分解锁新书架 + POST /api/v1/stories/shelves/unlock/ + """ + from apps.users.models import PointsRecord + + # 解锁费用(可后续改为从配置读取) + unlock_cost = 100 + user = request.user + + if user.points < unlock_cost: + return error( + code=ErrorCode.POINTS_NOT_ENOUGH, + message=f'积分不足,需要 {unlock_cost} 积分,当前 {user.points} 积分' + ) + + shelf_count = StoryShelf.objects.filter(user=user).count() + shelf_name = f'书架 {shelf_count + 1}' + + with transaction.atomic(): + # 扣除积分 + user.points -= unlock_cost + user.save(update_fields=['points']) + + # 创建书架 + shelf = StoryShelf.objects.create( + user=user, + name=shelf_name, + unlock_cost=unlock_cost, + ) + + # 记录积分流水 + PointsRecord.objects.create( + user=user, + amount=-unlock_cost, + type='unlock_shelf', + description=f'解锁书架「{shelf_name}」', + ) + + shelf.story_count = 0 + return success(data={ + 'shelf': StoryShelfSerializer(shelf).data, + 'remaining_points': user.points, + }, message='解锁成功') diff --git a/apps/users/migrations/0003_pointsrecord_and_more.py b/apps/users/migrations/0003_pointsrecord_and_more.py new file mode 100644 index 0000000..abf53d7 --- /dev/null +++ b/apps/users/migrations/0003_pointsrecord_and_more.py @@ -0,0 +1,73 @@ +# Generated by Django 4.2 on 2026-02-12 02:53 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ("users", "0002_smscode_user_birthday_user_deletion_requested_at_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="PointsRecord", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("amount", models.IntegerField(verbose_name="变动数量")), + ( + "type", + models.CharField( + choices=[ + ("unlock_shelf", "解锁书架"), + ("reward", "奖励"), + ("admin_adjust", "管理员调整"), + ], + max_length=30, + verbose_name="类型", + ), + ), + ( + "description", + models.CharField( + blank=True, default="", max_length=200, verbose_name="描述" + ), + ), + ( + "created_at", + models.DateTimeField(auto_now_add=True, verbose_name="创建时间"), + ), + ( + "user", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="points_records", + to=settings.AUTH_USER_MODEL, + verbose_name="用户", + ), + ), + ], + options={ + "verbose_name": "积分流水", + "verbose_name_plural": "积分流水", + "db_table": "points_record", + "ordering": ["-created_at"], + }, + ), + migrations.AddIndex( + model_name="pointsrecord", + index=models.Index( + fields=["user", "-created_at"], name="points_reco_user_id_b675f0_idx" + ), + ), + ] diff --git a/apps/users/models.py b/apps/users/models.py index 2abba15..9a71c5a 100644 --- a/apps/users/models.py +++ b/apps/users/models.py @@ -58,6 +58,37 @@ class User(AbstractBaseUser, PermissionsMixin): return self.phone +class PointsRecord(models.Model): + """积分流水记录""" + + TYPE_CHOICES = [ + ('unlock_shelf', '解锁书架'), + ('reward', '奖励'), + ('admin_adjust', '管理员调整'), + ] + + user = models.ForeignKey( + User, on_delete=models.CASCADE, + related_name='points_records', verbose_name='用户' + ) + amount = models.IntegerField('变动数量') + type = models.CharField('类型', max_length=30, choices=TYPE_CHOICES) + description = models.CharField('描述', max_length=200, blank=True, default='') + created_at = models.DateTimeField('创建时间', auto_now_add=True) + + class Meta: + db_table = 'points_record' + verbose_name = '积分流水' + verbose_name_plural = verbose_name + ordering = ['-created_at'] + indexes = [ + models.Index(fields=['user', '-created_at']), + ] + + def __str__(self): + return f'{self.user.phone} {self.amount:+d} ({self.type})' + + class SmsCode(models.Model): """短信验证码""" phone = models.CharField('手机号', max_length=20) diff --git a/apps/users/serializers.py b/apps/users/serializers.py index 8b7daf4..934bb15 100644 --- a/apps/users/serializers.py +++ b/apps/users/serializers.py @@ -2,7 +2,7 @@ 用户模块序列化器 """ from rest_framework import serializers -from .models import User +from .models import User, PointsRecord class UserSerializer(serializers.ModelSerializer): @@ -10,8 +10,16 @@ class UserSerializer(serializers.ModelSerializer): class Meta: model = User - fields = ['id', 'phone', 'nickname', 'avatar', 'gender', 'birthday', 'created_at'] - read_only_fields = ['id', 'phone', 'created_at'] + fields = ['id', 'phone', 'nickname', 'avatar', 'gender', 'birthday', 'points', 'created_at'] + read_only_fields = ['id', 'phone', 'points', 'created_at'] + + +class PointsRecordSerializer(serializers.ModelSerializer): + """积分流水序列化器""" + + class Meta: + model = PointsRecord + fields = ['id', 'amount', 'type', 'description', 'created_at'] class UserDetailSerializer(serializers.ModelSerializer): diff --git a/apps/users/views.py b/apps/users/views.py index 05275fc..eaf5054 100644 --- a/apps/users/views.py +++ b/apps/users/views.py @@ -24,7 +24,9 @@ from .serializers import ( UpdateUserSerializer, SendCodeSerializer, CodeLoginSerializer, + PointsRecordSerializer, ) +from .models import PointsRecord def get_app_tokens(user): @@ -254,6 +256,33 @@ class UserViewSet(viewsets.ViewSet): return success(data={'avatar_url': avatar_url}) + @action(detail=False, methods=['get']) + def points(self, request): + """ + 查询积分余额 + GET /api/v1/users/points/ + """ + return success(data={'points': request.user.points}) + + @action(detail=False, methods=['get'], url_path='points/records') + def points_records(self, request): + """ + 积分流水记录 + GET /api/v1/users/points/records/?page=1&page_size=20 + """ + queryset = PointsRecord.objects.filter(user=request.user) + + page = int(request.query_params.get('page', 1)) + page_size = int(request.query_params.get('page_size', 20)) + start = (page - 1) * page_size + total = queryset.count() + items = queryset[start:start + page_size] + + return success(data={ + 'total': total, + 'items': PointsRecordSerializer(items, many=True).data, + }) + @extend_schema(tags=['管理员-用户管理']) class AdminUserManageViewSet(viewsets.ViewSet): diff --git a/config/settings.py b/config/settings.py index 36df6b3..df5a494 100644 --- a/config/settings.py +++ b/config/settings.py @@ -4,10 +4,13 @@ Django settings for RTC_DEMO project. import os from pathlib import Path from datetime import timedelta +from dotenv import load_dotenv # Build paths inside the project like this: BASE_DIR / 'subdir'. BASE_DIR = Path(__file__).resolve().parent.parent +load_dotenv(BASE_DIR / '.env') + # SECURITY WARNING: keep the secret key used in production secret! SECRET_KEY = os.environ.get('DJANGO_SECRET_KEY', 'django-insecure-dev-key-change-in-production') @@ -16,6 +19,9 @@ DEBUG = os.environ.get('DJANGO_DEBUG', 'True').lower() == 'true' ALLOWED_HOSTS = os.environ.get('DJANGO_ALLOWED_HOSTS', '*').split(',') +# 纯 API 服务,禁用 APPEND_SLASH 避免 POST/PUT/PATCH/DELETE 请求因缺少尾部斜杠而触发 RuntimeError +APPEND_SLASH = False + # Application definition INSTALLED_APPS = [ 'django.contrib.admin', @@ -42,9 +48,11 @@ INSTALLED_APPS = [ ] MIDDLEWARE = [ + 'utils.middleware.ExceptionReportMiddleware', 'corsheaders.middleware.CorsMiddleware', 'django.middleware.security.SecurityMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', + 'utils.middleware.TrailingSlashMiddleware', 'django.middleware.common.CommonMiddleware', 'django.middleware.csrf.CsrfViewMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', @@ -186,6 +194,13 @@ ALIYUN_PHONE_AUTH = { 'ACCESS_KEY_SECRET': os.environ.get('PHONE_AUTH_ACCESS_KEY_SECRET', ALIYUN_ACCESS_KEY_SECRET), } +# LLM Settings - Volcengine / 火山引擎豆包 (Story Generation) +LLM_CONFIG = { + 'API_KEY': os.environ.get('VOLCENGINE_API_KEY', ''), + 'API_BASE_URL': os.environ.get('VOLCENGINE_API_BASE_URL', 'https://ark.cn-beijing.volces.com/api/v3'), + 'MODEL_NAME': os.environ.get('VOLCENGINE_MODEL_NAME', 'doubao-seed-1-6-lite-251015'), +} + # Swagger/OpenAPI Settings SPECTACULAR_SETTINGS = { 'TITLE': 'RTC API', diff --git a/requirements.txt b/requirements.txt index 0abefa2..dfbf347 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,5 @@ urllib3==2.6.3 drf-spectacular==0.27.1 alibabacloud_dysmsapi20170525>=4.4.0 alibabacloud_dypnsapi20170525>=3.0.0 +openai>=1.0.0 +edge-tts>=6.1.0 diff --git a/tests.py b/tests.py index ab9bd35..d6b0efd 100644 --- a/tests.py +++ b/tests.py @@ -9,10 +9,12 @@ from django.test import TestCase from django.urls import reverse from rest_framework.test import APITestCase, APIClient from rest_framework import status -from apps.users.models import User +from apps.users.models import User, PointsRecord from apps.admins.models import AdminUser from apps.spirits.models import Spirit from apps.devices.models import DeviceType, DeviceBatch, Device, UserDevice +from apps.stories.models import StoryShelf, Story +from apps.users.views import get_app_tokens # ==================== App端测试 ==================== @@ -944,7 +946,623 @@ class ExceptionHandlerIntegrationTests(APITestCase): with patch('utils.exceptions.report_to_log_center') as mock_report: custom_exception_handler(biz_exc, context) - + # 业务异常不应触发上报 mock_report.assert_not_called() + +# ==================== 故事模块测试 ==================== + +class StoryTestBase(APITestCase): + """故事模块测试基类""" + + def setUp(self): + self.user = User.objects.create_user(phone='13800130001', nickname='故事测试用户') + tokens = get_app_tokens(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}') + # 创建默认书架 + self.shelf = StoryShelf.objects.create(user=self.user, name='我的书架') + + +class StoryShelfTests(StoryTestBase): + """书架接口测试""" + + def test_list_shelves(self): + """测试获取书架列表""" + url = '/api/v1/stories/shelves/' + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['code'], 0) + self.assertEqual(len(response.data['data']), 1) + self.assertEqual(response.data['data'][0]['name'], '我的书架') + + def test_list_shelves_auto_create_default(self): + """测试首次查询自动创建默认书架""" + new_user = User.objects.create_user(phone='13800130099') + tokens = get_app_tokens(new_user) + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}') + + url = '/api/v1/stories/shelves/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(len(response.data['data']), 1) + self.assertEqual(response.data['data'][0]['name'], '我的书架') + + def test_list_shelves_includes_story_count(self): + """测试书架列表包含故事数量""" + Story.objects.create( + user=self.user, shelf=self.shelf, + title='测试故事', content='内容' + ) + + url = '/api/v1/stories/shelves/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data'][0]['story_count'], 1) + + def test_create_shelf(self): + """测试创建书架""" + url = '/api/v1/stories/shelves/' + data = {'name': '新书架'} + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['name'], '新书架') + self.assertEqual(response.data['data']['capacity'], 10) + + def test_delete_shelf(self): + """测试删除书架""" + shelf = StoryShelf.objects.create(user=self.user, name='待删除书架') + story = Story.objects.create( + user=self.user, shelf=shelf, + title='测试', content='内容' + ) + + url = f'/api/v1/stories/shelves/{shelf.id}/' + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertFalse(StoryShelf.objects.filter(id=shelf.id).exists()) + # 故事保留,shelf_id 置 null + story.refresh_from_db() + self.assertIsNone(story.shelf) + + def test_delete_shelf_not_found(self): + """测试删除不存在的书架""" + url = '/api/v1/stories/shelves/99999/' + response = self.client.delete(url) + + self.assertNotEqual(response.data['code'], 0) + + def test_shelf_isolation(self): + """测试书架用户隔离""" + other_user = User.objects.create_user(phone='13800130002') + other_shelf = StoryShelf.objects.create(user=other_user, name='别人的书架') + + url = f'/api/v1/stories/shelves/{other_shelf.id}/' + response = self.client.delete(url) + + self.assertNotEqual(response.data['code'], 0) + # 确认没被删除 + self.assertTrue(StoryShelf.objects.filter(id=other_shelf.id).exists()) + + +class ShelfUnlockTests(StoryTestBase): + """书架解锁测试""" + + def test_unlock_shelf_success(self): + """测试积分解锁书架 - 成功""" + self.user.points = 200 + self.user.save(update_fields=['points']) + + url = '/api/v1/stories/shelves/unlock/' + response = self.client.post(url, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['remaining_points'], 100) + self.assertIn('shelf', response.data['data']) + + # 验证积分扣除 + self.user.refresh_from_db() + self.assertEqual(self.user.points, 100) + + # 验证积分流水 + record = PointsRecord.objects.filter(user=self.user).first() + self.assertIsNotNone(record) + self.assertEqual(record.amount, -100) + self.assertEqual(record.type, 'unlock_shelf') + + def test_unlock_shelf_not_enough_points(self): + """测试积分解锁书架 - 积分不足""" + self.user.points = 50 + self.user.save(update_fields=['points']) + + url = '/api/v1/stories/shelves/unlock/' + response = self.client.post(url, format='json') + + self.assertEqual(response.data['code'], 603) # POINTS_NOT_ENOUGH + # 积分不应变化 + self.user.refresh_from_db() + self.assertEqual(self.user.points, 50) + + def test_unlock_shelf_zero_points(self): + """测试积分解锁书架 - 零积分""" + self.user.points = 0 + self.user.save(update_fields=['points']) + + url = '/api/v1/stories/shelves/unlock/' + response = self.client.post(url, format='json') + + self.assertEqual(response.data['code'], 603) + + def test_unlock_shelf_naming(self): + """测试解锁书架自动命名""" + self.user.points = 500 + self.user.save(update_fields=['points']) + + url = '/api/v1/stories/shelves/unlock/' + + # 第一次解锁(已有1个默认书架) + response = self.client.post(url, format='json') + self.assertEqual(response.data['data']['shelf']['name'], '书架 2') + + # 第二次解锁 + response = self.client.post(url, format='json') + self.assertEqual(response.data['data']['shelf']['name'], '书架 3') + + +class StoryTests(StoryTestBase): + """故事接口测试""" + + def test_list_stories(self): + """测试获取故事列表""" + Story.objects.create( + user=self.user, shelf=self.shelf, + title='故事1', content='内容1' + ) + Story.objects.create( + user=self.user, shelf=self.shelf, + title='故事2', content='内容2' + ) + + url = '/api/v1/stories/' + response = self.client.get(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['total'], 2) + self.assertEqual(len(response.data['data']['items']), 2) + + def test_list_stories_filter_by_shelf(self): + """测试按书架筛选故事""" + shelf2 = StoryShelf.objects.create(user=self.user, name='书架2') + Story.objects.create( + user=self.user, shelf=self.shelf, + title='书架1故事', content='内容' + ) + Story.objects.create( + user=self.user, shelf=shelf2, + title='书架2故事', content='内容' + ) + + url = f'/api/v1/stories/?shelf_id={self.shelf.id}' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['total'], 1) + self.assertEqual(response.data['data']['items'][0]['title'], '书架1故事') + + def test_list_stories_empty(self): + """测试空故事列表""" + url = '/api/v1/stories/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['total'], 0) + + def test_create_story(self): + """测试保存故事""" + url = '/api/v1/stories/' + data = { + 'title': '新故事', + 'content': '这是故事内容', + 'shelf_id': self.shelf.id, + } + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['title'], '新故事') + self.assertEqual(response.data['data']['content'], '这是故事内容') + + def test_create_story_with_optional_fields(self): + """测试保存故事 - 包含可选字段""" + url = '/api/v1/stories/' + data = { + 'title': '完整故事', + 'content': '故事正文', + 'shelf_id': self.shelf.id, + 'cover_url': 'https://example.com/cover.jpg', + 'generation_mode': 'ai', + 'prompt': '角色=小猫, 场景=森林', + } + response = self.client.post(url, data, format='json') + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['generation_mode'], 'ai') + + def test_create_story_shelf_full(self): + """测试保存故事 - 书架已满""" + # 创建小容量书架 + small_shelf = StoryShelf.objects.create( + user=self.user, name='小书架', capacity=2 + ) + Story.objects.create(user=self.user, shelf=small_shelf, title='故事1', content='内容') + Story.objects.create(user=self.user, shelf=small_shelf, title='故事2', content='内容') + + url = '/api/v1/stories/' + data = { + 'title': '溢出故事', + 'content': '这本放不下了', + 'shelf_id': small_shelf.id, + } + response = self.client.post(url, data, format='json') + + self.assertEqual(response.data['code'], 604) # SHELF_FULL + + def test_create_story_shelf_not_found(self): + """测试保存故事 - 书架不存在""" + url = '/api/v1/stories/' + data = { + 'title': '故事', + 'content': '内容', + 'shelf_id': 99999, + } + response = self.client.post(url, data, format='json') + + self.assertNotEqual(response.data['code'], 0) + + def test_create_story_other_user_shelf(self): + """测试保存故事到他人书架""" + other_user = User.objects.create_user(phone='13800130003') + other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架') + + url = '/api/v1/stories/' + data = { + 'title': '故事', + 'content': '内容', + 'shelf_id': other_shelf.id, + } + response = self.client.post(url, data, format='json') + + self.assertNotEqual(response.data['code'], 0) + + def test_delete_story(self): + """测试删除故事""" + story = Story.objects.create( + user=self.user, shelf=self.shelf, + title='待删除', content='内容' + ) + + url = f'/api/v1/stories/{story.id}/' + response = self.client.delete(url) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertFalse(Story.objects.filter(id=story.id).exists()) + + def test_delete_story_not_found(self): + """测试删除不存在的故事""" + url = '/api/v1/stories/99999/' + response = self.client.delete(url) + + self.assertEqual(response.data['code'], 600) # STORY_NOT_FOUND + + def test_story_isolation(self): + """测试故事用户隔离""" + other_user = User.objects.create_user(phone='13800130004') + other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架') + other_story = Story.objects.create( + user=other_user, shelf=other_shelf, + title='他人故事', content='内容' + ) + + url = f'/api/v1/stories/{other_story.id}/' + response = self.client.delete(url) + + self.assertNotEqual(response.data['code'], 0) + self.assertTrue(Story.objects.filter(id=other_story.id).exists()) + + def test_story_capacity_limit(self): + """测试书架容量为10的限制""" + # 默认书架容量 = 10 + for i in range(10): + Story.objects.create( + user=self.user, shelf=self.shelf, + title=f'故事{i+1}', content=f'内容{i+1}' + ) + + url = '/api/v1/stories/' + data = { + 'title': '第11本', + 'content': '超出容量', + 'shelf_id': self.shelf.id, + } + response = self.client.post(url, data, format='json') + + self.assertEqual(response.data['code'], 604) # SHELF_FULL + + +class StoryGenerateTests(StoryTestBase): + """故事生成接口测试""" + + def test_generate_story_returns_sse(self): + """测试生成故事返回 SSE 流""" + from unittest.mock import patch + + mock_events = [ + 'event: stage\ndata: {"stage":"connecting","progress":0,"message":"正在收集灵感碎片..."}\n\n', + 'event: done\ndata: {"stage":"done","progress":100,"title":"测试故事","content":"故事内容"}\n\n', + ] + + with patch('apps.stories.services.llm_service.generate_story_stream') as mock_gen: + mock_gen.return_value = iter(mock_events) + + url = '/api/v1/stories/generate/' + data = { + 'characters': ['小猫'], + 'scenes': ['森林'], + 'props': ['魔法棒'], + } + response = self.client.post(url, data, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response['Content-Type'], 'text/event-stream') + + def test_generate_story_empty_params(self): + """测试生成故事 - 空参数(允许,有默认值)""" + from unittest.mock import patch + + mock_events = [ + 'event: done\ndata: {"stage":"done","progress":100,"title":"默认故事","content":"内容"}\n\n', + ] + + with patch('apps.stories.services.llm_service.generate_story_stream') as mock_gen: + mock_gen.return_value = iter(mock_events) + + url = '/api/v1/stories/generate/' + response = self.client.post(url, {}, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class StoryTTSTests(StoryTestBase): + """TTS 音频接口测试""" + + def test_tts_check_no_audio(self): + """测试查询音频状态 - 无音频""" + story = Story.objects.create( + user=self.user, shelf=self.shelf, + title='无音频故事', content='内容' + ) + + url = f'/api/v1/stories/{story.id}/tts/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertFalse(response.data['data']['exists']) + self.assertEqual(response.data['data']['audio_url'], '') + + def test_tts_check_has_audio(self): + """测试查询音频状态 - 有音频""" + story = Story.objects.create( + user=self.user, shelf=self.shelf, + title='有音频故事', content='内容', + audio_url='https://oss.example.com/audio.mp3' + ) + + url = f'/api/v1/stories/{story.id}/tts/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertTrue(response.data['data']['exists']) + self.assertEqual( + response.data['data']['audio_url'], + 'https://oss.example.com/audio.mp3' + ) + + def test_tts_generate_returns_sse(self): + """测试生成 TTS 返回 SSE 流""" + from unittest.mock import patch + + story = Story.objects.create( + user=self.user, shelf=self.shelf, + title='TTS测试', content='这是要转换的故事内容' + ) + + mock_events = [ + 'event: stage\ndata: {"stage":"connecting","message":"正在连接..."}\n\n', + 'event: done\ndata: {"stage":"done","audio_url":"https://oss.example.com/audio.mp3"}\n\n', + ] + + with patch('apps.stories.services.tts_service.generate_tts_stream') as mock_tts: + mock_tts.return_value = iter(mock_events) + + url = f'/api/v1/stories/{story.id}/tts/' + response = self.client.post(url, format='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response['Content-Type'], 'text/event-stream') + + def test_tts_story_not_found(self): + """测试 TTS - 故事不存在""" + url = '/api/v1/stories/99999/tts/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 600) # STORY_NOT_FOUND + + def test_tts_story_isolation(self): + """测试 TTS - 不能访问他人故事""" + other_user = User.objects.create_user(phone='13800130005') + other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架') + other_story = Story.objects.create( + user=other_user, shelf=other_shelf, + title='他人故事', content='内容' + ) + + url = f'/api/v1/stories/{other_story.id}/tts/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 600) + + +class PointsTests(StoryTestBase): + """积分接口测试""" + + def test_query_points(self): + """测试查询积分余额""" + self.user.points = 500 + self.user.save(update_fields=['points']) + + url = '/api/v1/users/points/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['points'], 500) + + def test_query_points_default_zero(self): + """测试查询积分余额 - 默认为0""" + url = '/api/v1/users/points/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['points'], 0) + + def test_points_records_list(self): + """测试积分流水记录""" + PointsRecord.objects.create( + user=self.user, amount=100, + type='reward', description='注册奖励' + ) + PointsRecord.objects.create( + user=self.user, amount=-100, + type='unlock_shelf', description='解锁书架' + ) + + url = '/api/v1/users/points/records/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['total'], 2) + self.assertEqual(len(response.data['data']['items']), 2) + + def test_points_records_empty(self): + """测试积分流水记录 - 空""" + url = '/api/v1/users/points/records/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['total'], 0) + + def test_points_records_pagination(self): + """测试积分流水分页""" + for i in range(5): + PointsRecord.objects.create( + user=self.user, amount=10, + type='reward', description=f'奖励{i+1}' + ) + + url = '/api/v1/users/points/records/?page=1&page_size=3' + response = self.client.get(url) + + self.assertEqual(response.data['data']['total'], 5) + self.assertEqual(len(response.data['data']['items']), 3) + + def test_points_records_isolation(self): + """测试积分流水用户隔离""" + other_user = User.objects.create_user(phone='13800130006') + PointsRecord.objects.create( + user=other_user, amount=100, + type='reward', description='他人的奖励' + ) + + url = '/api/v1/users/points/records/' + response = self.client.get(url) + + self.assertEqual(response.data['data']['total'], 0) + + +class LLMServiceTests(TestCase): + """LLM 服务单元测试""" + + def test_build_user_prompt(self): + """测试构建用户提示词""" + from apps.stories.services.llm_service import build_user_prompt + + prompt = build_user_prompt(['小猫', '小狗'], ['森林'], ['魔法棒']) + self.assertIn('小猫', prompt) + self.assertIn('小狗', prompt) + self.assertIn('森林', prompt) + self.assertIn('魔法棒', prompt) + + def test_build_user_prompt_partial(self): + """测试构建提示词 - 部分参数""" + from apps.stories.services.llm_service import build_user_prompt + + prompt = build_user_prompt(['公主'], [], []) + self.assertIn('公主', prompt) + self.assertNotIn('场景', prompt) + self.assertNotIn('道具', prompt) + + def test_parse_story_json_valid(self): + """测试解析故事 JSON - 有效""" + from apps.stories.services.llm_service import _parse_story_json + + text = '{"title": "小猫冒险", "content": "从前有一只小猫..."}' + result = _parse_story_json(text) + self.assertEqual(result['title'], '小猫冒险') + self.assertEqual(result['content'], '从前有一只小猫...') + + def test_parse_story_json_with_markdown(self): + """测试解析故事 JSON - 包含 markdown 代码块""" + from apps.stories.services.llm_service import _parse_story_json + + text = '```json\n{"title": "森林故事", "content": "在深深的森林里..."}\n```' + result = _parse_story_json(text) + self.assertEqual(result['title'], '森林故事') + + def test_parse_story_json_invalid(self): + """测试解析故事 JSON - 无效 JSON""" + from apps.stories.services.llm_service import _parse_story_json + + text = '这不是一个有效的 JSON 格式的文本' + result = _parse_story_json(text) + self.assertEqual(result['title'], '新故事') + self.assertIn('这不是', result['content']) + + def test_sse_event_format(self): + """测试 SSE 事件格式化""" + from apps.stories.services.llm_service import _sse_event + + event = _sse_event('stage', {'stage': 'connecting', 'progress': 0}) + self.assertTrue(event.startswith('event: stage\n')) + self.assertIn('data: ', event) + self.assertTrue(event.endswith('\n\n')) + + def test_generate_stream_without_api_key(self): + """测试未配置 API Key 时返回错误事件""" + from apps.stories.services.llm_service import generate_story_stream + from unittest.mock import patch + + with patch('apps.stories.services.llm_service.settings') as mock_settings: + mock_settings.LLM_CONFIG = {'API_KEY': '', 'API_BASE_URL': '', 'MODEL_NAME': ''} + events = list(generate_story_stream(['小猫'], [], [])) + + self.assertEqual(len(events), 1) + self.assertIn('error', events[0]) + self.assertIn('未配置', events[0]) + diff --git a/utils/exceptions.py b/utils/exceptions.py index af1900e..e6217d9 100644 --- a/utils/exceptions.py +++ b/utils/exceptions.py @@ -118,6 +118,7 @@ class ErrorCode: SHELF_NOT_FOUND = 601 SHELF_LOCKED = 602 POINTS_NOT_ENOUGH = 603 + SHELF_FULL = 604 # 音乐模块 700-799 TRACK_NOT_FOUND = 700 diff --git a/utils/middleware.py b/utils/middleware.py new file mode 100644 index 0000000..83d5943 --- /dev/null +++ b/utils/middleware.py @@ -0,0 +1,123 @@ +""" +全局异常捕获中间件 + +两层防线确保异常上报到 Log Center: +1. got_request_exception 信号 —— 捕获被 Django convert_exception_to_response 吞掉的异常 + (如 CommonMiddleware 的 APPEND_SLASH RuntimeError) +2. ExceptionReportMiddleware 的 try/except —— 兜底捕获穿透所有内层包裹的异常 +""" +import os +import sys +import traceback +import threading +import requests +from django.http import JsonResponse +from django.core.signals import got_request_exception + + +LOG_CENTER_URL = os.environ.get('LOG_CENTER_URL', 'https://qiyuan-log-center-api.airlabs.art') +LOG_CENTER_ENABLED = os.environ.get('LOG_CENTER_ENABLED', 'true').lower() == 'true' + + +def _send_to_log_center(payload): + """异步发送日志到 Log Center""" + def send_async(): + try: + requests.post( + f"{LOG_CENTER_URL}/api/v1/logs/report", + json=payload, + timeout=3, + ) + except Exception: + pass + + thread = threading.Thread(target=send_async) + thread.daemon = True + thread.start() + + +def _report_exception(exc, request): + """构造 payload 并上报异常""" + if not LOG_CENTER_ENABLED: + return + + try: + tb = traceback.extract_tb(exc.__traceback__) if exc.__traceback__ else [] + last_frame = tb[-1] if tb else None + + payload = { + "project_id": "rtc_backend", + "environment": os.environ.get('ENVIRONMENT', 'production'), + "level": "ERROR", + "error": { + "type": type(exc).__name__, + "message": str(exc), + "file_path": last_frame.filename if last_frame else "unknown", + "line_number": last_frame.lineno if last_frame else 0, + "stack_trace": traceback.format_exception(exc) if exc.__traceback__ else [str(exc)], + }, + "context": { + "url": request.path, + "method": request.method, + "view": "middleware", + }, + } + _send_to_log_center(payload) + except Exception: + pass + + +def _on_request_exception(sender, request, **kwargs): + """ + Django 信号回调:convert_exception_to_response 内部触发。 + 此时 sys.exc_info() 仍持有完整的异常上下文。 + """ + exc_info = sys.exc_info() + exc = exc_info[1] + if exc: + _report_exception(exc, request) + + +# 模块加载时注册信号,全局生效 +got_request_exception.connect(_on_request_exception) + + +class TrailingSlashMiddleware: + """ + 为缺少尾部斜杠的 API 请求补全 '/',直接修改 request.path_info, + 不做 HTTP 重定向,因此 POST/PUT/PATCH 请求体完好保留。 + + 配合 APPEND_SLASH = False 使用,替代 CommonMiddleware 的重定向逻辑。 + 必须放在 CommonMiddleware 之前。 + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if not request.path_info.endswith('/'): + request.path_info = request.path_info + '/' + return self.get_response(request) + + +class ExceptionReportMiddleware: + """ + 全局异常捕获中间件,必须放在 MIDDLEWARE 列表的第一个位置。 + + 作为第二道防线:如果异常穿过了所有内层中间件的 + convert_exception_to_response 包裹,这里的 try/except 仍会兜底。 + """ + + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + try: + response = self.get_response(request) + return response + except Exception as exc: + _report_exception(exc, request) + return JsonResponse( + {"code": 1, "message": str(exc), "data": None}, + status=500, + ) diff --git a/utils/routers.py b/utils/routers.py new file mode 100644 index 0000000..1e5681c --- /dev/null +++ b/utils/routers.py @@ -0,0 +1,9 @@ +from rest_framework.routers import DefaultRouter + + +class OptionalSlashRouter(DefaultRouter): + """尾部斜杠可选的 Router,兼容 /path 和 /path/ 两种形式""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.trailing_slash = '/?'