rtc_backend/tests.py
2026-02-27 16:38:50 +08:00

2342 lines
89 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RTC_DEMO API 完整测试用例
覆盖所有API接口
运行方式: python manage.py test tests --verbosity=2 --keepdb
"""
import json
from datetime import date
from unittest.mock import patch, MagicMock
from django.test import TestCase
from django.urls import reverse
from rest_framework.test import APITestCase, APIClient
from rest_framework import status
from apps.users.models import User, PointsRecord
from apps.admins.models import AdminUser
from apps.spirits.models import Spirit
from apps.devices.models import DeviceType, DeviceBatch, Device, UserDevice, RoleMemory
from apps.stories.models import StoryShelf, Story
from apps.music.models import Track
from apps.users.views import get_app_tokens
# ==================== App端测试 ====================
class UserAuthTests(APITestCase):
"""App端用户认证测试"""
def test_phone_login_new_user(self):
"""测试手机号一键登录 - 新用户"""
url = '/api/v1/auth/phone-login/'
data = {'phone': '13800138001'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertIn('token', response.data['data'])
self.assertTrue(response.data['data']['is_new_user'])
self.assertEqual(response.data['data']['user']['phone'], '13800138001')
def test_phone_login_existing_user(self):
"""测试手机号一键登录 - 已有用户"""
User.objects.create_user(phone='13800138002', nickname='测试用户')
url = '/api/v1/auth/phone-login/'
data = {'phone': '13800138002'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertFalse(response.data['data']['is_new_user'])
def test_phone_login_invalid_phone(self):
"""测试手机号格式验证"""
url = '/api/v1/auth/phone-login/'
data = {'phone': '1234567'}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_phone_login_disabled_user(self):
"""测试禁用用户登录"""
User.objects.create_user(phone='13800138003', is_active=False)
url = '/api/v1/auth/phone-login/'
data = {'phone': '13800138003'}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_token_refresh(self):
"""测试Token刷新"""
login_url = '/api/v1/auth/phone-login/'
login_response = self.client.post(login_url, {'phone': '13800138004'}, format='json')
refresh_token = login_response.data['data']['token']['refresh']
refresh_url = '/api/v1/auth/refresh/'
response = self.client.post(refresh_url, {'refresh': refresh_token}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertIn('access', response.data['data'])
def test_token_refresh_invalid(self):
"""测试无效Token刷新"""
url = '/api/v1/auth/refresh/'
response = self.client.post(url, {'refresh': 'invalid_token'}, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_token_refresh_empty(self):
"""测试空Token刷新"""
url = '/api/v1/auth/refresh/'
response = self.client.post(url, {}, format='json')
self.assertNotEqual(response.data['code'], 0)
class UserProfileTests(APITestCase):
"""App端用户信息测试"""
def setUp(self):
self.user = User.objects.create_user(phone='13800138010', nickname='测试用户')
login_url = '/api/v1/auth/phone-login/'
response = self.client.post(login_url, {'phone': '13800138010'}, format='json')
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
def test_get_user_info(self):
"""测试获取用户信息"""
url = '/api/v1/users/me/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['phone'], '13800138010')
def test_get_user_info_unauthorized(self):
"""测试未登录获取用户信息"""
self.client.credentials()
url = '/api/v1/users/me/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_update_user_info(self):
"""测试更新用户信息"""
url = '/api/v1/users/update_me/'
data = {'nickname': '新昵称'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['nickname'], '新昵称')
def test_update_user_avatar(self):
"""测试更新用户头像"""
url = '/api/v1/users/update_me/'
data = {'avatar': 'https://example.com/avatar.jpg'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
class SpiritTests(APITestCase):
"""App端智能体测试"""
def setUp(self):
self.user = User.objects.create_user(phone='13800138020', nickname='测试用户')
login_url = '/api/v1/auth/phone-login/'
response = self.client.post(login_url, {'phone': '13800138020'}, format='json')
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
def test_create_spirit(self):
"""测试创建智能体"""
url = '/api/v1/spirits/'
data = {
'name': '测试心灵',
'prompt': '你是一个友好的助手',
'voice_id': 'voice_001'
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['name'], '测试心灵')
def test_create_spirit_minimal(self):
"""测试创建智能体 - 最少字段"""
url = '/api/v1/spirits/'
data = {'name': '简单心灵'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
def test_list_spirits(self):
"""测试获取智能体列表"""
Spirit.objects.create(user=self.user, name='心灵1')
Spirit.objects.create(user=self.user, name='心灵2')
url = '/api/v1/spirits/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 2)
def test_list_spirits_empty(self):
"""测试获取空智能体列表"""
url = '/api/v1/spirits/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data['data']), 0)
def test_get_spirit_detail(self):
"""测试获取智能体详情"""
spirit = Spirit.objects.create(user=self.user, name='详情测试')
url = f'/api/v1/spirits/{spirit.id}/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['name'], '详情测试')
def test_update_spirit(self):
"""测试更新智能体"""
spirit = Spirit.objects.create(user=self.user, name='原始名称')
url = f'/api/v1/spirits/{spirit.id}/'
data = {'name': '新名称', 'prompt': '新的提示词'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['name'], '新名称')
def test_delete_spirit(self):
"""测试删除智能体"""
spirit = Spirit.objects.create(user=self.user, name='待删除')
url = f'/api/v1/spirits/{spirit.id}/'
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(Spirit.objects.filter(id=spirit.id).exists())
def test_spirit_isolation(self):
"""测试智能体用户隔离 - 不能访问其他用户的智能体"""
other_user = User.objects.create_user(phone='13800138021')
other_spirit = Spirit.objects.create(user=other_user, name='其他用户的心灵')
url = f'/api/v1/spirits/{other_spirit.id}/'
response = self.client.get(url)
# 应该返回404或权限错误
self.assertNotEqual(response.status_code, status.HTTP_200_OK)
class DeviceTests(APITestCase):
"""App端设备测试"""
def setUp(self):
self.user = User.objects.create_user(phone='13800138050', nickname='测试用户')
login_url = '/api/v1/auth/phone-login/'
response = self.client.post(login_url, {'phone': '13800138050'}, format='json')
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
self.device_type = DeviceType.objects.create(
brand='AL',
product_code='DZBJ-ON',
name='电子吧唧-联网版'
)
self.device = Device.objects.create(
sn='AL-DZBJ-ON-25W45-A01-00001',
device_type=self.device_type,
mac_address='AA:BB:CC:DD:EE:FF',
status='in_stock'
)
def test_query_by_mac(self):
"""测试通过MAC地址查询SN码无需登录"""
self.client.credentials() # 清除认证
url = '/api/v1/devices/query-by-mac/?mac=AA:BB:CC:DD:EE:FF'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['sn'], 'AL-DZBJ-ON-25W45-A01-00001')
def test_query_by_mac_lowercase(self):
"""测试MAC地址查询 - 小写格式"""
self.client.credentials()
url = '/api/v1/devices/query-by-mac/?mac=aa:bb:cc:dd:ee:ff'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
def test_query_by_mac_dash_format(self):
"""测试MAC地址查询 - 横杠格式"""
self.client.credentials()
url = '/api/v1/devices/query-by-mac/?mac=AA-BB-CC-DD-EE-FF'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_query_by_mac_not_found(self):
"""测试MAC地址查询 - 设备不存在"""
self.client.credentials()
url = '/api/v1/devices/query-by-mac/?mac=11:22:33:44:55:66'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_query_by_mac_empty(self):
"""测试MAC地址查询 - 空参数"""
self.client.credentials()
url = '/api/v1/devices/query-by-mac/'
response = self.client.get(url)
self.assertNotEqual(response.data['code'], 0)
def test_verify_device(self):
"""测试验证设备SN"""
url = '/api/v1/devices/verify/'
data = {'sn': 'AL-DZBJ-ON-25W45-A01-00001'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertTrue(response.data['data']['is_bindable'])
def test_verify_device_not_found(self):
"""测试验证设备SN - 不存在"""
url = '/api/v1/devices/verify/'
data = {'sn': 'NOT-EXIST-SN'}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_bind_device(self):
"""测试绑定设备"""
spirit = Spirit.objects.create(user=self.user, name='测试心灵')
url = '/api/v1/devices/bind/'
data = {
'sn': 'AL-DZBJ-ON-25W45-A01-00001',
'spirit_id': spirit.id
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
# 验证设备状态
self.device.refresh_from_db()
self.assertEqual(self.device.status, 'bound')
def test_bind_device_without_spirit(self):
"""测试绑定设备 - 不绑定智能体"""
device2 = Device.objects.create(
sn='AL-DZBJ-ON-25W45-A01-00002',
device_type=self.device_type,
status='in_stock'
)
url = '/api/v1/devices/bind/'
data = {'sn': 'AL-DZBJ-ON-25W45-A01-00002'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_my_devices(self):
"""测试我的设备列表"""
UserDevice.objects.create(user=self.user, device=self.device, is_active=True)
url = '/api/v1/devices/my_devices/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 1)
def test_my_devices_empty(self):
"""测试我的设备列表 - 空列表"""
url = '/api/v1/devices/my_devices/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data['data']), 0)
def test_unbind_device(self):
"""测试解绑设备"""
user_device = UserDevice.objects.create(user=self.user, device=self.device, is_active=True)
self.device.status = 'bound'
self.device.save()
url = f'/api/v1/devices/{user_device.id}/unbind/'
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
user_device.refresh_from_db()
self.assertFalse(user_device.is_active)
def test_update_spirit_on_device(self):
"""测试更新设备绑定的智能体"""
spirit1 = Spirit.objects.create(user=self.user, name='心灵1')
spirit2 = Spirit.objects.create(user=self.user, name='心灵2')
user_device = UserDevice.objects.create(
user=self.user, device=self.device, spirit=spirit1, is_active=True
)
url = f'/api/v1/devices/{user_device.id}/update-spirit/'
data = {'spirit_id': spirit2.id}
response = self.client.put(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
user_device.refresh_from_db()
self.assertEqual(user_device.spirit_id, spirit2.id)
# ==================== 管理端测试 ====================
class AdminAuthTests(APITestCase):
"""管理端认证测试"""
def setUp(self):
self.admin = AdminUser.objects.create_user(
username='admin',
password='admin123',
role='super_admin'
)
def test_admin_login(self):
"""测试管理员登录"""
url = '/api/admin/auth/login/'
data = {'username': 'admin', 'password': 'admin123'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertIn('token', response.data['data'])
self.assertEqual(response.data['data']['admin']['username'], 'admin')
def test_admin_login_wrong_password(self):
"""测试管理员登录 - 密码错误"""
url = '/api/admin/auth/login/'
data = {'username': 'admin', 'password': 'wrong'}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_admin_login_user_not_exist(self):
"""测试管理员登录 - 用户不存在"""
url = '/api/admin/auth/login/'
data = {'username': 'notexist', 'password': 'password'}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_admin_login_disabled(self):
"""测试管理员登录 - 账户禁用"""
disabled_admin = AdminUser.objects.create_user(
username='disabled', password='pass123', is_active=False
)
url = '/api/admin/auth/login/'
data = {'username': 'disabled', 'password': 'pass123'}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_admin_token_refresh(self):
"""测试管理员Token刷新"""
login_url = '/api/admin/auth/login/'
login_response = self.client.post(
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
)
refresh_token = login_response.data['data']['token']['refresh']
refresh_url = '/api/admin/auth/refresh/'
response = self.client.post(refresh_url, {'refresh': refresh_token}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
class AdminProfileTests(APITestCase):
"""管理端个人信息测试"""
def setUp(self):
self.admin = AdminUser.objects.create_user(
username='admin', password='admin123', role='admin'
)
login_url = '/api/admin/auth/login/'
response = self.client.post(
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
)
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
def test_get_admin_profile(self):
"""测试获取管理员信息"""
url = '/api/admin/profile/me/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['username'], 'admin')
def test_change_password(self):
"""测试修改密码"""
url = '/api/admin/profile/change-password/'
data = {'old_password': 'admin123', 'new_password': 'newpass123'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# 验证新密码可以登录
self.client.credentials()
login_url = '/api/admin/auth/login/'
login_response = self.client.post(
login_url, {'username': 'admin', 'password': 'newpass123'}, format='json'
)
self.assertEqual(login_response.data['code'], 0)
def test_change_password_wrong_old(self):
"""测试修改密码 - 原密码错误"""
url = '/api/admin/profile/change-password/'
data = {'old_password': 'wrongpass', 'new_password': 'newpass123'}
response = self.client.put(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
class AdminUserManageTests(APITestCase):
"""管理员用户管理测试"""
def setUp(self):
self.super_admin = AdminUser.objects.create_user(
username='superadmin', password='super123', role='super_admin'
)
login_url = '/api/admin/auth/login/'
response = self.client.post(
login_url, {'username': 'superadmin', 'password': 'super123'}, format='json'
)
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
def test_list_admins(self):
"""测试管理员列表"""
url = '/api/admin/admins/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
def test_create_admin(self):
"""测试创建管理员"""
url = '/api/admin/admins/'
data = {
'username': 'newadmin',
'password': 'newpass123',
'name': '新管理员',
'role': 'operator'
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['username'], 'newadmin')
def test_get_admin_detail(self):
"""测试获取管理员详情"""
admin = AdminUser.objects.create_user(username='testadmin', password='test123')
url = f'/api/admin/admins/{admin.id}/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['username'], 'testadmin')
def test_toggle_admin_status(self):
"""测试启用/禁用管理员"""
admin = AdminUser.objects.create_user(username='testadmin2', password='test123')
url = f'/api/admin/admins/{admin.id}/toggle-status/'
response = self.client.post(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
admin.refresh_from_db()
self.assertFalse(admin.is_active)
def test_reset_admin_password(self):
"""测试重置管理员密码"""
admin = AdminUser.objects.create_user(username='testadmin3', password='oldpass')
url = f'/api/admin/admins/{admin.id}/reset-password/'
data = {'new_password': 'resetpass123'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
class AdminDeviceTypeTests(APITestCase):
"""管理端设备类型测试"""
def setUp(self):
self.admin = AdminUser.objects.create_user(
username='admin', password='admin123', role='super_admin'
)
login_url = '/api/admin/auth/login/'
response = self.client.post(
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
)
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
def test_create_device_type_network(self):
"""测试创建联网设备类型"""
url = '/api/admin/device-types/'
data = {
'brand': 'AL',
'product_code': 'DZBJ-ON',
'name': '电子吧唧-联网版'
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertTrue(response.data['data']['is_network_required'])
def test_create_device_type_offline(self):
"""测试创建非联网设备类型"""
url = '/api/admin/device-types/'
data = {
'brand': 'AL',
'product_code': 'DZBJ-OFF',
'name': '电子吧唧-离线版'
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(response.data['data']['is_network_required'])
def test_list_device_types(self):
"""测试设备类型列表"""
DeviceType.objects.create(brand='AL', product_code='TEST-ON', name='测试设备')
url = '/api/admin/device-types/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
def test_get_device_type_detail(self):
"""测试设备类型详情"""
dt = DeviceType.objects.create(brand='AL', product_code='TEST-DT', name='测试类型')
url = f'/api/admin/device-types/{dt.id}/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['name'], '测试类型')
def test_update_device_type(self):
"""测试更新设备类型"""
dt = DeviceType.objects.create(brand='AL', product_code='TEST-UP', name='原名称')
url = f'/api/admin/device-types/{dt.id}/'
data = {'name': '新名称'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['data']['name'], '新名称')
class AdminBatchTests(APITestCase):
"""管理端批次测试"""
def setUp(self):
self.admin = AdminUser.objects.create_user(
username='admin', password='admin123', role='super_admin'
)
login_url = '/api/admin/auth/login/'
response = self.client.post(
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
)
self.access_token = response.data['data']['token']['access']
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
self.device_type = DeviceType.objects.create(
brand='AL', product_code='DZBJ-ON', name='电子吧唧-联网版'
)
def test_create_batch_and_generate_sn(self):
"""测试创建批次并生成SN码"""
url = '/api/admin/device-batches/'
data = {
'device_type': self.device_type.id,
'batch_no': 'A01',
'production_date': '2026-01-28',
'quantity': 10,
'remark': '测试批次'
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['quantity'], 10)
self.assertIn('sn_range', response.data['data'])
# 验证SN码格式
sn_start = response.data['data']['sn_range']['start']
self.assertTrue(sn_start.startswith('AL-DZBJ-ON-'))
# 验证设备数量
device_count = Device.objects.filter(batch__batch_no='A01').count()
self.assertEqual(device_count, 10)
def test_list_batches(self):
"""测试批次列表"""
url = '/api/admin/device-batches/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
def test_get_batch_detail(self):
"""测试批次详情"""
batch = DeviceBatch.objects.create(
device_type=self.device_type,
batch_no='B01',
production_date=date(2026, 1, 28),
quantity=5
)
url = f'/api/admin/device-batches/{batch.id}/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('statistics', response.data['data'])
def test_get_batch_devices(self):
"""测试批次设备列表"""
batch = DeviceBatch.objects.create(
device_type=self.device_type,
batch_no='C01',
production_date=date(2026, 1, 28),
quantity=3
)
Device.objects.create(sn='TEST-001', batch=batch, device_type=self.device_type)
Device.objects.create(sn='TEST-002', batch=batch, device_type=self.device_type)
url = f'/api/admin/device-batches/{batch.id}/devices/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data['data']['items']), 2)
class AuthSeparationTests(APITestCase):
"""验证App和Admin认证分离"""
def setUp(self):
# 创建App用户
self.app_user = User.objects.create_user(phone='13800138099')
login_url = '/api/v1/auth/phone-login/'
response = self.client.post(login_url, {'phone': '13800138099'}, format='json')
self.app_token = response.data['data']['token']['access']
# 创建Admin用户
self.admin = AdminUser.objects.create_user(
username='admin', password='admin123', role='admin'
)
login_url = '/api/admin/auth/login/'
response = self.client.post(
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
)
self.admin_token = response.data['data']['token']['access']
def test_app_token_cannot_access_admin_api(self):
"""测试App Token无法访问管理端API"""
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.app_token}')
url = '/api/admin/device-types/'
response = self.client.get(url)
self.assertNotEqual(response.data.get('code', 0), 0)
def test_admin_token_cannot_access_app_api(self):
"""测试Admin Token无法访问App端API"""
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.admin_token}')
url = '/api/v1/users/me/'
response = self.client.get(url)
self.assertNotEqual(response.data.get('code', 0), 0)
def test_no_token_cannot_access_protected_api(self):
"""测试无Token无法访问受保护API"""
self.client.credentials()
url = '/api/v1/users/me/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
url = '/api/admin/device-types/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
# ==================== Log Center 集成测试 ====================
class LogCenterIntegrationTests(TestCase):
"""Log Center 错误上报集成测试"""
def test_report_to_log_center_function_exists(self):
"""测试 report_to_log_center 函数存在"""
from utils.exceptions import report_to_log_center
self.assertTrue(callable(report_to_log_center))
def test_report_to_log_center_with_exception(self):
"""测试上报异常到 Log Center"""
from utils.exceptions import report_to_log_center
from unittest.mock import patch, MagicMock
# 创建测试异常
try:
raise ValueError("Test error message")
except ValueError as e:
test_exc = e
# Mock context
mock_request = MagicMock()
mock_request.path = '/api/test/'
mock_request.method = 'GET'
context = {'request': mock_request, 'view': 'TestView'}
# Mock requests.post
with patch('utils.exceptions.requests.post') as mock_post:
with patch('utils.exceptions.LOG_CENTER_ENABLED', True):
report_to_log_center(test_exc, context)
# 等待线程执行(测试中同步执行更可靠)
import time
time.sleep(0.5)
def test_report_payload_structure(self):
"""测试上报 payload 结构正确"""
from utils.exceptions import report_to_log_center
from unittest.mock import patch, MagicMock
import json
captured_payload = None
def capture_post(url, json=None, timeout=None):
nonlocal captured_payload
captured_payload = json
return MagicMock(status_code=200)
try:
raise TypeError("Type mismatch error")
except TypeError as e:
test_exc = e
mock_request = MagicMock()
mock_request.path = '/api/users/me/'
mock_request.method = 'POST'
context = {'request': mock_request, 'view': 'UserView'}
with patch('utils.exceptions.requests.post', side_effect=capture_post):
with patch('utils.exceptions.LOG_CENTER_ENABLED', True):
with patch('utils.exceptions.threading.Thread') as mock_thread:
# 直接调用 target 函数而不是启动线程
mock_thread_instance = MagicMock()
def run_target(*args, **kwargs):
target = kwargs.get('target') or args[0]
target()
mock_thread.side_effect = lambda *args, **kwargs: MagicMock(
start=lambda: run_target(*args, **kwargs),
daemon=True
)
report_to_log_center(test_exc, context)
# 验证 payload 结构
if captured_payload:
self.assertEqual(captured_payload['project_id'], 'rtc_backend')
self.assertEqual(captured_payload['level'], 'ERROR')
self.assertEqual(captured_payload['error']['type'], 'TypeError')
self.assertEqual(captured_payload['error']['message'], 'Type mismatch error')
self.assertIn('stack_trace', captured_payload['error'])
self.assertEqual(captured_payload['context']['url'], '/api/users/me/')
self.assertEqual(captured_payload['context']['method'], 'POST')
def test_report_disabled_when_flag_off(self):
"""测试关闭开关时不上报"""
from utils.exceptions import report_to_log_center
from unittest.mock import patch, MagicMock
try:
raise Exception("Should not be reported")
except Exception as e:
test_exc = e
with patch('utils.exceptions.requests.post') as mock_post:
with patch('utils.exceptions.LOG_CENTER_ENABLED', False):
report_to_log_center(test_exc, {})
# 应该不调用 requests.post
mock_post.assert_not_called()
def test_report_silent_failure(self):
"""测试上报失败不抛异常"""
from utils.exceptions import report_to_log_center
from unittest.mock import patch, MagicMock
try:
raise Exception("Test exception")
except Exception as e:
test_exc = e
with patch('utils.exceptions.requests.post', side_effect=Exception("Network error")):
with patch('utils.exceptions.LOG_CENTER_ENABLED', True):
# 不应抛出异常
try:
report_to_log_center(test_exc, {})
except Exception:
self.fail("report_to_log_center should not raise exceptions")
class ExceptionHandlerIntegrationTests(APITestCase):
"""异常处理器集成测试 - 验证异常时触发 Log Center 上报"""
def test_exception_triggers_log_center_report(self):
"""测试异常触发 Log Center 上报"""
from utils.exceptions import custom_exception_handler
from unittest.mock import patch, MagicMock
# 创建异常
test_exc = ValueError("Database connection failed")
# Mock context
mock_request = MagicMock()
mock_request.path = '/api/test/'
mock_request.method = 'GET'
context = {'request': mock_request, 'view': MagicMock()}
with patch('utils.exceptions.report_to_log_center') as mock_report:
# 调用异常处理器
custom_exception_handler(test_exc, context)
# 验证调用了上报函数
mock_report.assert_called_once()
call_args = mock_report.call_args
self.assertEqual(call_args[0][0], test_exc)
def test_business_exception_not_reported(self):
"""测试业务异常不上报到 Log Center"""
from utils.exceptions import custom_exception_handler, BusinessException
from unittest.mock import patch, MagicMock
# 创建业务异常
biz_exc = BusinessException(code=100, message="用户不存在")
context = {'request': MagicMock(), 'view': MagicMock()}
with patch('utils.exceptions.report_to_log_center') as mock_report:
custom_exception_handler(biz_exc, context)
# 业务异常不应触发上报
mock_report.assert_not_called()
# ==================== 故事模块测试 ====================
class StoryTestBase(APITestCase):
"""故事模块测试基类"""
def setUp(self):
self.user = User.objects.create_user(phone='13800130001', nickname='故事测试用户')
tokens = get_app_tokens(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
# 创建默认书架
self.shelf = StoryShelf.objects.create(user=self.user, name='我的书架')
class StoryShelfTests(StoryTestBase):
"""书架接口测试"""
def test_list_shelves(self):
"""测试获取书架列表"""
url = '/api/v1/stories/shelves/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 1)
self.assertEqual(response.data['data'][0]['name'], '我的书架')
def test_list_shelves_auto_create_default(self):
"""测试首次查询自动创建默认书架"""
new_user = User.objects.create_user(phone='13800130099')
tokens = get_app_tokens(new_user)
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
url = '/api/v1/stories/shelves/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 1)
self.assertEqual(response.data['data'][0]['name'], '我的书架')
def test_list_shelves_includes_story_count(self):
"""测试书架列表包含故事数量"""
Story.objects.create(
user=self.user, shelf=self.shelf,
title='测试故事', content='内容'
)
url = '/api/v1/stories/shelves/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data'][0]['story_count'], 1)
def test_create_shelf(self):
"""测试创建书架"""
url = '/api/v1/stories/shelves/'
data = {'name': '新书架'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['name'], '新书架')
self.assertEqual(response.data['data']['capacity'], 10)
def test_delete_shelf(self):
"""测试删除书架"""
shelf = StoryShelf.objects.create(user=self.user, name='待删除书架')
story = Story.objects.create(
user=self.user, shelf=shelf,
title='测试', content='内容'
)
url = f'/api/v1/stories/shelves/{shelf.id}/'
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(StoryShelf.objects.filter(id=shelf.id).exists())
# 故事保留shelf_id 置 null
story.refresh_from_db()
self.assertIsNone(story.shelf)
def test_delete_shelf_not_found(self):
"""测试删除不存在的书架"""
url = '/api/v1/stories/shelves/99999/'
response = self.client.delete(url)
self.assertNotEqual(response.data['code'], 0)
def test_shelf_isolation(self):
"""测试书架用户隔离"""
other_user = User.objects.create_user(phone='13800130002')
other_shelf = StoryShelf.objects.create(user=other_user, name='别人的书架')
url = f'/api/v1/stories/shelves/{other_shelf.id}/'
response = self.client.delete(url)
self.assertNotEqual(response.data['code'], 0)
# 确认没被删除
self.assertTrue(StoryShelf.objects.filter(id=other_shelf.id).exists())
class ShelfUnlockTests(StoryTestBase):
"""书架解锁测试"""
def test_unlock_shelf_success(self):
"""测试积分解锁书架 - 成功"""
self.user.points = 200
self.user.save(update_fields=['points'])
url = '/api/v1/stories/shelves/unlock/'
response = self.client.post(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['remaining_points'], 100)
self.assertIn('shelf', response.data['data'])
# 验证积分扣除
self.user.refresh_from_db()
self.assertEqual(self.user.points, 100)
# 验证积分流水
record = PointsRecord.objects.filter(user=self.user).first()
self.assertIsNotNone(record)
self.assertEqual(record.amount, -100)
self.assertEqual(record.type, 'unlock_shelf')
def test_unlock_shelf_not_enough_points(self):
"""测试积分解锁书架 - 积分不足"""
self.user.points = 50
self.user.save(update_fields=['points'])
url = '/api/v1/stories/shelves/unlock/'
response = self.client.post(url, format='json')
self.assertEqual(response.data['code'], 603) # POINTS_NOT_ENOUGH
# 积分不应变化
self.user.refresh_from_db()
self.assertEqual(self.user.points, 50)
def test_unlock_shelf_zero_points(self):
"""测试积分解锁书架 - 零积分"""
self.user.points = 0
self.user.save(update_fields=['points'])
url = '/api/v1/stories/shelves/unlock/'
response = self.client.post(url, format='json')
self.assertEqual(response.data['code'], 603)
def test_unlock_shelf_naming(self):
"""测试解锁书架自动命名"""
self.user.points = 500
self.user.save(update_fields=['points'])
url = '/api/v1/stories/shelves/unlock/'
# 第一次解锁已有1个默认书架
response = self.client.post(url, format='json')
self.assertEqual(response.data['data']['shelf']['name'], '书架 2')
# 第二次解锁
response = self.client.post(url, format='json')
self.assertEqual(response.data['data']['shelf']['name'], '书架 3')
class StoryTests(StoryTestBase):
"""故事接口测试"""
def test_list_stories(self):
"""测试获取故事列表"""
Story.objects.create(
user=self.user, shelf=self.shelf,
title='故事1', content='内容1'
)
Story.objects.create(
user=self.user, shelf=self.shelf,
title='故事2', content='内容2'
)
url = '/api/v1/stories/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['total'], 2)
self.assertEqual(len(response.data['data']['items']), 2)
def test_list_stories_filter_by_shelf(self):
"""测试按书架筛选故事"""
shelf2 = StoryShelf.objects.create(user=self.user, name='书架2')
Story.objects.create(
user=self.user, shelf=self.shelf,
title='书架1故事', content='内容'
)
Story.objects.create(
user=self.user, shelf=shelf2,
title='书架2故事', content='内容'
)
url = f'/api/v1/stories/?shelf_id={self.shelf.id}'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['total'], 1)
self.assertEqual(response.data['data']['items'][0]['title'], '书架1故事')
def test_list_stories_empty(self):
"""测试空故事列表"""
url = '/api/v1/stories/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['total'], 0)
def test_create_story(self):
"""测试保存故事"""
url = '/api/v1/stories/'
data = {
'title': '新故事',
'content': '这是故事内容',
'shelf_id': self.shelf.id,
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['title'], '新故事')
self.assertEqual(response.data['data']['content'], '这是故事内容')
def test_create_story_with_optional_fields(self):
"""测试保存故事 - 包含可选字段"""
url = '/api/v1/stories/'
data = {
'title': '完整故事',
'content': '故事正文',
'shelf_id': self.shelf.id,
'cover_url': 'https://example.com/cover.jpg',
'generation_mode': 'ai',
'prompt': '角色=小猫, 场景=森林',
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['generation_mode'], 'ai')
def test_create_story_shelf_full(self):
"""测试保存故事 - 书架已满"""
# 创建小容量书架
small_shelf = StoryShelf.objects.create(
user=self.user, name='小书架', capacity=2
)
Story.objects.create(user=self.user, shelf=small_shelf, title='故事1', content='内容')
Story.objects.create(user=self.user, shelf=small_shelf, title='故事2', content='内容')
url = '/api/v1/stories/'
data = {
'title': '溢出故事',
'content': '这本放不下了',
'shelf_id': small_shelf.id,
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.data['code'], 604) # SHELF_FULL
def test_create_story_shelf_not_found(self):
"""测试保存故事 - 书架不存在"""
url = '/api/v1/stories/'
data = {
'title': '故事',
'content': '内容',
'shelf_id': 99999,
}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_create_story_other_user_shelf(self):
"""测试保存故事到他人书架"""
other_user = User.objects.create_user(phone='13800130003')
other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架')
url = '/api/v1/stories/'
data = {
'title': '故事',
'content': '内容',
'shelf_id': other_shelf.id,
}
response = self.client.post(url, data, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_delete_story(self):
"""测试删除故事"""
story = Story.objects.create(
user=self.user, shelf=self.shelf,
title='待删除', content='内容'
)
url = f'/api/v1/stories/{story.id}/'
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertFalse(Story.objects.filter(id=story.id).exists())
def test_delete_story_not_found(self):
"""测试删除不存在的故事"""
url = '/api/v1/stories/99999/'
response = self.client.delete(url)
self.assertEqual(response.data['code'], 600) # STORY_NOT_FOUND
def test_story_isolation(self):
"""测试故事用户隔离"""
other_user = User.objects.create_user(phone='13800130004')
other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架')
other_story = Story.objects.create(
user=other_user, shelf=other_shelf,
title='他人故事', content='内容'
)
url = f'/api/v1/stories/{other_story.id}/'
response = self.client.delete(url)
self.assertNotEqual(response.data['code'], 0)
self.assertTrue(Story.objects.filter(id=other_story.id).exists())
def test_story_capacity_limit(self):
"""测试书架容量为10的限制"""
# 默认书架容量 = 10
for i in range(10):
Story.objects.create(
user=self.user, shelf=self.shelf,
title=f'故事{i+1}', content=f'内容{i+1}'
)
url = '/api/v1/stories/'
data = {
'title': '第11本',
'content': '超出容量',
'shelf_id': self.shelf.id,
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.data['code'], 604) # SHELF_FULL
class StoryGenerateTests(StoryTestBase):
"""故事生成接口测试"""
def test_generate_story_returns_sse(self):
"""测试生成故事返回 SSE 流"""
from unittest.mock import patch
mock_events = [
'event: stage\ndata: {"stage":"connecting","progress":0,"message":"正在收集灵感碎片..."}\n\n',
'event: done\ndata: {"stage":"done","progress":100,"title":"测试故事","content":"故事内容"}\n\n',
]
with patch('apps.stories.services.llm_service.generate_story_stream') as mock_gen:
mock_gen.return_value = iter(mock_events)
url = '/api/v1/stories/generate/'
data = {
'characters': ['小猫'],
'scenes': ['森林'],
'props': ['魔法棒'],
}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response['Content-Type'], 'text/event-stream')
def test_generate_story_empty_params(self):
"""测试生成故事 - 空参数(允许,有默认值)"""
from unittest.mock import patch
mock_events = [
'event: done\ndata: {"stage":"done","progress":100,"title":"默认故事","content":"内容"}\n\n',
]
with patch('apps.stories.services.llm_service.generate_story_stream') as mock_gen:
mock_gen.return_value = iter(mock_events)
url = '/api/v1/stories/generate/'
response = self.client.post(url, {}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
class StoryTTSTests(StoryTestBase):
"""TTS 音频接口测试"""
def test_tts_check_no_audio(self):
"""测试查询音频状态 - 无音频"""
story = Story.objects.create(
user=self.user, shelf=self.shelf,
title='无音频故事', content='内容'
)
url = f'/api/v1/stories/{story.id}/tts/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertFalse(response.data['data']['exists'])
self.assertEqual(response.data['data']['audio_url'], '')
def test_tts_check_has_audio(self):
"""测试查询音频状态 - 有音频"""
story = Story.objects.create(
user=self.user, shelf=self.shelf,
title='有音频故事', content='内容',
audio_url='https://oss.example.com/audio.mp3'
)
url = f'/api/v1/stories/{story.id}/tts/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertTrue(response.data['data']['exists'])
self.assertEqual(
response.data['data']['audio_url'],
'https://oss.example.com/audio.mp3'
)
def test_tts_generate_returns_sse(self):
"""测试生成 TTS 返回 SSE 流"""
from unittest.mock import patch
story = Story.objects.create(
user=self.user, shelf=self.shelf,
title='TTS测试', content='这是要转换的故事内容'
)
mock_events = [
'event: stage\ndata: {"stage":"connecting","message":"正在连接..."}\n\n',
'event: done\ndata: {"stage":"done","audio_url":"https://oss.example.com/audio.mp3"}\n\n',
]
with patch('apps.stories.services.tts_service.generate_tts_stream') as mock_tts:
mock_tts.return_value = iter(mock_events)
url = f'/api/v1/stories/{story.id}/tts/'
response = self.client.post(url, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response['Content-Type'], 'text/event-stream')
def test_tts_story_not_found(self):
"""测试 TTS - 故事不存在"""
url = '/api/v1/stories/99999/tts/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 600) # STORY_NOT_FOUND
def test_tts_story_isolation(self):
"""测试 TTS - 不能访问他人故事"""
other_user = User.objects.create_user(phone='13800130005')
other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架')
other_story = Story.objects.create(
user=other_user, shelf=other_shelf,
title='他人故事', content='内容'
)
url = f'/api/v1/stories/{other_story.id}/tts/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 600)
class PointsTests(StoryTestBase):
"""积分接口测试"""
def test_query_points(self):
"""测试查询积分余额"""
self.user.points = 500
self.user.save(update_fields=['points'])
url = '/api/v1/users/points/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['points'], 500)
def test_query_points_default_zero(self):
"""测试查询积分余额 - 默认为0"""
url = '/api/v1/users/points/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['points'], 0)
def test_points_records_list(self):
"""测试积分流水记录"""
PointsRecord.objects.create(
user=self.user, amount=100,
type='reward', description='注册奖励'
)
PointsRecord.objects.create(
user=self.user, amount=-100,
type='unlock_shelf', description='解锁书架'
)
url = '/api/v1/users/points/records/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['total'], 2)
self.assertEqual(len(response.data['data']['items']), 2)
def test_points_records_empty(self):
"""测试积分流水记录 - 空"""
url = '/api/v1/users/points/records/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['total'], 0)
def test_points_records_pagination(self):
"""测试积分流水分页"""
for i in range(5):
PointsRecord.objects.create(
user=self.user, amount=10,
type='reward', description=f'奖励{i+1}'
)
url = '/api/v1/users/points/records/?page=1&page_size=3'
response = self.client.get(url)
self.assertEqual(response.data['data']['total'], 5)
self.assertEqual(len(response.data['data']['items']), 3)
def test_points_records_isolation(self):
"""测试积分流水用户隔离"""
other_user = User.objects.create_user(phone='13800130006')
PointsRecord.objects.create(
user=other_user, amount=100,
type='reward', description='他人的奖励'
)
url = '/api/v1/users/points/records/'
response = self.client.get(url)
self.assertEqual(response.data['data']['total'], 0)
class LLMServiceTests(TestCase):
"""LLM 服务单元测试"""
def test_build_user_prompt(self):
"""测试构建用户提示词"""
from apps.stories.services.llm_service import build_user_prompt
prompt = build_user_prompt(['小猫', '小狗'], ['森林'], ['魔法棒'])
self.assertIn('小猫', prompt)
self.assertIn('小狗', prompt)
self.assertIn('森林', prompt)
self.assertIn('魔法棒', prompt)
def test_build_user_prompt_partial(self):
"""测试构建提示词 - 部分参数"""
from apps.stories.services.llm_service import build_user_prompt
prompt = build_user_prompt(['公主'], [], [])
self.assertIn('公主', prompt)
self.assertNotIn('场景', prompt)
self.assertNotIn('道具', prompt)
def test_parse_story_json_valid(self):
"""测试解析故事 JSON - 有效"""
from apps.stories.services.llm_service import _parse_story_json
text = '{"title": "小猫冒险", "content": "从前有一只小猫..."}'
result = _parse_story_json(text)
self.assertEqual(result['title'], '小猫冒险')
self.assertEqual(result['content'], '从前有一只小猫...')
def test_parse_story_json_with_markdown(self):
"""测试解析故事 JSON - 包含 markdown 代码块"""
from apps.stories.services.llm_service import _parse_story_json
text = '```json\n{"title": "森林故事", "content": "在深深的森林里..."}\n```'
result = _parse_story_json(text)
self.assertEqual(result['title'], '森林故事')
def test_parse_story_json_invalid(self):
"""测试解析故事 JSON - 无效 JSON"""
from apps.stories.services.llm_service import _parse_story_json
text = '这不是一个有效的 JSON 格式的文本'
result = _parse_story_json(text)
self.assertEqual(result['title'], '新故事')
self.assertIn('这不是', result['content'])
def test_sse_event_format(self):
"""测试 SSE 事件格式化"""
from apps.stories.services.llm_service import _sse_event
event = _sse_event('stage', {'stage': 'connecting', 'progress': 0})
self.assertTrue(event.startswith('event: stage\n'))
self.assertIn('data: ', event)
self.assertTrue(event.endswith('\n\n'))
def test_generate_stream_without_api_key(self):
"""测试未配置 API Key 时返回错误事件"""
from apps.stories.services.llm_service import generate_story_stream
from unittest.mock import patch
with patch('apps.stories.services.llm_service.settings') as mock_settings:
mock_settings.LLM_CONFIG = {'API_KEY': '', 'API_BASE_URL': '', 'MODEL_NAME': ''}
events = list(generate_story_stream(['小猫'], [], []))
self.assertEqual(len(events), 1)
self.assertIn('error', events[0])
self.assertIn('未配置', events[0])
# ==================== 音乐模块测试 ====================
class MusicTestBase(APITestCase):
"""音乐模块测试基类"""
def setUp(self):
self.user = User.objects.create_user(phone='13800140001', nickname='音乐测试用户')
tokens = get_app_tokens(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
class MusicPlaylistTests(MusicTestBase):
"""播放列表接口测试"""
def test_playlist_auto_create_defaults(self):
"""测试首次获取播放列表自动创建 3 首默认曲目"""
url = '/api/v1/music/playlist/'
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
playlist = response.data['data']['playlist']
self.assertEqual(len(playlist), 3)
# 验证都是默认曲目
for t in playlist:
self.assertTrue(t['is_default'])
self.assertEqual(t['generation_status'], 'completed')
def test_playlist_default_track_titles(self):
"""测试默认曲目标题"""
url = '/api/v1/music/playlist/'
response = self.client.get(url)
titles = [t['title'] for t in response.data['data']['playlist']]
self.assertIn('卡皮巴拉蹦蹦蹦', titles)
self.assertIn('卡皮巴拉快乐水', titles)
self.assertIn('卡皮巴拉快乐营业', titles)
def test_playlist_defaults_not_duplicated(self):
"""测试多次请求不重复创建默认曲目"""
url = '/api/v1/music/playlist/'
self.client.get(url)
self.client.get(url)
count = Track.objects.filter(user=self.user, is_default=True).count()
self.assertEqual(count, 3)
def test_playlist_ordering(self):
"""测试播放列表排序:用户歌曲在前,默认歌曲在后"""
# 先创建默认曲目
url = '/api/v1/music/playlist/'
self.client.get(url)
# 创建用户歌曲
Track.objects.create(
user=self.user, title='我的歌', is_default=False,
generation_status='completed'
)
response = self.client.get(url)
playlist = response.data['data']['playlist']
self.assertEqual(len(playlist), 4)
# 第一首应该是用户歌曲
self.assertEqual(playlist[0]['title'], '我的歌')
self.assertFalse(playlist[0]['is_default'])
# 后三首是默认曲目
for t in playlist[1:]:
self.assertTrue(t['is_default'])
def test_playlist_user_isolation(self):
"""测试播放列表用户隔离"""
other_user = User.objects.create_user(phone='13800140099')
Track.objects.create(
user=other_user, title='他人的歌',
generation_status='completed'
)
url = '/api/v1/music/playlist/'
response = self.client.get(url)
titles = [t['title'] for t in response.data['data']['playlist']]
self.assertNotIn('他人的歌', titles)
def test_playlist_default_tracks_have_audio_url(self):
"""测试默认曲目有 audio_url"""
url = '/api/v1/music/playlist/'
response = self.client.get(url)
for t in response.data['data']['playlist']:
self.assertTrue(t['audio_url'])
self.assertIn('qy-rtc', t['audio_url'])
def test_playlist_default_tracks_have_lyrics(self):
"""测试默认曲目有歌词"""
url = '/api/v1/music/playlist/'
response = self.client.get(url)
for t in response.data['data']['playlist']:
self.assertTrue(t['lyrics'])
class MusicDeleteTests(MusicTestBase):
"""删除音乐接口测试"""
def test_delete_user_track(self):
"""测试删除用户生成的曲目"""
track = Track.objects.create(
user=self.user, title='待删除歌曲',
generation_status='completed'
)
url = f'/api/v1/music/{track.id}/'
response = self.client.delete(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['code'], 0)
self.assertFalse(Track.objects.filter(id=track.id).exists())
def test_delete_default_track_rejected(self):
"""测试删除默认曲目 - 应被拒绝"""
track = Track.objects.create(
user=self.user, title='默认歌曲',
is_default=True, generation_status='completed'
)
url = f'/api/v1/music/{track.id}/'
response = self.client.delete(url)
self.assertEqual(response.data['code'], 703) # MUSIC_DEFAULT_UNDELETABLE
self.assertTrue(Track.objects.filter(id=track.id).exists())
def test_delete_track_not_found(self):
"""测试删除不存在的曲目"""
url = '/api/v1/music/99999/'
response = self.client.delete(url)
self.assertEqual(response.data['code'], 700) # TRACK_NOT_FOUND
def test_delete_other_user_track(self):
"""测试删除他人的曲目"""
other_user = User.objects.create_user(phone='13800140002')
other_track = Track.objects.create(
user=other_user, title='他人歌曲',
generation_status='completed'
)
url = f'/api/v1/music/{other_track.id}/'
response = self.client.delete(url)
self.assertEqual(response.data['code'], 700)
self.assertTrue(Track.objects.filter(id=other_track.id).exists())
class MusicFavoriteTests(MusicTestBase):
"""收藏接口测试"""
def test_favorite_toggle(self):
"""测试收藏/取消收藏"""
track = Track.objects.create(
user=self.user, title='测试歌曲',
generation_status='completed'
)
url = f'/api/v1/music/{track.id}/favorite/'
# 收藏
response = self.client.post(url)
self.assertEqual(response.data['code'], 0)
self.assertTrue(response.data['data']['is_favorite'])
# 取消收藏
response = self.client.post(url)
self.assertEqual(response.data['code'], 0)
self.assertFalse(response.data['data']['is_favorite'])
def test_favorite_not_found(self):
"""测试收藏不存在的曲目"""
url = '/api/v1/music/99999/favorite/'
response = self.client.post(url)
self.assertEqual(response.data['code'], 700)
class MusicGenerateTests(MusicTestBase):
"""音乐生成接口测试"""
def test_generate_points_not_enough(self):
"""测试生成音乐 - 积分不足"""
self.user.points = 50
self.user.save(update_fields=['points'])
url = '/api/v1/music/generate/'
data = {'text': '开心的一天', 'mood': 'happy'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.data['code'], 603) # POINTS_NOT_ENOUGH
# 积分不应变化
self.user.refresh_from_db()
self.assertEqual(self.user.points, 50)
def test_generate_zero_points(self):
"""测试生成音乐 - 零积分"""
url = '/api/v1/music/generate/'
data = {'text': '开心的一天', 'mood': 'happy'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.data['code'], 603)
def test_generate_returns_sse(self):
"""测试生成音乐返回 SSE 流"""
self.user.points = 200
self.user.save(update_fields=['points'])
mock_events = [
'data: {"stage":"lyrics","progress":10,"message":"AI 正在创作词曲..."}\n\n',
'data: {"stage":"done","progress":100,"message":"新歌出炉!","track_id":1}\n\n',
]
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
mock_gen.return_value = iter(mock_events)
url = '/api/v1/music/generate/'
data = {'text': '开心的一天', 'mood': 'happy'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response['Content-Type'], 'text/event-stream')
def test_generate_deducts_points(self):
"""测试生成音乐扣除积分"""
self.user.points = 200
self.user.save(update_fields=['points'])
mock_events = [
'data: {"stage":"done","progress":100}\n\n',
]
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
mock_gen.return_value = iter(mock_events)
url = '/api/v1/music/generate/'
data = {'text': '开心的一天', 'mood': 'happy'}
self.client.post(url, data, format='json')
# 验证积分扣除
self.user.refresh_from_db()
self.assertEqual(self.user.points, 100)
# 验证积分流水
record = PointsRecord.objects.filter(
user=self.user, type='generate_music'
).first()
self.assertIsNotNone(record)
self.assertEqual(record.amount, -100)
def test_generate_creates_track(self):
"""测试生成音乐创建 Track 记录"""
self.user.points = 200
self.user.save(update_fields=['points'])
mock_events = [
'data: {"stage":"done","progress":100}\n\n',
]
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
mock_gen.return_value = iter(mock_events)
url = '/api/v1/music/generate/'
data = {'text': '开心的一天', 'mood': 'happy'}
self.client.post(url, data, format='json')
track = Track.objects.filter(user=self.user, is_default=False).first()
self.assertIsNotNone(track)
self.assertEqual(track.mood, 'happy')
self.assertEqual(track.prompt, '开心的一天')
self.assertEqual(track.generation_status, 'generating')
def test_generate_empty_text_allowed(self):
"""测试生成音乐 - 空 textrandom 模式允许)"""
self.user.points = 200
self.user.save(update_fields=['points'])
mock_events = [
'data: {"stage":"done","progress":100}\n\n',
]
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
mock_gen.return_value = iter(mock_events)
url = '/api/v1/music/generate/'
data = {'mood': 'random'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response['Content-Type'], 'text/event-stream')
class MusicServiceTests(TestCase):
"""音乐生成服务单元测试"""
def test_clean_lyrics(self):
"""测试歌词清洗"""
from apps.music.services.music_generation_service import clean_lyrics
raw = '[verse 1]\n泡在温泉里\n[chorus]\n咔咔咔咔\n[outro]\n(水花声...)'
cleaned = clean_lyrics(raw)
self.assertNotIn('[verse', cleaned)
self.assertNotIn('[chorus]', cleaned)
self.assertIn('泡在温泉里', cleaned)
self.assertIn('咔咔咔咔', cleaned)
def test_clean_lyrics_empty(self):
"""测试清洗空歌词"""
from apps.music.services.music_generation_service import clean_lyrics
self.assertEqual(clean_lyrics(''), '')
self.assertEqual(clean_lyrics(None), None)
def test_sse_event_format(self):
"""测试 SSE 事件格式"""
from apps.music.services.music_generation_service import sse_event
event = sse_event({"stage": "lyrics", "progress": 10})
self.assertTrue(event.startswith('data: '))
self.assertTrue(event.endswith('\n\n'))
data = json.loads(event.replace('data: ', '').strip())
self.assertEqual(data['stage'], 'lyrics')
self.assertEqual(data['progress'], 10)
def test_parse_llm_json_valid(self):
"""测试解析有效 JSON"""
from apps.music.services.music_generation_service import _parse_llm_json
text = '{"song_title": "咔咔之歌", "style": "Pop music", "lyrics": "[verse]\\n泡温泉"}'
result = _parse_llm_json(text)
self.assertEqual(result['song_title'], '咔咔之歌')
self.assertEqual(result['style'], 'Pop music')
def test_parse_llm_json_with_markdown(self):
"""测试解析 markdown 包裹的 JSON"""
from apps.music.services.music_generation_service import _parse_llm_json
text = '```json\n{"song_title": "温泉曲", "style": "Lofi", "lyrics": "la la la"}\n```'
result = _parse_llm_json(text)
self.assertEqual(result['song_title'], '温泉曲')
def test_parse_llm_json_invalid_fallback(self):
"""测试解析无效 JSON 时的正则回退"""
from apps.music.services.music_generation_service import _parse_llm_json
# Malformed JSON with unescaped newlines in string values
text = '{"song_title": "测试歌", "style": "Pop", "lyrics": "line1\nline2"}'
# json.loads should handle this since \n is valid in JSON
result = _parse_llm_json(text)
self.assertIn('song_title', result)
def test_generate_stream_without_api_key(self):
"""测试未配置 API Key 时退还积分"""
from apps.music.services.music_generation_service import generate_music_stream
user = User.objects.create_user(phone='13800149001', nickname='测试')
user.points = 100
user.save(update_fields=['points'])
track = Track.objects.create(
user=user, title='测试', generation_status='generating'
)
with patch('apps.music.services.music_generation_service._get_api_key', return_value=''):
events = list(generate_music_stream(user, track, '测试', 'happy'))
self.assertEqual(len(events), 1)
self.assertIn('error', events[0])
# SSE uses ensure_ascii=True, so check the decoded JSON
event_data = json.loads(events[0].replace('data: ', '').strip())
self.assertIn('未配置', event_data['message'])
# 验证积分退还
user.refresh_from_db()
self.assertEqual(user.points, 200) # 100 original + 100 refunded
track.refresh_from_db()
self.assertEqual(track.generation_status, 'failed')
def test_refund_points(self):
"""测试积分退还"""
from apps.music.services.music_generation_service import _refund_points
user = User.objects.create_user(phone='13800149002', nickname='退款测试')
user.points = 50
user.save(update_fields=['points'])
track = Track.objects.create(
user=user, title='失败歌曲', generation_status='generating'
)
_refund_points(user, track)
user.refresh_from_db()
self.assertEqual(user.points, 150)
track.refresh_from_db()
self.assertEqual(track.generation_status, 'failed')
# 验证退款记录
record = PointsRecord.objects.filter(
user=user, type='refund_music'
).first()
self.assertIsNotNone(record)
self.assertEqual(record.amount, 100)
class DefaultTracksTests(TestCase):
"""默认曲目初始化测试"""
def test_ensure_default_tracks_creates_three(self):
"""测试 ensure_default_tracks 创建 3 首"""
from apps.music.utils import ensure_default_tracks
user = User.objects.create_user(phone='13800149010')
ensure_default_tracks(user)
count = Track.objects.filter(user=user, is_default=True).count()
self.assertEqual(count, 3)
def test_ensure_default_tracks_idempotent(self):
"""测试 ensure_default_tracks 幂等"""
from apps.music.utils import ensure_default_tracks
user = User.objects.create_user(phone='13800149011')
ensure_default_tracks(user)
ensure_default_tracks(user)
count = Track.objects.filter(user=user, is_default=True).count()
self.assertEqual(count, 3)
def test_default_tracks_have_all_fields(self):
"""测试默认曲目字段完整"""
from apps.music.utils import ensure_default_tracks
user = User.objects.create_user(phone='13800149012')
ensure_default_tracks(user)
for track in Track.objects.filter(user=user, is_default=True):
self.assertTrue(track.title)
self.assertTrue(track.lyrics)
self.assertTrue(track.audio_url)
self.assertTrue(track.cover_url)
self.assertEqual(track.generation_status, 'completed')
class MigrateHistoricalTracksTests(TestCase):
"""存量数据迁移命令测试"""
def setUp(self):
self.user = User.objects.create_user(phone='18682237028')
def test_creates_15_tracks(self):
"""测试迁移命令创建 15 首历史曲目"""
from django.core.management import call_command
from io import StringIO
out = StringIO()
call_command('migrate_historical_tracks', stdout=out)
count = Track.objects.filter(user=self.user, is_default=False).count()
self.assertEqual(count, 15)
def test_idempotent(self):
"""测试迁移命令幂等(重复执行不重复创建)"""
from django.core.management import call_command
from io import StringIO
out = StringIO()
call_command('migrate_historical_tracks', stdout=out)
call_command('migrate_historical_tracks', stdout=out)
count = Track.objects.filter(user=self.user, is_default=False).count()
self.assertEqual(count, 15)
def test_dry_run(self):
"""测试 dry-run 不写入数据库"""
from django.core.management import call_command
from io import StringIO
out = StringIO()
call_command('migrate_historical_tracks', '--dry-run', stdout=out)
count = Track.objects.filter(user=self.user).count()
self.assertEqual(count, 0)
self.assertIn('dry-run', out.getvalue())
def test_tracks_have_oss_urls(self):
"""测试所有迁移曲目都有 OSS URL"""
from django.core.management import call_command
from io import StringIO
call_command('migrate_historical_tracks', stdout=StringIO())
for track in Track.objects.filter(user=self.user):
self.assertTrue(track.audio_url.startswith('https://qy-rtc.oss-cn-beijing.aliyuncs.com/'))
self.assertTrue(track.cover_url.startswith('https://qy-rtc.oss-cn-beijing.aliyuncs.com/'))
# ==================== 角色记忆测试 ====================
class RoleMemoryTests(APITestCase):
"""角色记忆功能测试"""
def setUp(self):
self.user = User.objects.create_user(phone='13800139000', nickname='记忆测试用户')
tokens = get_app_tokens(self.user)
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
self.device_type, _ = DeviceType.objects.get_or_create(
product_code='KPBL-ON-RM',
defaults={
'brand': 'AL',
'name': '卡皮巴拉-联网版',
'default_prompt': '你是一只可爱的卡皮巴拉',
'default_voice_id': 'voice_kpbl_01',
}
)
# 确保模板字段正确get_or_create 可能返回已有记录)
if not self.device_type.default_prompt:
self.device_type.default_prompt = '你是一只可爱的卡皮巴拉'
self.device_type.default_voice_id = 'voice_kpbl_01'
self.device_type.save()
self.device, _ = Device.objects.get_or_create(
sn='AL-KPBL-ON-25W01-RM-00001',
defaults={
'device_type': self.device_type,
'status': 'in_stock',
}
)
# 重置设备状态
self.device.status = 'in_stock'
self.device.save()
# 清理旧绑定关系和角色记忆
UserDevice.objects.filter(user=self.user).delete()
RoleMemory.objects.filter(user=self.user).delete()
def test_bind_creates_role_memory(self):
"""测试绑定设备自动创建角色记忆"""
url = '/api/v1/devices/bind/'
data = {'sn': 'AL-KPBL-ON-25W01-RM-00001'}
response = self.client.post(url, data, format='json')
self.assertEqual(response.data['code'], 0)
self.assertIsNotNone(response.data['data']['role_memory'])
rm_data = response.data['data']['role_memory']
self.assertEqual(rm_data['prompt'], '你是一只可爱的卡皮巴拉')
self.assertEqual(rm_data['voice_id'], 'voice_kpbl_01')
self.assertTrue(rm_data['is_bound'])
self.assertEqual(rm_data['volume'], 50)
self.assertEqual(rm_data['brightness'], 50)
def test_bind_creates_new_memory_each_time(self):
"""测试每次绑定新设备都创建新的角色记忆"""
device2 = Device.objects.create(
sn='AL-KPBL-ON-25W01-A01-00002',
device_type=self.device_type,
status='in_stock'
)
# 绑定第一个设备
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
# 绑定第二个设备
self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json')
self.assertEqual(RoleMemory.objects.filter(user=self.user).count(), 2)
def test_unbind_marks_memory_idle(self):
"""测试解绑后角色记忆标记为闲置"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
rm = ud.role_memory
url = f'/api/v1/devices/{ud.id}/unbind/'
response = self.client.delete(url)
self.assertEqual(response.data['code'], 0)
rm.refresh_from_db()
self.assertFalse(rm.is_bound)
def test_unbind_preserves_memory(self):
"""测试解绑不删除角色记忆"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
self.client.delete(f'/api/v1/devices/{ud.id}/unbind/')
self.assertEqual(RoleMemory.objects.filter(user=self.user).count(), 1)
def test_get_role_memory(self):
"""测试获取角色记忆"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
url = f'/api/v1/devices/{ud.id}/role-memory/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['prompt'], '你是一只可爱的卡皮巴拉')
def test_update_role_memory_settings(self):
"""测试更新角色记忆设备设置"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
url = f'/api/v1/devices/{ud.id}/role-memory/settings/'
data = {'nickname': '我的卡皮', 'volume': 80, 'brightness': 30}
response = self.client.put(url, data, format='json')
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['nickname'], '我的卡皮')
self.assertEqual(response.data['data']['volume'], 80)
self.assertEqual(response.data['data']['brightness'], 30)
def test_update_role_memory_agent(self):
"""测试更新Agent信息"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
url = f'/api/v1/devices/{ud.id}/role-memory/agent/'
data = {'prompt': '你是一只会讲故事的卡皮巴拉', 'voice_id': 'voice_new'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['prompt'], '你是一只会讲故事的卡皮巴拉')
self.assertEqual(response.data['data']['voice_id'], 'voice_new')
def test_update_role_memory_summary(self):
"""测试更新聊天记忆摘要"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
url = f'/api/v1/devices/{ud.id}/role-memory/memory/'
data = {'memory_summary': '用户喜欢恐龙故事,不喜欢太吓人的情节'}
response = self.client.put(url, data, format='json')
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['memory_summary'], '用户喜欢恐龙故事,不喜欢太吓人的情节')
def test_role_memory_list(self):
"""测试角色记忆列表"""
device2 = Device.objects.create(
sn='AL-KPBL-ON-25W01-A01-00002',
device_type=self.device_type,
status='in_stock'
)
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json')
url = '/api/v1/devices/role-memories/'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 2)
def test_role_memory_list_filter_by_device_type(self):
"""测试角色记忆列表按设备类型过滤"""
other_type = DeviceType.objects.create(
brand='AL', product_code='OTHER-ON', name='其他设备'
)
other_device = Device.objects.create(
sn='AL-OTHER-ON-25W01-A01-00001',
device_type=other_type,
status='in_stock'
)
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
self.client.post('/api/v1/devices/bind/', {'sn': other_device.sn}, format='json')
url = f'/api/v1/devices/role-memories/?device_type_id={self.device_type.id}'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 1)
self.assertEqual(response.data['data'][0]['device_type'], self.device_type.id)
def test_role_memory_list_filter_by_is_bound(self):
"""测试角色记忆列表按绑定状态过滤"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
# 创建一个闲置的记忆
RoleMemory.objects.create(
user=self.user, device_type=self.device_type, is_bound=False
)
url = '/api/v1/devices/role-memories/?is_bound=false'
response = self.client.get(url)
self.assertEqual(response.data['code'], 0)
self.assertEqual(len(response.data['data']), 1)
self.assertFalse(response.data['data'][0]['is_bound'])
def test_switch_role_memory(self):
"""测试切换角色记忆"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
old_rm = ud.role_memory
# 创建一个闲置的同类型记忆
idle_rm = RoleMemory.objects.create(
user=self.user, device_type=self.device_type,
is_bound=False, prompt='闲置的提示词',
memory_summary='之前的记忆内容'
)
url = f'/api/v1/devices/{ud.id}/switch-role-memory/'
response = self.client.put(url, {'role_memory_id': idle_rm.id}, format='json')
self.assertEqual(response.data['code'], 0)
self.assertEqual(response.data['data']['prompt'], '闲置的提示词')
# 验证状态变化
old_rm.refresh_from_db()
idle_rm.refresh_from_db()
ud.refresh_from_db()
self.assertFalse(old_rm.is_bound)
self.assertTrue(idle_rm.is_bound)
self.assertEqual(ud.role_memory_id, idle_rm.id)
def test_switch_rejects_different_type(self):
"""测试切换到不同类型的记忆被拒绝"""
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
other_type = DeviceType.objects.create(
brand='AL', product_code='OTHER-ON', name='其他设备'
)
other_rm = RoleMemory.objects.create(
user=self.user, device_type=other_type, is_bound=False
)
url = f'/api/v1/devices/{ud.id}/switch-role-memory/'
response = self.client.put(url, {'role_memory_id': other_rm.id}, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_switch_rejects_bound_memory(self):
"""测试切换到已绑定的记忆被拒绝"""
device2 = Device.objects.create(
sn='AL-KPBL-ON-25W01-A01-00002',
device_type=self.device_type,
status='in_stock'
)
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json')
ud1 = UserDevice.objects.get(user=self.user, device=self.device)
ud2 = UserDevice.objects.get(user=self.user, device=device2)
url = f'/api/v1/devices/{ud1.id}/switch-role-memory/'
response = self.client.put(url, {'role_memory_id': ud2.role_memory_id}, format='json')
self.assertNotEqual(response.data['code'], 0)
def test_user_isolation(self):
"""测试用户隔离 - 不能访问其他用户的角色记忆"""
other_user = User.objects.create_user(phone='13800139001')
other_rm = RoleMemory.objects.create(
user=other_user, device_type=self.device_type, is_bound=False
)
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
ud = UserDevice.objects.get(user=self.user, device=self.device)
# 尝试切换到其他用户的记忆
url = f'/api/v1/devices/{ud.id}/switch-role-memory/'
response = self.client.put(url, {'role_memory_id': other_rm.id}, format='json')
self.assertNotEqual(response.data['code'], 0)