2342 lines
89 KiB
Python
2342 lines
89 KiB
Python
"""
|
||
RTC_DEMO API 完整测试用例
|
||
覆盖所有API接口
|
||
运行方式: python manage.py test tests --verbosity=2 --keepdb
|
||
"""
|
||
import json
|
||
from datetime import date
|
||
from unittest.mock import patch, MagicMock
|
||
from django.test import TestCase
|
||
from django.urls import reverse
|
||
from rest_framework.test import APITestCase, APIClient
|
||
from rest_framework import status
|
||
from apps.users.models import User, PointsRecord
|
||
from apps.admins.models import AdminUser
|
||
from apps.spirits.models import Spirit
|
||
from apps.devices.models import DeviceType, DeviceBatch, Device, UserDevice, RoleMemory
|
||
from apps.stories.models import StoryShelf, Story
|
||
from apps.music.models import Track
|
||
from apps.users.views import get_app_tokens
|
||
|
||
|
||
# ==================== App端测试 ====================
|
||
|
||
class UserAuthTests(APITestCase):
|
||
"""App端用户认证测试"""
|
||
|
||
def test_phone_login_new_user(self):
|
||
"""测试手机号一键登录 - 新用户"""
|
||
url = '/api/v1/auth/phone-login/'
|
||
data = {'phone': '13800138001'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertIn('token', response.data['data'])
|
||
self.assertTrue(response.data['data']['is_new_user'])
|
||
self.assertEqual(response.data['data']['user']['phone'], '13800138001')
|
||
|
||
def test_phone_login_existing_user(self):
|
||
"""测试手机号一键登录 - 已有用户"""
|
||
User.objects.create_user(phone='13800138002', nickname='测试用户')
|
||
|
||
url = '/api/v1/auth/phone-login/'
|
||
data = {'phone': '13800138002'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertFalse(response.data['data']['is_new_user'])
|
||
|
||
def test_phone_login_invalid_phone(self):
|
||
"""测试手机号格式验证"""
|
||
url = '/api/v1/auth/phone-login/'
|
||
data = {'phone': '1234567'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_phone_login_disabled_user(self):
|
||
"""测试禁用用户登录"""
|
||
User.objects.create_user(phone='13800138003', is_active=False)
|
||
|
||
url = '/api/v1/auth/phone-login/'
|
||
data = {'phone': '13800138003'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_token_refresh(self):
|
||
"""测试Token刷新"""
|
||
login_url = '/api/v1/auth/phone-login/'
|
||
login_response = self.client.post(login_url, {'phone': '13800138004'}, format='json')
|
||
refresh_token = login_response.data['data']['token']['refresh']
|
||
|
||
refresh_url = '/api/v1/auth/refresh/'
|
||
response = self.client.post(refresh_url, {'refresh': refresh_token}, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertIn('access', response.data['data'])
|
||
|
||
def test_token_refresh_invalid(self):
|
||
"""测试无效Token刷新"""
|
||
url = '/api/v1/auth/refresh/'
|
||
response = self.client.post(url, {'refresh': 'invalid_token'}, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_token_refresh_empty(self):
|
||
"""测试空Token刷新"""
|
||
url = '/api/v1/auth/refresh/'
|
||
response = self.client.post(url, {}, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
|
||
class UserProfileTests(APITestCase):
|
||
"""App端用户信息测试"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='13800138010', nickname='测试用户')
|
||
login_url = '/api/v1/auth/phone-login/'
|
||
response = self.client.post(login_url, {'phone': '13800138010'}, format='json')
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
def test_get_user_info(self):
|
||
"""测试获取用户信息"""
|
||
url = '/api/v1/users/me/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['phone'], '13800138010')
|
||
|
||
def test_get_user_info_unauthorized(self):
|
||
"""测试未登录获取用户信息"""
|
||
self.client.credentials()
|
||
url = '/api/v1/users/me/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||
|
||
def test_update_user_info(self):
|
||
"""测试更新用户信息"""
|
||
url = '/api/v1/users/update_me/'
|
||
data = {'nickname': '新昵称'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['nickname'], '新昵称')
|
||
|
||
def test_update_user_avatar(self):
|
||
"""测试更新用户头像"""
|
||
url = '/api/v1/users/update_me/'
|
||
data = {'avatar': 'https://example.com/avatar.jpg'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
|
||
class SpiritTests(APITestCase):
|
||
"""App端智能体测试"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='13800138020', nickname='测试用户')
|
||
login_url = '/api/v1/auth/phone-login/'
|
||
response = self.client.post(login_url, {'phone': '13800138020'}, format='json')
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
def test_create_spirit(self):
|
||
"""测试创建智能体"""
|
||
url = '/api/v1/spirits/'
|
||
data = {
|
||
'name': '测试心灵',
|
||
'prompt': '你是一个友好的助手',
|
||
'voice_id': 'voice_001'
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['name'], '测试心灵')
|
||
|
||
def test_create_spirit_minimal(self):
|
||
"""测试创建智能体 - 最少字段"""
|
||
url = '/api/v1/spirits/'
|
||
data = {'name': '简单心灵'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
def test_list_spirits(self):
|
||
"""测试获取智能体列表"""
|
||
Spirit.objects.create(user=self.user, name='心灵1')
|
||
Spirit.objects.create(user=self.user, name='心灵2')
|
||
|
||
url = '/api/v1/spirits/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 2)
|
||
|
||
def test_list_spirits_empty(self):
|
||
"""测试获取空智能体列表"""
|
||
url = '/api/v1/spirits/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(len(response.data['data']), 0)
|
||
|
||
def test_get_spirit_detail(self):
|
||
"""测试获取智能体详情"""
|
||
spirit = Spirit.objects.create(user=self.user, name='详情测试')
|
||
|
||
url = f'/api/v1/spirits/{spirit.id}/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['name'], '详情测试')
|
||
|
||
def test_update_spirit(self):
|
||
"""测试更新智能体"""
|
||
spirit = Spirit.objects.create(user=self.user, name='原始名称')
|
||
|
||
url = f'/api/v1/spirits/{spirit.id}/'
|
||
data = {'name': '新名称', 'prompt': '新的提示词'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['name'], '新名称')
|
||
|
||
def test_delete_spirit(self):
|
||
"""测试删除智能体"""
|
||
spirit = Spirit.objects.create(user=self.user, name='待删除')
|
||
|
||
url = f'/api/v1/spirits/{spirit.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertFalse(Spirit.objects.filter(id=spirit.id).exists())
|
||
|
||
def test_spirit_isolation(self):
|
||
"""测试智能体用户隔离 - 不能访问其他用户的智能体"""
|
||
other_user = User.objects.create_user(phone='13800138021')
|
||
other_spirit = Spirit.objects.create(user=other_user, name='其他用户的心灵')
|
||
|
||
url = f'/api/v1/spirits/{other_spirit.id}/'
|
||
response = self.client.get(url)
|
||
|
||
# 应该返回404或权限错误
|
||
self.assertNotEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
|
||
class DeviceTests(APITestCase):
|
||
"""App端设备测试"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='13800138050', nickname='测试用户')
|
||
login_url = '/api/v1/auth/phone-login/'
|
||
response = self.client.post(login_url, {'phone': '13800138050'}, format='json')
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
self.device_type = DeviceType.objects.create(
|
||
brand='AL',
|
||
product_code='DZBJ-ON',
|
||
name='电子吧唧-联网版'
|
||
)
|
||
self.device = Device.objects.create(
|
||
sn='AL-DZBJ-ON-25W45-A01-00001',
|
||
device_type=self.device_type,
|
||
mac_address='AA:BB:CC:DD:EE:FF',
|
||
status='in_stock'
|
||
)
|
||
|
||
def test_query_by_mac(self):
|
||
"""测试通过MAC地址查询SN码(无需登录)"""
|
||
self.client.credentials() # 清除认证
|
||
url = '/api/v1/devices/query-by-mac/?mac=AA:BB:CC:DD:EE:FF'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['sn'], 'AL-DZBJ-ON-25W45-A01-00001')
|
||
|
||
def test_query_by_mac_lowercase(self):
|
||
"""测试MAC地址查询 - 小写格式"""
|
||
self.client.credentials()
|
||
url = '/api/v1/devices/query-by-mac/?mac=aa:bb:cc:dd:ee:ff'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
def test_query_by_mac_dash_format(self):
|
||
"""测试MAC地址查询 - 横杠格式"""
|
||
self.client.credentials()
|
||
url = '/api/v1/devices/query-by-mac/?mac=AA-BB-CC-DD-EE-FF'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
def test_query_by_mac_not_found(self):
|
||
"""测试MAC地址查询 - 设备不存在"""
|
||
self.client.credentials()
|
||
url = '/api/v1/devices/query-by-mac/?mac=11:22:33:44:55:66'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||
|
||
def test_query_by_mac_empty(self):
|
||
"""测试MAC地址查询 - 空参数"""
|
||
self.client.credentials()
|
||
url = '/api/v1/devices/query-by-mac/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_verify_device(self):
|
||
"""测试验证设备SN"""
|
||
url = '/api/v1/devices/verify/'
|
||
data = {'sn': 'AL-DZBJ-ON-25W45-A01-00001'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertTrue(response.data['data']['is_bindable'])
|
||
|
||
def test_verify_device_not_found(self):
|
||
"""测试验证设备SN - 不存在"""
|
||
url = '/api/v1/devices/verify/'
|
||
data = {'sn': 'NOT-EXIST-SN'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_bind_device(self):
|
||
"""测试绑定设备"""
|
||
spirit = Spirit.objects.create(user=self.user, name='测试心灵')
|
||
|
||
url = '/api/v1/devices/bind/'
|
||
data = {
|
||
'sn': 'AL-DZBJ-ON-25W45-A01-00001',
|
||
'spirit_id': spirit.id
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
# 验证设备状态
|
||
self.device.refresh_from_db()
|
||
self.assertEqual(self.device.status, 'bound')
|
||
|
||
def test_bind_device_without_spirit(self):
|
||
"""测试绑定设备 - 不绑定智能体"""
|
||
device2 = Device.objects.create(
|
||
sn='AL-DZBJ-ON-25W45-A01-00002',
|
||
device_type=self.device_type,
|
||
status='in_stock'
|
||
)
|
||
|
||
url = '/api/v1/devices/bind/'
|
||
data = {'sn': 'AL-DZBJ-ON-25W45-A01-00002'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
def test_my_devices(self):
|
||
"""测试我的设备列表"""
|
||
UserDevice.objects.create(user=self.user, device=self.device, is_active=True)
|
||
|
||
url = '/api/v1/devices/my_devices/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 1)
|
||
|
||
def test_my_devices_empty(self):
|
||
"""测试我的设备列表 - 空列表"""
|
||
url = '/api/v1/devices/my_devices/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(len(response.data['data']), 0)
|
||
|
||
def test_unbind_device(self):
|
||
"""测试解绑设备"""
|
||
user_device = UserDevice.objects.create(user=self.user, device=self.device, is_active=True)
|
||
self.device.status = 'bound'
|
||
self.device.save()
|
||
|
||
url = f'/api/v1/devices/{user_device.id}/unbind/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
user_device.refresh_from_db()
|
||
self.assertFalse(user_device.is_active)
|
||
|
||
def test_update_spirit_on_device(self):
|
||
"""测试更新设备绑定的智能体"""
|
||
spirit1 = Spirit.objects.create(user=self.user, name='心灵1')
|
||
spirit2 = Spirit.objects.create(user=self.user, name='心灵2')
|
||
user_device = UserDevice.objects.create(
|
||
user=self.user, device=self.device, spirit=spirit1, is_active=True
|
||
)
|
||
|
||
url = f'/api/v1/devices/{user_device.id}/update-spirit/'
|
||
data = {'spirit_id': spirit2.id}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
user_device.refresh_from_db()
|
||
self.assertEqual(user_device.spirit_id, spirit2.id)
|
||
|
||
|
||
# ==================== 管理端测试 ====================
|
||
|
||
class AdminAuthTests(APITestCase):
|
||
"""管理端认证测试"""
|
||
|
||
def setUp(self):
|
||
self.admin = AdminUser.objects.create_user(
|
||
username='admin',
|
||
password='admin123',
|
||
role='super_admin'
|
||
)
|
||
|
||
def test_admin_login(self):
|
||
"""测试管理员登录"""
|
||
url = '/api/admin/auth/login/'
|
||
data = {'username': 'admin', 'password': 'admin123'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertIn('token', response.data['data'])
|
||
self.assertEqual(response.data['data']['admin']['username'], 'admin')
|
||
|
||
def test_admin_login_wrong_password(self):
|
||
"""测试管理员登录 - 密码错误"""
|
||
url = '/api/admin/auth/login/'
|
||
data = {'username': 'admin', 'password': 'wrong'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_admin_login_user_not_exist(self):
|
||
"""测试管理员登录 - 用户不存在"""
|
||
url = '/api/admin/auth/login/'
|
||
data = {'username': 'notexist', 'password': 'password'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_admin_login_disabled(self):
|
||
"""测试管理员登录 - 账户禁用"""
|
||
disabled_admin = AdminUser.objects.create_user(
|
||
username='disabled', password='pass123', is_active=False
|
||
)
|
||
|
||
url = '/api/admin/auth/login/'
|
||
data = {'username': 'disabled', 'password': 'pass123'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_admin_token_refresh(self):
|
||
"""测试管理员Token刷新"""
|
||
login_url = '/api/admin/auth/login/'
|
||
login_response = self.client.post(
|
||
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
|
||
)
|
||
refresh_token = login_response.data['data']['token']['refresh']
|
||
|
||
refresh_url = '/api/admin/auth/refresh/'
|
||
response = self.client.post(refresh_url, {'refresh': refresh_token}, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
|
||
class AdminProfileTests(APITestCase):
|
||
"""管理端个人信息测试"""
|
||
|
||
def setUp(self):
|
||
self.admin = AdminUser.objects.create_user(
|
||
username='admin', password='admin123', role='admin'
|
||
)
|
||
login_url = '/api/admin/auth/login/'
|
||
response = self.client.post(
|
||
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
|
||
)
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
def test_get_admin_profile(self):
|
||
"""测试获取管理员信息"""
|
||
url = '/api/admin/profile/me/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['username'], 'admin')
|
||
|
||
def test_change_password(self):
|
||
"""测试修改密码"""
|
||
url = '/api/admin/profile/change-password/'
|
||
data = {'old_password': 'admin123', 'new_password': 'newpass123'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
# 验证新密码可以登录
|
||
self.client.credentials()
|
||
login_url = '/api/admin/auth/login/'
|
||
login_response = self.client.post(
|
||
login_url, {'username': 'admin', 'password': 'newpass123'}, format='json'
|
||
)
|
||
self.assertEqual(login_response.data['code'], 0)
|
||
|
||
def test_change_password_wrong_old(self):
|
||
"""测试修改密码 - 原密码错误"""
|
||
url = '/api/admin/profile/change-password/'
|
||
data = {'old_password': 'wrongpass', 'new_password': 'newpass123'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
|
||
class AdminUserManageTests(APITestCase):
|
||
"""管理员用户管理测试"""
|
||
|
||
def setUp(self):
|
||
self.super_admin = AdminUser.objects.create_user(
|
||
username='superadmin', password='super123', role='super_admin'
|
||
)
|
||
login_url = '/api/admin/auth/login/'
|
||
response = self.client.post(
|
||
login_url, {'username': 'superadmin', 'password': 'super123'}, format='json'
|
||
)
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
def test_list_admins(self):
|
||
"""测试管理员列表"""
|
||
url = '/api/admin/admins/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
def test_create_admin(self):
|
||
"""测试创建管理员"""
|
||
url = '/api/admin/admins/'
|
||
data = {
|
||
'username': 'newadmin',
|
||
'password': 'newpass123',
|
||
'name': '新管理员',
|
||
'role': 'operator'
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['username'], 'newadmin')
|
||
|
||
def test_get_admin_detail(self):
|
||
"""测试获取管理员详情"""
|
||
admin = AdminUser.objects.create_user(username='testadmin', password='test123')
|
||
|
||
url = f'/api/admin/admins/{admin.id}/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['username'], 'testadmin')
|
||
|
||
def test_toggle_admin_status(self):
|
||
"""测试启用/禁用管理员"""
|
||
admin = AdminUser.objects.create_user(username='testadmin2', password='test123')
|
||
|
||
url = f'/api/admin/admins/{admin.id}/toggle-status/'
|
||
response = self.client.post(url, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
admin.refresh_from_db()
|
||
self.assertFalse(admin.is_active)
|
||
|
||
def test_reset_admin_password(self):
|
||
"""测试重置管理员密码"""
|
||
admin = AdminUser.objects.create_user(username='testadmin3', password='oldpass')
|
||
|
||
url = f'/api/admin/admins/{admin.id}/reset-password/'
|
||
data = {'new_password': 'resetpass123'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
|
||
class AdminDeviceTypeTests(APITestCase):
|
||
"""管理端设备类型测试"""
|
||
|
||
def setUp(self):
|
||
self.admin = AdminUser.objects.create_user(
|
||
username='admin', password='admin123', role='super_admin'
|
||
)
|
||
login_url = '/api/admin/auth/login/'
|
||
response = self.client.post(
|
||
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
|
||
)
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
def test_create_device_type_network(self):
|
||
"""测试创建联网设备类型"""
|
||
url = '/api/admin/device-types/'
|
||
data = {
|
||
'brand': 'AL',
|
||
'product_code': 'DZBJ-ON',
|
||
'name': '电子吧唧-联网版'
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertTrue(response.data['data']['is_network_required'])
|
||
|
||
def test_create_device_type_offline(self):
|
||
"""测试创建非联网设备类型"""
|
||
url = '/api/admin/device-types/'
|
||
data = {
|
||
'brand': 'AL',
|
||
'product_code': 'DZBJ-OFF',
|
||
'name': '电子吧唧-离线版'
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertFalse(response.data['data']['is_network_required'])
|
||
|
||
def test_list_device_types(self):
|
||
"""测试设备类型列表"""
|
||
DeviceType.objects.create(brand='AL', product_code='TEST-ON', name='测试设备')
|
||
|
||
url = '/api/admin/device-types/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
def test_get_device_type_detail(self):
|
||
"""测试设备类型详情"""
|
||
dt = DeviceType.objects.create(brand='AL', product_code='TEST-DT', name='测试类型')
|
||
|
||
url = f'/api/admin/device-types/{dt.id}/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['name'], '测试类型')
|
||
|
||
def test_update_device_type(self):
|
||
"""测试更新设备类型"""
|
||
dt = DeviceType.objects.create(brand='AL', product_code='TEST-UP', name='原名称')
|
||
|
||
url = f'/api/admin/device-types/{dt.id}/'
|
||
data = {'name': '新名称'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['data']['name'], '新名称')
|
||
|
||
|
||
class AdminBatchTests(APITestCase):
|
||
"""管理端批次测试"""
|
||
|
||
def setUp(self):
|
||
self.admin = AdminUser.objects.create_user(
|
||
username='admin', password='admin123', role='super_admin'
|
||
)
|
||
login_url = '/api/admin/auth/login/'
|
||
response = self.client.post(
|
||
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
|
||
)
|
||
self.access_token = response.data['data']['token']['access']
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.access_token}')
|
||
|
||
self.device_type = DeviceType.objects.create(
|
||
brand='AL', product_code='DZBJ-ON', name='电子吧唧-联网版'
|
||
)
|
||
|
||
def test_create_batch_and_generate_sn(self):
|
||
"""测试创建批次并生成SN码"""
|
||
url = '/api/admin/device-batches/'
|
||
data = {
|
||
'device_type': self.device_type.id,
|
||
'batch_no': 'A01',
|
||
'production_date': '2026-01-28',
|
||
'quantity': 10,
|
||
'remark': '测试批次'
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['quantity'], 10)
|
||
self.assertIn('sn_range', response.data['data'])
|
||
|
||
# 验证SN码格式
|
||
sn_start = response.data['data']['sn_range']['start']
|
||
self.assertTrue(sn_start.startswith('AL-DZBJ-ON-'))
|
||
|
||
# 验证设备数量
|
||
device_count = Device.objects.filter(batch__batch_no='A01').count()
|
||
self.assertEqual(device_count, 10)
|
||
|
||
def test_list_batches(self):
|
||
"""测试批次列表"""
|
||
url = '/api/admin/device-batches/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
|
||
def test_get_batch_detail(self):
|
||
"""测试批次详情"""
|
||
batch = DeviceBatch.objects.create(
|
||
device_type=self.device_type,
|
||
batch_no='B01',
|
||
production_date=date(2026, 1, 28),
|
||
quantity=5
|
||
)
|
||
|
||
url = f'/api/admin/device-batches/{batch.id}/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertIn('statistics', response.data['data'])
|
||
|
||
def test_get_batch_devices(self):
|
||
"""测试批次设备列表"""
|
||
batch = DeviceBatch.objects.create(
|
||
device_type=self.device_type,
|
||
batch_no='C01',
|
||
production_date=date(2026, 1, 28),
|
||
quantity=3
|
||
)
|
||
Device.objects.create(sn='TEST-001', batch=batch, device_type=self.device_type)
|
||
Device.objects.create(sn='TEST-002', batch=batch, device_type=self.device_type)
|
||
|
||
url = f'/api/admin/device-batches/{batch.id}/devices/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(len(response.data['data']['items']), 2)
|
||
|
||
|
||
class AuthSeparationTests(APITestCase):
|
||
"""验证App和Admin认证分离"""
|
||
|
||
def setUp(self):
|
||
# 创建App用户
|
||
self.app_user = User.objects.create_user(phone='13800138099')
|
||
login_url = '/api/v1/auth/phone-login/'
|
||
response = self.client.post(login_url, {'phone': '13800138099'}, format='json')
|
||
self.app_token = response.data['data']['token']['access']
|
||
|
||
# 创建Admin用户
|
||
self.admin = AdminUser.objects.create_user(
|
||
username='admin', password='admin123', role='admin'
|
||
)
|
||
login_url = '/api/admin/auth/login/'
|
||
response = self.client.post(
|
||
login_url, {'username': 'admin', 'password': 'admin123'}, format='json'
|
||
)
|
||
self.admin_token = response.data['data']['token']['access']
|
||
|
||
def test_app_token_cannot_access_admin_api(self):
|
||
"""测试App Token无法访问管理端API"""
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.app_token}')
|
||
|
||
url = '/api/admin/device-types/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertNotEqual(response.data.get('code', 0), 0)
|
||
|
||
def test_admin_token_cannot_access_app_api(self):
|
||
"""测试Admin Token无法访问App端API"""
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {self.admin_token}')
|
||
|
||
url = '/api/v1/users/me/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertNotEqual(response.data.get('code', 0), 0)
|
||
|
||
def test_no_token_cannot_access_protected_api(self):
|
||
"""测试无Token无法访问受保护API"""
|
||
self.client.credentials()
|
||
|
||
url = '/api/v1/users/me/'
|
||
response = self.client.get(url)
|
||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||
|
||
url = '/api/admin/device-types/'
|
||
response = self.client.get(url)
|
||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||
|
||
|
||
# ==================== Log Center 集成测试 ====================
|
||
|
||
class LogCenterIntegrationTests(TestCase):
|
||
"""Log Center 错误上报集成测试"""
|
||
|
||
def test_report_to_log_center_function_exists(self):
|
||
"""测试 report_to_log_center 函数存在"""
|
||
from utils.exceptions import report_to_log_center
|
||
self.assertTrue(callable(report_to_log_center))
|
||
|
||
def test_report_to_log_center_with_exception(self):
|
||
"""测试上报异常到 Log Center"""
|
||
from utils.exceptions import report_to_log_center
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
# 创建测试异常
|
||
try:
|
||
raise ValueError("Test error message")
|
||
except ValueError as e:
|
||
test_exc = e
|
||
|
||
# Mock context
|
||
mock_request = MagicMock()
|
||
mock_request.path = '/api/test/'
|
||
mock_request.method = 'GET'
|
||
context = {'request': mock_request, 'view': 'TestView'}
|
||
|
||
# Mock requests.post
|
||
with patch('utils.exceptions.requests.post') as mock_post:
|
||
with patch('utils.exceptions.LOG_CENTER_ENABLED', True):
|
||
report_to_log_center(test_exc, context)
|
||
|
||
# 等待线程执行(测试中同步执行更可靠)
|
||
import time
|
||
time.sleep(0.5)
|
||
|
||
def test_report_payload_structure(self):
|
||
"""测试上报 payload 结构正确"""
|
||
from utils.exceptions import report_to_log_center
|
||
from unittest.mock import patch, MagicMock
|
||
import json
|
||
|
||
captured_payload = None
|
||
|
||
def capture_post(url, json=None, timeout=None):
|
||
nonlocal captured_payload
|
||
captured_payload = json
|
||
return MagicMock(status_code=200)
|
||
|
||
try:
|
||
raise TypeError("Type mismatch error")
|
||
except TypeError as e:
|
||
test_exc = e
|
||
|
||
mock_request = MagicMock()
|
||
mock_request.path = '/api/users/me/'
|
||
mock_request.method = 'POST'
|
||
context = {'request': mock_request, 'view': 'UserView'}
|
||
|
||
with patch('utils.exceptions.requests.post', side_effect=capture_post):
|
||
with patch('utils.exceptions.LOG_CENTER_ENABLED', True):
|
||
with patch('utils.exceptions.threading.Thread') as mock_thread:
|
||
# 直接调用 target 函数而不是启动线程
|
||
mock_thread_instance = MagicMock()
|
||
def run_target(*args, **kwargs):
|
||
target = kwargs.get('target') or args[0]
|
||
target()
|
||
mock_thread.side_effect = lambda *args, **kwargs: MagicMock(
|
||
start=lambda: run_target(*args, **kwargs),
|
||
daemon=True
|
||
)
|
||
|
||
report_to_log_center(test_exc, context)
|
||
|
||
# 验证 payload 结构
|
||
if captured_payload:
|
||
self.assertEqual(captured_payload['project_id'], 'rtc_backend')
|
||
self.assertEqual(captured_payload['level'], 'ERROR')
|
||
self.assertEqual(captured_payload['error']['type'], 'TypeError')
|
||
self.assertEqual(captured_payload['error']['message'], 'Type mismatch error')
|
||
self.assertIn('stack_trace', captured_payload['error'])
|
||
self.assertEqual(captured_payload['context']['url'], '/api/users/me/')
|
||
self.assertEqual(captured_payload['context']['method'], 'POST')
|
||
|
||
def test_report_disabled_when_flag_off(self):
|
||
"""测试关闭开关时不上报"""
|
||
from utils.exceptions import report_to_log_center
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
try:
|
||
raise Exception("Should not be reported")
|
||
except Exception as e:
|
||
test_exc = e
|
||
|
||
with patch('utils.exceptions.requests.post') as mock_post:
|
||
with patch('utils.exceptions.LOG_CENTER_ENABLED', False):
|
||
report_to_log_center(test_exc, {})
|
||
# 应该不调用 requests.post
|
||
mock_post.assert_not_called()
|
||
|
||
def test_report_silent_failure(self):
|
||
"""测试上报失败不抛异常"""
|
||
from utils.exceptions import report_to_log_center
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
try:
|
||
raise Exception("Test exception")
|
||
except Exception as e:
|
||
test_exc = e
|
||
|
||
with patch('utils.exceptions.requests.post', side_effect=Exception("Network error")):
|
||
with patch('utils.exceptions.LOG_CENTER_ENABLED', True):
|
||
# 不应抛出异常
|
||
try:
|
||
report_to_log_center(test_exc, {})
|
||
except Exception:
|
||
self.fail("report_to_log_center should not raise exceptions")
|
||
|
||
|
||
class ExceptionHandlerIntegrationTests(APITestCase):
|
||
"""异常处理器集成测试 - 验证异常时触发 Log Center 上报"""
|
||
|
||
def test_exception_triggers_log_center_report(self):
|
||
"""测试异常触发 Log Center 上报"""
|
||
from utils.exceptions import custom_exception_handler
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
# 创建异常
|
||
test_exc = ValueError("Database connection failed")
|
||
|
||
# Mock context
|
||
mock_request = MagicMock()
|
||
mock_request.path = '/api/test/'
|
||
mock_request.method = 'GET'
|
||
context = {'request': mock_request, 'view': MagicMock()}
|
||
|
||
with patch('utils.exceptions.report_to_log_center') as mock_report:
|
||
# 调用异常处理器
|
||
custom_exception_handler(test_exc, context)
|
||
|
||
# 验证调用了上报函数
|
||
mock_report.assert_called_once()
|
||
call_args = mock_report.call_args
|
||
self.assertEqual(call_args[0][0], test_exc)
|
||
|
||
def test_business_exception_not_reported(self):
|
||
"""测试业务异常不上报到 Log Center"""
|
||
from utils.exceptions import custom_exception_handler, BusinessException
|
||
from unittest.mock import patch, MagicMock
|
||
|
||
# 创建业务异常
|
||
biz_exc = BusinessException(code=100, message="用户不存在")
|
||
|
||
context = {'request': MagicMock(), 'view': MagicMock()}
|
||
|
||
with patch('utils.exceptions.report_to_log_center') as mock_report:
|
||
custom_exception_handler(biz_exc, context)
|
||
|
||
# 业务异常不应触发上报
|
||
mock_report.assert_not_called()
|
||
|
||
|
||
# ==================== 故事模块测试 ====================
|
||
|
||
class StoryTestBase(APITestCase):
|
||
"""故事模块测试基类"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='13800130001', nickname='故事测试用户')
|
||
tokens = get_app_tokens(self.user)
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
|
||
# 创建默认书架
|
||
self.shelf = StoryShelf.objects.create(user=self.user, name='我的书架')
|
||
|
||
|
||
class StoryShelfTests(StoryTestBase):
|
||
"""书架接口测试"""
|
||
|
||
def test_list_shelves(self):
|
||
"""测试获取书架列表"""
|
||
url = '/api/v1/stories/shelves/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 1)
|
||
self.assertEqual(response.data['data'][0]['name'], '我的书架')
|
||
|
||
def test_list_shelves_auto_create_default(self):
|
||
"""测试首次查询自动创建默认书架"""
|
||
new_user = User.objects.create_user(phone='13800130099')
|
||
tokens = get_app_tokens(new_user)
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
|
||
|
||
url = '/api/v1/stories/shelves/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 1)
|
||
self.assertEqual(response.data['data'][0]['name'], '我的书架')
|
||
|
||
def test_list_shelves_includes_story_count(self):
|
||
"""测试书架列表包含故事数量"""
|
||
Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='测试故事', content='内容'
|
||
)
|
||
|
||
url = '/api/v1/stories/shelves/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data'][0]['story_count'], 1)
|
||
|
||
def test_create_shelf(self):
|
||
"""测试创建书架"""
|
||
url = '/api/v1/stories/shelves/'
|
||
data = {'name': '新书架'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['name'], '新书架')
|
||
self.assertEqual(response.data['data']['capacity'], 10)
|
||
|
||
def test_delete_shelf(self):
|
||
"""测试删除书架"""
|
||
shelf = StoryShelf.objects.create(user=self.user, name='待删除书架')
|
||
story = Story.objects.create(
|
||
user=self.user, shelf=shelf,
|
||
title='测试', content='内容'
|
||
)
|
||
|
||
url = f'/api/v1/stories/shelves/{shelf.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertFalse(StoryShelf.objects.filter(id=shelf.id).exists())
|
||
# 故事保留,shelf_id 置 null
|
||
story.refresh_from_db()
|
||
self.assertIsNone(story.shelf)
|
||
|
||
def test_delete_shelf_not_found(self):
|
||
"""测试删除不存在的书架"""
|
||
url = '/api/v1/stories/shelves/99999/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_shelf_isolation(self):
|
||
"""测试书架用户隔离"""
|
||
other_user = User.objects.create_user(phone='13800130002')
|
||
other_shelf = StoryShelf.objects.create(user=other_user, name='别人的书架')
|
||
|
||
url = f'/api/v1/stories/shelves/{other_shelf.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
# 确认没被删除
|
||
self.assertTrue(StoryShelf.objects.filter(id=other_shelf.id).exists())
|
||
|
||
|
||
class ShelfUnlockTests(StoryTestBase):
|
||
"""书架解锁测试"""
|
||
|
||
def test_unlock_shelf_success(self):
|
||
"""测试积分解锁书架 - 成功"""
|
||
self.user.points = 200
|
||
self.user.save(update_fields=['points'])
|
||
|
||
url = '/api/v1/stories/shelves/unlock/'
|
||
response = self.client.post(url, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['remaining_points'], 100)
|
||
self.assertIn('shelf', response.data['data'])
|
||
|
||
# 验证积分扣除
|
||
self.user.refresh_from_db()
|
||
self.assertEqual(self.user.points, 100)
|
||
|
||
# 验证积分流水
|
||
record = PointsRecord.objects.filter(user=self.user).first()
|
||
self.assertIsNotNone(record)
|
||
self.assertEqual(record.amount, -100)
|
||
self.assertEqual(record.type, 'unlock_shelf')
|
||
|
||
def test_unlock_shelf_not_enough_points(self):
|
||
"""测试积分解锁书架 - 积分不足"""
|
||
self.user.points = 50
|
||
self.user.save(update_fields=['points'])
|
||
|
||
url = '/api/v1/stories/shelves/unlock/'
|
||
response = self.client.post(url, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 603) # POINTS_NOT_ENOUGH
|
||
# 积分不应变化
|
||
self.user.refresh_from_db()
|
||
self.assertEqual(self.user.points, 50)
|
||
|
||
def test_unlock_shelf_zero_points(self):
|
||
"""测试积分解锁书架 - 零积分"""
|
||
self.user.points = 0
|
||
self.user.save(update_fields=['points'])
|
||
|
||
url = '/api/v1/stories/shelves/unlock/'
|
||
response = self.client.post(url, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 603)
|
||
|
||
def test_unlock_shelf_naming(self):
|
||
"""测试解锁书架自动命名"""
|
||
self.user.points = 500
|
||
self.user.save(update_fields=['points'])
|
||
|
||
url = '/api/v1/stories/shelves/unlock/'
|
||
|
||
# 第一次解锁(已有1个默认书架)
|
||
response = self.client.post(url, format='json')
|
||
self.assertEqual(response.data['data']['shelf']['name'], '书架 2')
|
||
|
||
# 第二次解锁
|
||
response = self.client.post(url, format='json')
|
||
self.assertEqual(response.data['data']['shelf']['name'], '书架 3')
|
||
|
||
|
||
class StoryTests(StoryTestBase):
|
||
"""故事接口测试"""
|
||
|
||
def test_list_stories(self):
|
||
"""测试获取故事列表"""
|
||
Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='故事1', content='内容1'
|
||
)
|
||
Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='故事2', content='内容2'
|
||
)
|
||
|
||
url = '/api/v1/stories/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['total'], 2)
|
||
self.assertEqual(len(response.data['data']['items']), 2)
|
||
|
||
def test_list_stories_filter_by_shelf(self):
|
||
"""测试按书架筛选故事"""
|
||
shelf2 = StoryShelf.objects.create(user=self.user, name='书架2')
|
||
Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='书架1故事', content='内容'
|
||
)
|
||
Story.objects.create(
|
||
user=self.user, shelf=shelf2,
|
||
title='书架2故事', content='内容'
|
||
)
|
||
|
||
url = f'/api/v1/stories/?shelf_id={self.shelf.id}'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['total'], 1)
|
||
self.assertEqual(response.data['data']['items'][0]['title'], '书架1故事')
|
||
|
||
def test_list_stories_empty(self):
|
||
"""测试空故事列表"""
|
||
url = '/api/v1/stories/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['total'], 0)
|
||
|
||
def test_create_story(self):
|
||
"""测试保存故事"""
|
||
url = '/api/v1/stories/'
|
||
data = {
|
||
'title': '新故事',
|
||
'content': '这是故事内容',
|
||
'shelf_id': self.shelf.id,
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['title'], '新故事')
|
||
self.assertEqual(response.data['data']['content'], '这是故事内容')
|
||
|
||
def test_create_story_with_optional_fields(self):
|
||
"""测试保存故事 - 包含可选字段"""
|
||
url = '/api/v1/stories/'
|
||
data = {
|
||
'title': '完整故事',
|
||
'content': '故事正文',
|
||
'shelf_id': self.shelf.id,
|
||
'cover_url': 'https://example.com/cover.jpg',
|
||
'generation_mode': 'ai',
|
||
'prompt': '角色=小猫, 场景=森林',
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['generation_mode'], 'ai')
|
||
|
||
def test_create_story_shelf_full(self):
|
||
"""测试保存故事 - 书架已满"""
|
||
# 创建小容量书架
|
||
small_shelf = StoryShelf.objects.create(
|
||
user=self.user, name='小书架', capacity=2
|
||
)
|
||
Story.objects.create(user=self.user, shelf=small_shelf, title='故事1', content='内容')
|
||
Story.objects.create(user=self.user, shelf=small_shelf, title='故事2', content='内容')
|
||
|
||
url = '/api/v1/stories/'
|
||
data = {
|
||
'title': '溢出故事',
|
||
'content': '这本放不下了',
|
||
'shelf_id': small_shelf.id,
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 604) # SHELF_FULL
|
||
|
||
def test_create_story_shelf_not_found(self):
|
||
"""测试保存故事 - 书架不存在"""
|
||
url = '/api/v1/stories/'
|
||
data = {
|
||
'title': '故事',
|
||
'content': '内容',
|
||
'shelf_id': 99999,
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_create_story_other_user_shelf(self):
|
||
"""测试保存故事到他人书架"""
|
||
other_user = User.objects.create_user(phone='13800130003')
|
||
other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架')
|
||
|
||
url = '/api/v1/stories/'
|
||
data = {
|
||
'title': '故事',
|
||
'content': '内容',
|
||
'shelf_id': other_shelf.id,
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_delete_story(self):
|
||
"""测试删除故事"""
|
||
story = Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='待删除', content='内容'
|
||
)
|
||
|
||
url = f'/api/v1/stories/{story.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertFalse(Story.objects.filter(id=story.id).exists())
|
||
|
||
def test_delete_story_not_found(self):
|
||
"""测试删除不存在的故事"""
|
||
url = '/api/v1/stories/99999/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.data['code'], 600) # STORY_NOT_FOUND
|
||
|
||
def test_story_isolation(self):
|
||
"""测试故事用户隔离"""
|
||
other_user = User.objects.create_user(phone='13800130004')
|
||
other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架')
|
||
other_story = Story.objects.create(
|
||
user=other_user, shelf=other_shelf,
|
||
title='他人故事', content='内容'
|
||
)
|
||
|
||
url = f'/api/v1/stories/{other_story.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
self.assertTrue(Story.objects.filter(id=other_story.id).exists())
|
||
|
||
def test_story_capacity_limit(self):
|
||
"""测试书架容量为10的限制"""
|
||
# 默认书架容量 = 10
|
||
for i in range(10):
|
||
Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title=f'故事{i+1}', content=f'内容{i+1}'
|
||
)
|
||
|
||
url = '/api/v1/stories/'
|
||
data = {
|
||
'title': '第11本',
|
||
'content': '超出容量',
|
||
'shelf_id': self.shelf.id,
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 604) # SHELF_FULL
|
||
|
||
|
||
class StoryGenerateTests(StoryTestBase):
|
||
"""故事生成接口测试"""
|
||
|
||
def test_generate_story_returns_sse(self):
|
||
"""测试生成故事返回 SSE 流"""
|
||
from unittest.mock import patch
|
||
|
||
mock_events = [
|
||
'event: stage\ndata: {"stage":"connecting","progress":0,"message":"正在收集灵感碎片..."}\n\n',
|
||
'event: done\ndata: {"stage":"done","progress":100,"title":"测试故事","content":"故事内容"}\n\n',
|
||
]
|
||
|
||
with patch('apps.stories.services.llm_service.generate_story_stream') as mock_gen:
|
||
mock_gen.return_value = iter(mock_events)
|
||
|
||
url = '/api/v1/stories/generate/'
|
||
data = {
|
||
'characters': ['小猫'],
|
||
'scenes': ['森林'],
|
||
'props': ['魔法棒'],
|
||
}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response['Content-Type'], 'text/event-stream')
|
||
|
||
def test_generate_story_empty_params(self):
|
||
"""测试生成故事 - 空参数(允许,有默认值)"""
|
||
from unittest.mock import patch
|
||
|
||
mock_events = [
|
||
'event: done\ndata: {"stage":"done","progress":100,"title":"默认故事","content":"内容"}\n\n',
|
||
]
|
||
|
||
with patch('apps.stories.services.llm_service.generate_story_stream') as mock_gen:
|
||
mock_gen.return_value = iter(mock_events)
|
||
|
||
url = '/api/v1/stories/generate/'
|
||
response = self.client.post(url, {}, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
|
||
|
||
class StoryTTSTests(StoryTestBase):
|
||
"""TTS 音频接口测试"""
|
||
|
||
def test_tts_check_no_audio(self):
|
||
"""测试查询音频状态 - 无音频"""
|
||
story = Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='无音频故事', content='内容'
|
||
)
|
||
|
||
url = f'/api/v1/stories/{story.id}/tts/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertFalse(response.data['data']['exists'])
|
||
self.assertEqual(response.data['data']['audio_url'], '')
|
||
|
||
def test_tts_check_has_audio(self):
|
||
"""测试查询音频状态 - 有音频"""
|
||
story = Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='有音频故事', content='内容',
|
||
audio_url='https://oss.example.com/audio.mp3'
|
||
)
|
||
|
||
url = f'/api/v1/stories/{story.id}/tts/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertTrue(response.data['data']['exists'])
|
||
self.assertEqual(
|
||
response.data['data']['audio_url'],
|
||
'https://oss.example.com/audio.mp3'
|
||
)
|
||
|
||
def test_tts_generate_returns_sse(self):
|
||
"""测试生成 TTS 返回 SSE 流"""
|
||
from unittest.mock import patch
|
||
|
||
story = Story.objects.create(
|
||
user=self.user, shelf=self.shelf,
|
||
title='TTS测试', content='这是要转换的故事内容'
|
||
)
|
||
|
||
mock_events = [
|
||
'event: stage\ndata: {"stage":"connecting","message":"正在连接..."}\n\n',
|
||
'event: done\ndata: {"stage":"done","audio_url":"https://oss.example.com/audio.mp3"}\n\n',
|
||
]
|
||
|
||
with patch('apps.stories.services.tts_service.generate_tts_stream') as mock_tts:
|
||
mock_tts.return_value = iter(mock_events)
|
||
|
||
url = f'/api/v1/stories/{story.id}/tts/'
|
||
response = self.client.post(url, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response['Content-Type'], 'text/event-stream')
|
||
|
||
def test_tts_story_not_found(self):
|
||
"""测试 TTS - 故事不存在"""
|
||
url = '/api/v1/stories/99999/tts/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 600) # STORY_NOT_FOUND
|
||
|
||
def test_tts_story_isolation(self):
|
||
"""测试 TTS - 不能访问他人故事"""
|
||
other_user = User.objects.create_user(phone='13800130005')
|
||
other_shelf = StoryShelf.objects.create(user=other_user, name='他人书架')
|
||
other_story = Story.objects.create(
|
||
user=other_user, shelf=other_shelf,
|
||
title='他人故事', content='内容'
|
||
)
|
||
|
||
url = f'/api/v1/stories/{other_story.id}/tts/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 600)
|
||
|
||
|
||
class PointsTests(StoryTestBase):
|
||
"""积分接口测试"""
|
||
|
||
def test_query_points(self):
|
||
"""测试查询积分余额"""
|
||
self.user.points = 500
|
||
self.user.save(update_fields=['points'])
|
||
|
||
url = '/api/v1/users/points/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['points'], 500)
|
||
|
||
def test_query_points_default_zero(self):
|
||
"""测试查询积分余额 - 默认为0"""
|
||
url = '/api/v1/users/points/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['points'], 0)
|
||
|
||
def test_points_records_list(self):
|
||
"""测试积分流水记录"""
|
||
PointsRecord.objects.create(
|
||
user=self.user, amount=100,
|
||
type='reward', description='注册奖励'
|
||
)
|
||
PointsRecord.objects.create(
|
||
user=self.user, amount=-100,
|
||
type='unlock_shelf', description='解锁书架'
|
||
)
|
||
|
||
url = '/api/v1/users/points/records/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['total'], 2)
|
||
self.assertEqual(len(response.data['data']['items']), 2)
|
||
|
||
def test_points_records_empty(self):
|
||
"""测试积分流水记录 - 空"""
|
||
url = '/api/v1/users/points/records/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['total'], 0)
|
||
|
||
def test_points_records_pagination(self):
|
||
"""测试积分流水分页"""
|
||
for i in range(5):
|
||
PointsRecord.objects.create(
|
||
user=self.user, amount=10,
|
||
type='reward', description=f'奖励{i+1}'
|
||
)
|
||
|
||
url = '/api/v1/users/points/records/?page=1&page_size=3'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['data']['total'], 5)
|
||
self.assertEqual(len(response.data['data']['items']), 3)
|
||
|
||
def test_points_records_isolation(self):
|
||
"""测试积分流水用户隔离"""
|
||
other_user = User.objects.create_user(phone='13800130006')
|
||
PointsRecord.objects.create(
|
||
user=other_user, amount=100,
|
||
type='reward', description='他人的奖励'
|
||
)
|
||
|
||
url = '/api/v1/users/points/records/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['data']['total'], 0)
|
||
|
||
|
||
class LLMServiceTests(TestCase):
|
||
"""LLM 服务单元测试"""
|
||
|
||
def test_build_user_prompt(self):
|
||
"""测试构建用户提示词"""
|
||
from apps.stories.services.llm_service import build_user_prompt
|
||
|
||
prompt = build_user_prompt(['小猫', '小狗'], ['森林'], ['魔法棒'])
|
||
self.assertIn('小猫', prompt)
|
||
self.assertIn('小狗', prompt)
|
||
self.assertIn('森林', prompt)
|
||
self.assertIn('魔法棒', prompt)
|
||
|
||
def test_build_user_prompt_partial(self):
|
||
"""测试构建提示词 - 部分参数"""
|
||
from apps.stories.services.llm_service import build_user_prompt
|
||
|
||
prompt = build_user_prompt(['公主'], [], [])
|
||
self.assertIn('公主', prompt)
|
||
self.assertNotIn('场景', prompt)
|
||
self.assertNotIn('道具', prompt)
|
||
|
||
def test_parse_story_json_valid(self):
|
||
"""测试解析故事 JSON - 有效"""
|
||
from apps.stories.services.llm_service import _parse_story_json
|
||
|
||
text = '{"title": "小猫冒险", "content": "从前有一只小猫..."}'
|
||
result = _parse_story_json(text)
|
||
self.assertEqual(result['title'], '小猫冒险')
|
||
self.assertEqual(result['content'], '从前有一只小猫...')
|
||
|
||
def test_parse_story_json_with_markdown(self):
|
||
"""测试解析故事 JSON - 包含 markdown 代码块"""
|
||
from apps.stories.services.llm_service import _parse_story_json
|
||
|
||
text = '```json\n{"title": "森林故事", "content": "在深深的森林里..."}\n```'
|
||
result = _parse_story_json(text)
|
||
self.assertEqual(result['title'], '森林故事')
|
||
|
||
def test_parse_story_json_invalid(self):
|
||
"""测试解析故事 JSON - 无效 JSON"""
|
||
from apps.stories.services.llm_service import _parse_story_json
|
||
|
||
text = '这不是一个有效的 JSON 格式的文本'
|
||
result = _parse_story_json(text)
|
||
self.assertEqual(result['title'], '新故事')
|
||
self.assertIn('这不是', result['content'])
|
||
|
||
def test_sse_event_format(self):
|
||
"""测试 SSE 事件格式化"""
|
||
from apps.stories.services.llm_service import _sse_event
|
||
|
||
event = _sse_event('stage', {'stage': 'connecting', 'progress': 0})
|
||
self.assertTrue(event.startswith('event: stage\n'))
|
||
self.assertIn('data: ', event)
|
||
self.assertTrue(event.endswith('\n\n'))
|
||
|
||
def test_generate_stream_without_api_key(self):
|
||
"""测试未配置 API Key 时返回错误事件"""
|
||
from apps.stories.services.llm_service import generate_story_stream
|
||
from unittest.mock import patch
|
||
|
||
with patch('apps.stories.services.llm_service.settings') as mock_settings:
|
||
mock_settings.LLM_CONFIG = {'API_KEY': '', 'API_BASE_URL': '', 'MODEL_NAME': ''}
|
||
events = list(generate_story_stream(['小猫'], [], []))
|
||
|
||
self.assertEqual(len(events), 1)
|
||
self.assertIn('error', events[0])
|
||
self.assertIn('未配置', events[0])
|
||
|
||
|
||
# ==================== 音乐模块测试 ====================
|
||
|
||
class MusicTestBase(APITestCase):
|
||
"""音乐模块测试基类"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='13800140001', nickname='音乐测试用户')
|
||
tokens = get_app_tokens(self.user)
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
|
||
|
||
|
||
class MusicPlaylistTests(MusicTestBase):
|
||
"""播放列表接口测试"""
|
||
|
||
def test_playlist_auto_create_defaults(self):
|
||
"""测试首次获取播放列表自动创建 3 首默认曲目"""
|
||
url = '/api/v1/music/playlist/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
playlist = response.data['data']['playlist']
|
||
self.assertEqual(len(playlist), 3)
|
||
# 验证都是默认曲目
|
||
for t in playlist:
|
||
self.assertTrue(t['is_default'])
|
||
self.assertEqual(t['generation_status'], 'completed')
|
||
|
||
def test_playlist_default_track_titles(self):
|
||
"""测试默认曲目标题"""
|
||
url = '/api/v1/music/playlist/'
|
||
response = self.client.get(url)
|
||
|
||
titles = [t['title'] for t in response.data['data']['playlist']]
|
||
self.assertIn('卡皮巴拉蹦蹦蹦', titles)
|
||
self.assertIn('卡皮巴拉快乐水', titles)
|
||
self.assertIn('卡皮巴拉快乐营业', titles)
|
||
|
||
def test_playlist_defaults_not_duplicated(self):
|
||
"""测试多次请求不重复创建默认曲目"""
|
||
url = '/api/v1/music/playlist/'
|
||
self.client.get(url)
|
||
self.client.get(url)
|
||
|
||
count = Track.objects.filter(user=self.user, is_default=True).count()
|
||
self.assertEqual(count, 3)
|
||
|
||
def test_playlist_ordering(self):
|
||
"""测试播放列表排序:用户歌曲在前,默认歌曲在后"""
|
||
# 先创建默认曲目
|
||
url = '/api/v1/music/playlist/'
|
||
self.client.get(url)
|
||
|
||
# 创建用户歌曲
|
||
Track.objects.create(
|
||
user=self.user, title='我的歌', is_default=False,
|
||
generation_status='completed'
|
||
)
|
||
|
||
response = self.client.get(url)
|
||
playlist = response.data['data']['playlist']
|
||
self.assertEqual(len(playlist), 4)
|
||
# 第一首应该是用户歌曲
|
||
self.assertEqual(playlist[0]['title'], '我的歌')
|
||
self.assertFalse(playlist[0]['is_default'])
|
||
# 后三首是默认曲目
|
||
for t in playlist[1:]:
|
||
self.assertTrue(t['is_default'])
|
||
|
||
def test_playlist_user_isolation(self):
|
||
"""测试播放列表用户隔离"""
|
||
other_user = User.objects.create_user(phone='13800140099')
|
||
Track.objects.create(
|
||
user=other_user, title='他人的歌',
|
||
generation_status='completed'
|
||
)
|
||
|
||
url = '/api/v1/music/playlist/'
|
||
response = self.client.get(url)
|
||
|
||
titles = [t['title'] for t in response.data['data']['playlist']]
|
||
self.assertNotIn('他人的歌', titles)
|
||
|
||
def test_playlist_default_tracks_have_audio_url(self):
|
||
"""测试默认曲目有 audio_url"""
|
||
url = '/api/v1/music/playlist/'
|
||
response = self.client.get(url)
|
||
|
||
for t in response.data['data']['playlist']:
|
||
self.assertTrue(t['audio_url'])
|
||
self.assertIn('qy-rtc', t['audio_url'])
|
||
|
||
def test_playlist_default_tracks_have_lyrics(self):
|
||
"""测试默认曲目有歌词"""
|
||
url = '/api/v1/music/playlist/'
|
||
response = self.client.get(url)
|
||
|
||
for t in response.data['data']['playlist']:
|
||
self.assertTrue(t['lyrics'])
|
||
|
||
|
||
class MusicDeleteTests(MusicTestBase):
|
||
"""删除音乐接口测试"""
|
||
|
||
def test_delete_user_track(self):
|
||
"""测试删除用户生成的曲目"""
|
||
track = Track.objects.create(
|
||
user=self.user, title='待删除歌曲',
|
||
generation_status='completed'
|
||
)
|
||
|
||
url = f'/api/v1/music/{track.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertFalse(Track.objects.filter(id=track.id).exists())
|
||
|
||
def test_delete_default_track_rejected(self):
|
||
"""测试删除默认曲目 - 应被拒绝"""
|
||
track = Track.objects.create(
|
||
user=self.user, title='默认歌曲',
|
||
is_default=True, generation_status='completed'
|
||
)
|
||
|
||
url = f'/api/v1/music/{track.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.data['code'], 703) # MUSIC_DEFAULT_UNDELETABLE
|
||
self.assertTrue(Track.objects.filter(id=track.id).exists())
|
||
|
||
def test_delete_track_not_found(self):
|
||
"""测试删除不存在的曲目"""
|
||
url = '/api/v1/music/99999/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.data['code'], 700) # TRACK_NOT_FOUND
|
||
|
||
def test_delete_other_user_track(self):
|
||
"""测试删除他人的曲目"""
|
||
other_user = User.objects.create_user(phone='13800140002')
|
||
other_track = Track.objects.create(
|
||
user=other_user, title='他人歌曲',
|
||
generation_status='completed'
|
||
)
|
||
|
||
url = f'/api/v1/music/{other_track.id}/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.data['code'], 700)
|
||
self.assertTrue(Track.objects.filter(id=other_track.id).exists())
|
||
|
||
|
||
class MusicFavoriteTests(MusicTestBase):
|
||
"""收藏接口测试"""
|
||
|
||
def test_favorite_toggle(self):
|
||
"""测试收藏/取消收藏"""
|
||
track = Track.objects.create(
|
||
user=self.user, title='测试歌曲',
|
||
generation_status='completed'
|
||
)
|
||
|
||
url = f'/api/v1/music/{track.id}/favorite/'
|
||
|
||
# 收藏
|
||
response = self.client.post(url)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertTrue(response.data['data']['is_favorite'])
|
||
|
||
# 取消收藏
|
||
response = self.client.post(url)
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertFalse(response.data['data']['is_favorite'])
|
||
|
||
def test_favorite_not_found(self):
|
||
"""测试收藏不存在的曲目"""
|
||
url = '/api/v1/music/99999/favorite/'
|
||
response = self.client.post(url)
|
||
|
||
self.assertEqual(response.data['code'], 700)
|
||
|
||
|
||
class MusicGenerateTests(MusicTestBase):
|
||
"""音乐生成接口测试"""
|
||
|
||
def test_generate_points_not_enough(self):
|
||
"""测试生成音乐 - 积分不足"""
|
||
self.user.points = 50
|
||
self.user.save(update_fields=['points'])
|
||
|
||
url = '/api/v1/music/generate/'
|
||
data = {'text': '开心的一天', 'mood': 'happy'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 603) # POINTS_NOT_ENOUGH
|
||
# 积分不应变化
|
||
self.user.refresh_from_db()
|
||
self.assertEqual(self.user.points, 50)
|
||
|
||
def test_generate_zero_points(self):
|
||
"""测试生成音乐 - 零积分"""
|
||
url = '/api/v1/music/generate/'
|
||
data = {'text': '开心的一天', 'mood': 'happy'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 603)
|
||
|
||
def test_generate_returns_sse(self):
|
||
"""测试生成音乐返回 SSE 流"""
|
||
self.user.points = 200
|
||
self.user.save(update_fields=['points'])
|
||
|
||
mock_events = [
|
||
'data: {"stage":"lyrics","progress":10,"message":"AI 正在创作词曲..."}\n\n',
|
||
'data: {"stage":"done","progress":100,"message":"新歌出炉!","track_id":1}\n\n',
|
||
]
|
||
|
||
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
|
||
mock_gen.return_value = iter(mock_events)
|
||
|
||
url = '/api/v1/music/generate/'
|
||
data = {'text': '开心的一天', 'mood': 'happy'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response['Content-Type'], 'text/event-stream')
|
||
|
||
def test_generate_deducts_points(self):
|
||
"""测试生成音乐扣除积分"""
|
||
self.user.points = 200
|
||
self.user.save(update_fields=['points'])
|
||
|
||
mock_events = [
|
||
'data: {"stage":"done","progress":100}\n\n',
|
||
]
|
||
|
||
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
|
||
mock_gen.return_value = iter(mock_events)
|
||
|
||
url = '/api/v1/music/generate/'
|
||
data = {'text': '开心的一天', 'mood': 'happy'}
|
||
self.client.post(url, data, format='json')
|
||
|
||
# 验证积分扣除
|
||
self.user.refresh_from_db()
|
||
self.assertEqual(self.user.points, 100)
|
||
|
||
# 验证积分流水
|
||
record = PointsRecord.objects.filter(
|
||
user=self.user, type='generate_music'
|
||
).first()
|
||
self.assertIsNotNone(record)
|
||
self.assertEqual(record.amount, -100)
|
||
|
||
def test_generate_creates_track(self):
|
||
"""测试生成音乐创建 Track 记录"""
|
||
self.user.points = 200
|
||
self.user.save(update_fields=['points'])
|
||
|
||
mock_events = [
|
||
'data: {"stage":"done","progress":100}\n\n',
|
||
]
|
||
|
||
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
|
||
mock_gen.return_value = iter(mock_events)
|
||
|
||
url = '/api/v1/music/generate/'
|
||
data = {'text': '开心的一天', 'mood': 'happy'}
|
||
self.client.post(url, data, format='json')
|
||
|
||
track = Track.objects.filter(user=self.user, is_default=False).first()
|
||
self.assertIsNotNone(track)
|
||
self.assertEqual(track.mood, 'happy')
|
||
self.assertEqual(track.prompt, '开心的一天')
|
||
self.assertEqual(track.generation_status, 'generating')
|
||
|
||
def test_generate_empty_text_allowed(self):
|
||
"""测试生成音乐 - 空 text(random 模式允许)"""
|
||
self.user.points = 200
|
||
self.user.save(update_fields=['points'])
|
||
|
||
mock_events = [
|
||
'data: {"stage":"done","progress":100}\n\n',
|
||
]
|
||
|
||
with patch('apps.music.services.music_generation_service.generate_music_stream') as mock_gen:
|
||
mock_gen.return_value = iter(mock_events)
|
||
|
||
url = '/api/v1/music/generate/'
|
||
data = {'mood': 'random'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||
self.assertEqual(response['Content-Type'], 'text/event-stream')
|
||
|
||
|
||
class MusicServiceTests(TestCase):
|
||
"""音乐生成服务单元测试"""
|
||
|
||
def test_clean_lyrics(self):
|
||
"""测试歌词清洗"""
|
||
from apps.music.services.music_generation_service import clean_lyrics
|
||
|
||
raw = '[verse 1]\n泡在温泉里\n[chorus]\n咔咔咔咔\n[outro]\n(水花声...)'
|
||
cleaned = clean_lyrics(raw)
|
||
self.assertNotIn('[verse', cleaned)
|
||
self.assertNotIn('[chorus]', cleaned)
|
||
self.assertIn('泡在温泉里', cleaned)
|
||
self.assertIn('咔咔咔咔', cleaned)
|
||
|
||
def test_clean_lyrics_empty(self):
|
||
"""测试清洗空歌词"""
|
||
from apps.music.services.music_generation_service import clean_lyrics
|
||
|
||
self.assertEqual(clean_lyrics(''), '')
|
||
self.assertEqual(clean_lyrics(None), None)
|
||
|
||
def test_sse_event_format(self):
|
||
"""测试 SSE 事件格式"""
|
||
from apps.music.services.music_generation_service import sse_event
|
||
|
||
event = sse_event({"stage": "lyrics", "progress": 10})
|
||
self.assertTrue(event.startswith('data: '))
|
||
self.assertTrue(event.endswith('\n\n'))
|
||
data = json.loads(event.replace('data: ', '').strip())
|
||
self.assertEqual(data['stage'], 'lyrics')
|
||
self.assertEqual(data['progress'], 10)
|
||
|
||
def test_parse_llm_json_valid(self):
|
||
"""测试解析有效 JSON"""
|
||
from apps.music.services.music_generation_service import _parse_llm_json
|
||
|
||
text = '{"song_title": "咔咔之歌", "style": "Pop music", "lyrics": "[verse]\\n泡温泉"}'
|
||
result = _parse_llm_json(text)
|
||
self.assertEqual(result['song_title'], '咔咔之歌')
|
||
self.assertEqual(result['style'], 'Pop music')
|
||
|
||
def test_parse_llm_json_with_markdown(self):
|
||
"""测试解析 markdown 包裹的 JSON"""
|
||
from apps.music.services.music_generation_service import _parse_llm_json
|
||
|
||
text = '```json\n{"song_title": "温泉曲", "style": "Lofi", "lyrics": "la la la"}\n```'
|
||
result = _parse_llm_json(text)
|
||
self.assertEqual(result['song_title'], '温泉曲')
|
||
|
||
def test_parse_llm_json_invalid_fallback(self):
|
||
"""测试解析无效 JSON 时的正则回退"""
|
||
from apps.music.services.music_generation_service import _parse_llm_json
|
||
|
||
# Malformed JSON with unescaped newlines in string values
|
||
text = '{"song_title": "测试歌", "style": "Pop", "lyrics": "line1\nline2"}'
|
||
# json.loads should handle this since \n is valid in JSON
|
||
result = _parse_llm_json(text)
|
||
self.assertIn('song_title', result)
|
||
|
||
def test_generate_stream_without_api_key(self):
|
||
"""测试未配置 API Key 时退还积分"""
|
||
from apps.music.services.music_generation_service import generate_music_stream
|
||
|
||
user = User.objects.create_user(phone='13800149001', nickname='测试')
|
||
user.points = 100
|
||
user.save(update_fields=['points'])
|
||
track = Track.objects.create(
|
||
user=user, title='测试', generation_status='generating'
|
||
)
|
||
|
||
with patch('apps.music.services.music_generation_service._get_api_key', return_value=''):
|
||
events = list(generate_music_stream(user, track, '测试', 'happy'))
|
||
|
||
self.assertEqual(len(events), 1)
|
||
self.assertIn('error', events[0])
|
||
# SSE uses ensure_ascii=True, so check the decoded JSON
|
||
event_data = json.loads(events[0].replace('data: ', '').strip())
|
||
self.assertIn('未配置', event_data['message'])
|
||
|
||
# 验证积分退还
|
||
user.refresh_from_db()
|
||
self.assertEqual(user.points, 200) # 100 original + 100 refunded
|
||
track.refresh_from_db()
|
||
self.assertEqual(track.generation_status, 'failed')
|
||
|
||
def test_refund_points(self):
|
||
"""测试积分退还"""
|
||
from apps.music.services.music_generation_service import _refund_points
|
||
|
||
user = User.objects.create_user(phone='13800149002', nickname='退款测试')
|
||
user.points = 50
|
||
user.save(update_fields=['points'])
|
||
track = Track.objects.create(
|
||
user=user, title='失败歌曲', generation_status='generating'
|
||
)
|
||
|
||
_refund_points(user, track)
|
||
|
||
user.refresh_from_db()
|
||
self.assertEqual(user.points, 150)
|
||
track.refresh_from_db()
|
||
self.assertEqual(track.generation_status, 'failed')
|
||
|
||
# 验证退款记录
|
||
record = PointsRecord.objects.filter(
|
||
user=user, type='refund_music'
|
||
).first()
|
||
self.assertIsNotNone(record)
|
||
self.assertEqual(record.amount, 100)
|
||
|
||
|
||
class DefaultTracksTests(TestCase):
|
||
"""默认曲目初始化测试"""
|
||
|
||
def test_ensure_default_tracks_creates_three(self):
|
||
"""测试 ensure_default_tracks 创建 3 首"""
|
||
from apps.music.utils import ensure_default_tracks
|
||
|
||
user = User.objects.create_user(phone='13800149010')
|
||
ensure_default_tracks(user)
|
||
|
||
count = Track.objects.filter(user=user, is_default=True).count()
|
||
self.assertEqual(count, 3)
|
||
|
||
def test_ensure_default_tracks_idempotent(self):
|
||
"""测试 ensure_default_tracks 幂等"""
|
||
from apps.music.utils import ensure_default_tracks
|
||
|
||
user = User.objects.create_user(phone='13800149011')
|
||
ensure_default_tracks(user)
|
||
ensure_default_tracks(user)
|
||
|
||
count = Track.objects.filter(user=user, is_default=True).count()
|
||
self.assertEqual(count, 3)
|
||
|
||
def test_default_tracks_have_all_fields(self):
|
||
"""测试默认曲目字段完整"""
|
||
from apps.music.utils import ensure_default_tracks
|
||
|
||
user = User.objects.create_user(phone='13800149012')
|
||
ensure_default_tracks(user)
|
||
|
||
for track in Track.objects.filter(user=user, is_default=True):
|
||
self.assertTrue(track.title)
|
||
self.assertTrue(track.lyrics)
|
||
self.assertTrue(track.audio_url)
|
||
self.assertTrue(track.cover_url)
|
||
self.assertEqual(track.generation_status, 'completed')
|
||
|
||
|
||
class MigrateHistoricalTracksTests(TestCase):
|
||
"""存量数据迁移命令测试"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='18682237028')
|
||
|
||
def test_creates_15_tracks(self):
|
||
"""测试迁移命令创建 15 首历史曲目"""
|
||
from django.core.management import call_command
|
||
from io import StringIO
|
||
|
||
out = StringIO()
|
||
call_command('migrate_historical_tracks', stdout=out)
|
||
|
||
count = Track.objects.filter(user=self.user, is_default=False).count()
|
||
self.assertEqual(count, 15)
|
||
|
||
def test_idempotent(self):
|
||
"""测试迁移命令幂等(重复执行不重复创建)"""
|
||
from django.core.management import call_command
|
||
from io import StringIO
|
||
|
||
out = StringIO()
|
||
call_command('migrate_historical_tracks', stdout=out)
|
||
call_command('migrate_historical_tracks', stdout=out)
|
||
|
||
count = Track.objects.filter(user=self.user, is_default=False).count()
|
||
self.assertEqual(count, 15)
|
||
|
||
def test_dry_run(self):
|
||
"""测试 dry-run 不写入数据库"""
|
||
from django.core.management import call_command
|
||
from io import StringIO
|
||
|
||
out = StringIO()
|
||
call_command('migrate_historical_tracks', '--dry-run', stdout=out)
|
||
|
||
count = Track.objects.filter(user=self.user).count()
|
||
self.assertEqual(count, 0)
|
||
self.assertIn('dry-run', out.getvalue())
|
||
|
||
def test_tracks_have_oss_urls(self):
|
||
"""测试所有迁移曲目都有 OSS URL"""
|
||
from django.core.management import call_command
|
||
from io import StringIO
|
||
|
||
call_command('migrate_historical_tracks', stdout=StringIO())
|
||
|
||
for track in Track.objects.filter(user=self.user):
|
||
self.assertTrue(track.audio_url.startswith('https://qy-rtc.oss-cn-beijing.aliyuncs.com/'))
|
||
self.assertTrue(track.cover_url.startswith('https://qy-rtc.oss-cn-beijing.aliyuncs.com/'))
|
||
|
||
|
||
# ==================== 角色记忆测试 ====================
|
||
|
||
class RoleMemoryTests(APITestCase):
|
||
"""角色记忆功能测试"""
|
||
|
||
def setUp(self):
|
||
self.user = User.objects.create_user(phone='13800139000', nickname='记忆测试用户')
|
||
tokens = get_app_tokens(self.user)
|
||
self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}')
|
||
|
||
self.device_type, _ = DeviceType.objects.get_or_create(
|
||
product_code='KPBL-ON-RM',
|
||
defaults={
|
||
'brand': 'AL',
|
||
'name': '卡皮巴拉-联网版',
|
||
'default_prompt': '你是一只可爱的卡皮巴拉',
|
||
'default_voice_id': 'voice_kpbl_01',
|
||
}
|
||
)
|
||
# 确保模板字段正确(get_or_create 可能返回已有记录)
|
||
if not self.device_type.default_prompt:
|
||
self.device_type.default_prompt = '你是一只可爱的卡皮巴拉'
|
||
self.device_type.default_voice_id = 'voice_kpbl_01'
|
||
self.device_type.save()
|
||
|
||
self.device, _ = Device.objects.get_or_create(
|
||
sn='AL-KPBL-ON-25W01-RM-00001',
|
||
defaults={
|
||
'device_type': self.device_type,
|
||
'status': 'in_stock',
|
||
}
|
||
)
|
||
# 重置设备状态
|
||
self.device.status = 'in_stock'
|
||
self.device.save()
|
||
# 清理旧绑定关系和角色记忆
|
||
UserDevice.objects.filter(user=self.user).delete()
|
||
RoleMemory.objects.filter(user=self.user).delete()
|
||
|
||
def test_bind_creates_role_memory(self):
|
||
"""测试绑定设备自动创建角色记忆"""
|
||
url = '/api/v1/devices/bind/'
|
||
data = {'sn': 'AL-KPBL-ON-25W01-RM-00001'}
|
||
response = self.client.post(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertIsNotNone(response.data['data']['role_memory'])
|
||
rm_data = response.data['data']['role_memory']
|
||
self.assertEqual(rm_data['prompt'], '你是一只可爱的卡皮巴拉')
|
||
self.assertEqual(rm_data['voice_id'], 'voice_kpbl_01')
|
||
self.assertTrue(rm_data['is_bound'])
|
||
self.assertEqual(rm_data['volume'], 50)
|
||
self.assertEqual(rm_data['brightness'], 50)
|
||
|
||
def test_bind_creates_new_memory_each_time(self):
|
||
"""测试每次绑定新设备都创建新的角色记忆"""
|
||
device2 = Device.objects.create(
|
||
sn='AL-KPBL-ON-25W01-A01-00002',
|
||
device_type=self.device_type,
|
||
status='in_stock'
|
||
)
|
||
# 绑定第一个设备
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
# 绑定第二个设备
|
||
self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json')
|
||
|
||
self.assertEqual(RoleMemory.objects.filter(user=self.user).count(), 2)
|
||
|
||
def test_unbind_marks_memory_idle(self):
|
||
"""测试解绑后角色记忆标记为闲置"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
rm = ud.role_memory
|
||
|
||
url = f'/api/v1/devices/{ud.id}/unbind/'
|
||
response = self.client.delete(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
rm.refresh_from_db()
|
||
self.assertFalse(rm.is_bound)
|
||
|
||
def test_unbind_preserves_memory(self):
|
||
"""测试解绑不删除角色记忆"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
self.client.delete(f'/api/v1/devices/{ud.id}/unbind/')
|
||
|
||
self.assertEqual(RoleMemory.objects.filter(user=self.user).count(), 1)
|
||
|
||
def test_get_role_memory(self):
|
||
"""测试获取角色记忆"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
url = f'/api/v1/devices/{ud.id}/role-memory/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['prompt'], '你是一只可爱的卡皮巴拉')
|
||
|
||
def test_update_role_memory_settings(self):
|
||
"""测试更新角色记忆设备设置"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
url = f'/api/v1/devices/{ud.id}/role-memory/settings/'
|
||
data = {'nickname': '我的卡皮', 'volume': 80, 'brightness': 30}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['nickname'], '我的卡皮')
|
||
self.assertEqual(response.data['data']['volume'], 80)
|
||
self.assertEqual(response.data['data']['brightness'], 30)
|
||
|
||
def test_update_role_memory_agent(self):
|
||
"""测试更新Agent信息"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
url = f'/api/v1/devices/{ud.id}/role-memory/agent/'
|
||
data = {'prompt': '你是一只会讲故事的卡皮巴拉', 'voice_id': 'voice_new'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['prompt'], '你是一只会讲故事的卡皮巴拉')
|
||
self.assertEqual(response.data['data']['voice_id'], 'voice_new')
|
||
|
||
def test_update_role_memory_summary(self):
|
||
"""测试更新聊天记忆摘要"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
url = f'/api/v1/devices/{ud.id}/role-memory/memory/'
|
||
data = {'memory_summary': '用户喜欢恐龙故事,不喜欢太吓人的情节'}
|
||
response = self.client.put(url, data, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['memory_summary'], '用户喜欢恐龙故事,不喜欢太吓人的情节')
|
||
|
||
def test_role_memory_list(self):
|
||
"""测试角色记忆列表"""
|
||
device2 = Device.objects.create(
|
||
sn='AL-KPBL-ON-25W01-A01-00002',
|
||
device_type=self.device_type,
|
||
status='in_stock'
|
||
)
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json')
|
||
|
||
url = '/api/v1/devices/role-memories/'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 2)
|
||
|
||
def test_role_memory_list_filter_by_device_type(self):
|
||
"""测试角色记忆列表按设备类型过滤"""
|
||
other_type = DeviceType.objects.create(
|
||
brand='AL', product_code='OTHER-ON', name='其他设备'
|
||
)
|
||
other_device = Device.objects.create(
|
||
sn='AL-OTHER-ON-25W01-A01-00001',
|
||
device_type=other_type,
|
||
status='in_stock'
|
||
)
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
self.client.post('/api/v1/devices/bind/', {'sn': other_device.sn}, format='json')
|
||
|
||
url = f'/api/v1/devices/role-memories/?device_type_id={self.device_type.id}'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 1)
|
||
self.assertEqual(response.data['data'][0]['device_type'], self.device_type.id)
|
||
|
||
def test_role_memory_list_filter_by_is_bound(self):
|
||
"""测试角色记忆列表按绑定状态过滤"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
# 创建一个闲置的记忆
|
||
RoleMemory.objects.create(
|
||
user=self.user, device_type=self.device_type, is_bound=False
|
||
)
|
||
|
||
url = '/api/v1/devices/role-memories/?is_bound=false'
|
||
response = self.client.get(url)
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(len(response.data['data']), 1)
|
||
self.assertFalse(response.data['data'][0]['is_bound'])
|
||
|
||
def test_switch_role_memory(self):
|
||
"""测试切换角色记忆"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
old_rm = ud.role_memory
|
||
|
||
# 创建一个闲置的同类型记忆
|
||
idle_rm = RoleMemory.objects.create(
|
||
user=self.user, device_type=self.device_type,
|
||
is_bound=False, prompt='闲置的提示词',
|
||
memory_summary='之前的记忆内容'
|
||
)
|
||
|
||
url = f'/api/v1/devices/{ud.id}/switch-role-memory/'
|
||
response = self.client.put(url, {'role_memory_id': idle_rm.id}, format='json')
|
||
|
||
self.assertEqual(response.data['code'], 0)
|
||
self.assertEqual(response.data['data']['prompt'], '闲置的提示词')
|
||
|
||
# 验证状态变化
|
||
old_rm.refresh_from_db()
|
||
idle_rm.refresh_from_db()
|
||
ud.refresh_from_db()
|
||
self.assertFalse(old_rm.is_bound)
|
||
self.assertTrue(idle_rm.is_bound)
|
||
self.assertEqual(ud.role_memory_id, idle_rm.id)
|
||
|
||
def test_switch_rejects_different_type(self):
|
||
"""测试切换到不同类型的记忆被拒绝"""
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
other_type = DeviceType.objects.create(
|
||
brand='AL', product_code='OTHER-ON', name='其他设备'
|
||
)
|
||
other_rm = RoleMemory.objects.create(
|
||
user=self.user, device_type=other_type, is_bound=False
|
||
)
|
||
|
||
url = f'/api/v1/devices/{ud.id}/switch-role-memory/'
|
||
response = self.client.put(url, {'role_memory_id': other_rm.id}, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_switch_rejects_bound_memory(self):
|
||
"""测试切换到已绑定的记忆被拒绝"""
|
||
device2 = Device.objects.create(
|
||
sn='AL-KPBL-ON-25W01-A01-00002',
|
||
device_type=self.device_type,
|
||
status='in_stock'
|
||
)
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json')
|
||
|
||
ud1 = UserDevice.objects.get(user=self.user, device=self.device)
|
||
ud2 = UserDevice.objects.get(user=self.user, device=device2)
|
||
|
||
url = f'/api/v1/devices/{ud1.id}/switch-role-memory/'
|
||
response = self.client.put(url, {'role_memory_id': ud2.role_memory_id}, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|
||
def test_user_isolation(self):
|
||
"""测试用户隔离 - 不能访问其他用户的角色记忆"""
|
||
other_user = User.objects.create_user(phone='13800139001')
|
||
other_rm = RoleMemory.objects.create(
|
||
user=other_user, device_type=self.device_type, is_bound=False
|
||
)
|
||
|
||
self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json')
|
||
ud = UserDevice.objects.get(user=self.user, device=self.device)
|
||
|
||
# 尝试切换到其他用户的记忆
|
||
url = f'/api/v1/devices/{ud.id}/switch-role-memory/'
|
||
response = self.client.put(url, {'role_memory_id': other_rm.id}, format='json')
|
||
|
||
self.assertNotEqual(response.data['code'], 0)
|
||
|