From 9ac26a5f1108a52576032ac418ee3105c4e39b1c Mon Sep 17 00:00:00 2001 From: repair-agent Date: Fri, 27 Feb 2026 16:38:50 +0800 Subject: [PATCH] fix: auto repair bugs #52 --- ...004_rolememory_and_devicetype_templates.py | 64 ++++ apps/devices/models.py | 48 ++- apps/devices/serializers.py | 67 ++++- apps/devices/views.py | 254 ++++++++++++++-- tests.py | 273 +++++++++++++++++- utils/exceptions.py | 3 + 6 files changed, 668 insertions(+), 41 deletions(-) create mode 100644 apps/devices/migrations/0004_rolememory_and_devicetype_templates.py diff --git a/apps/devices/migrations/0004_rolememory_and_devicetype_templates.py b/apps/devices/migrations/0004_rolememory_and_devicetype_templates.py new file mode 100644 index 0000000..3a23a5f --- /dev/null +++ b/apps/devices/migrations/0004_rolememory_and_devicetype_templates.py @@ -0,0 +1,64 @@ +# Generated by Django 6.0.1 on 2026-02-27 06:23 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('devices', '0003_device_battery_device_icon_device_is_ai_and_more'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.AddField( + model_name='devicetype', + name='default_prompt', + field=models.TextField(blank=True, default='', verbose_name='默认提示词模板'), + ), + migrations.AddField( + model_name='devicetype', + name='default_voice_id', + field=models.CharField(blank=True, default='', max_length=100, verbose_name='默认音色ID'), + ), + migrations.CreateModel( + name='RoleMemory', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('is_bound', models.BooleanField(default=True, verbose_name='是否已绑定设备')), + ('nickname', models.CharField(blank=True, default='', max_length=50, verbose_name='设备昵称')), + ('user_name', models.CharField(blank=True, default='', max_length=50, verbose_name='用户称呼')), + ('volume', models.IntegerField(default=50, verbose_name='音量')), + ('brightness', models.IntegerField(default=50, verbose_name='亮度')), + ('allow_interrupt', models.BooleanField(default=True, verbose_name='允许打断')), + ('privacy_mode', models.BooleanField(default=False, verbose_name='隐私模式')), + ('prompt', models.TextField(blank=True, default='', verbose_name='提示词')), + ('voice_id', models.CharField(blank=True, default='', max_length=100, verbose_name='音色ID')), + ('memory_summary', models.TextField(blank=True, default='', verbose_name='聊天记忆摘要')), + ('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')), + ('device_type', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='role_memories', to='devices.devicetype', verbose_name='设备类型')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='role_memories', to=settings.AUTH_USER_MODEL, verbose_name='用户')), + ], + options={ + 'verbose_name': '角色记忆', + 'verbose_name_plural': '角色记忆', + 'db_table': 'role_memory', + }, + ), + migrations.AddField( + model_name='userdevice', + name='role_memory', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='user_devices', to='devices.rolememory', verbose_name='角色记忆'), + ), + migrations.AddIndex( + model_name='rolememory', + index=models.Index(fields=['user', 'device_type'], name='role_memory_user_id_c7dd09_idx'), + ), + migrations.AddIndex( + model_name='rolememory', + index=models.Index(fields=['is_bound'], name='role_memory_is_boun_556b2c_idx'), + ), + ] diff --git a/apps/devices/models.py b/apps/devices/models.py index 746941d..519ad5e 100644 --- a/apps/devices/models.py +++ b/apps/devices/models.py @@ -15,9 +15,11 @@ class DeviceType(models.Model): name = models.CharField('名称', max_length=100) is_network_required = models.BooleanField('是否需要联网', default=True) is_active = models.BooleanField('是否启用', default=True) + default_prompt = models.TextField('默认提示词模板', blank=True, default='') + default_voice_id = models.CharField('默认音色ID', max_length=100, blank=True, default='') created_at = models.DateTimeField('创建时间', auto_now_add=True) updated_at = models.DateTimeField('更新时间', auto_now=True) - + class Meta: db_table = 'device_type' verbose_name = '设备类型' @@ -121,6 +123,11 @@ class UserDevice(models.Model): user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='user_devices', verbose_name='用户') device = models.ForeignKey(Device, on_delete=models.CASCADE, related_name='user_devices', verbose_name='设备') spirit = models.ForeignKey(Spirit, on_delete=models.SET_NULL, null=True, blank=True, related_name='user_devices', verbose_name='绑定的智能体') + role_memory = models.ForeignKey( + 'RoleMemory', on_delete=models.SET_NULL, + null=True, blank=True, + related_name='user_devices', verbose_name='角色记忆' + ) bind_type = models.CharField('绑定类型', max_length=20, choices=BIND_TYPE_CHOICES, default='owner') bind_time = models.DateTimeField('绑定时间', auto_now_add=True) is_active = models.BooleanField('是否有效', default=True) @@ -178,3 +185,42 @@ class DeviceWifi(models.Model): def __str__(self): return f"{self.device.sn} - {self.ssid}" + + +class RoleMemory(models.Model): + """角色记忆 - 按用户+设备类型存储,同类型可有多个""" + + id = models.BigAutoField(primary_key=True) + user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='role_memories', verbose_name='用户') + device_type = models.ForeignKey(DeviceType, on_delete=models.PROTECT, related_name='role_memories', verbose_name='设备类型') + is_bound = models.BooleanField('是否已绑定设备', default=True) + + # 基础设备设置 + nickname = models.CharField('设备昵称', max_length=50, blank=True, default='') + user_name = models.CharField('用户称呼', max_length=50, blank=True, default='') + volume = models.IntegerField('音量', default=50) + brightness = models.IntegerField('亮度', default=50) + allow_interrupt = models.BooleanField('允许打断', default=True) + privacy_mode = models.BooleanField('隐私模式', default=False) + + # Agent 信息 + prompt = models.TextField('提示词', blank=True, default='') + voice_id = models.CharField('音色ID', max_length=100, blank=True, default='') + + # 聊天记忆(摘要式) + memory_summary = models.TextField('聊天记忆摘要', blank=True, default='') + + created_at = models.DateTimeField('创建时间', auto_now_add=True) + updated_at = models.DateTimeField('更新时间', auto_now=True) + + class Meta: + db_table = 'role_memory' + verbose_name = '角色记忆' + verbose_name_plural = '角色记忆' + indexes = [ + models.Index(fields=['user', 'device_type']), + models.Index(fields=['is_bound']), + ] + + def __str__(self): + return f"{self.user.phone} - {self.device_type.name} - #{self.id}" diff --git a/apps/devices/serializers.py b/apps/devices/serializers.py index c08d7b4..d237f69 100644 --- a/apps/devices/serializers.py +++ b/apps/devices/serializers.py @@ -2,15 +2,19 @@ 设备模块序列化器 """ from rest_framework import serializers -from .models import DeviceType, DeviceBatch, Device, UserDevice, DeviceSettings, DeviceWifi +from .models import DeviceType, DeviceBatch, Device, UserDevice, DeviceSettings, DeviceWifi, RoleMemory class DeviceTypeSerializer(serializers.ModelSerializer): """设备类型序列化器""" - + + default_prompt = serializers.CharField(required=False, default='', allow_blank=True) + default_voice_id = serializers.CharField(required=False, default='', allow_blank=True) + class Meta: model = DeviceType - fields = ['id', 'brand', 'product_code', 'name', 'is_network_required', 'is_active', 'created_at'] + fields = ['id', 'brand', 'product_code', 'name', 'is_network_required', 'is_active', + 'default_prompt', 'default_voice_id', 'created_at'] read_only_fields = ['id', 'is_network_required', 'created_at'] @@ -53,15 +57,35 @@ class DeviceSimpleSerializer(serializers.ModelSerializer): fields = ['id', 'sn', 'mac_address', 'status', 'created_at'] +class RoleMemorySerializer(serializers.ModelSerializer): + """角色记忆序列化器""" + device_type_name = serializers.CharField(source='device_type.name', read_only=True) + + class Meta: + model = RoleMemory + fields = [ + 'id', 'device_type', 'device_type_name', 'is_bound', + 'nickname', 'user_name', 'volume', 'brightness', + 'allow_interrupt', 'privacy_mode', + 'prompt', 'voice_id', + 'memory_summary', + 'created_at', 'updated_at', + ] + read_only_fields = ['id', 'device_type', 'device_type_name', 'is_bound', + 'created_at', 'updated_at'] + + class UserDeviceSerializer(serializers.ModelSerializer): """用户设备绑定序列化器""" - + device = DeviceSerializer(read_only=True) spirit_name = serializers.CharField(source='spirit.name', read_only=True, allow_null=True) - + role_memory = RoleMemorySerializer(read_only=True) + class Meta: model = UserDevice - fields = ['id', 'device', 'spirit', 'spirit_name', 'bind_type', 'bind_time', 'is_active'] + fields = ['id', 'device', 'spirit', 'spirit_name', 'role_memory', + 'bind_type', 'bind_time', 'is_active'] class BindDeviceSerializer(serializers.Serializer): @@ -112,11 +136,13 @@ class DeviceDetailSerializer(serializers.ModelSerializer): wifi_list = DeviceWifiSerializer(many=True, read_only=True) status = serializers.SerializerMethodField() bound_spirit = serializers.SerializerMethodField() + role_memory = serializers.SerializerMethodField() class Meta: model = Device fields = ['id', 'sn', 'name', 'status', 'battery', 'firmware_version', - 'mac_address', 'is_ai', 'icon', 'settings', 'wifi_list', 'bound_spirit'] + 'mac_address', 'is_ai', 'icon', 'settings', 'wifi_list', + 'bound_spirit', 'role_memory'] def get_status(self, obj): return 'online' if obj.is_online else 'offline' @@ -127,6 +153,12 @@ class DeviceDetailSerializer(serializers.ModelSerializer): return {'id': user_device.spirit.id, 'name': user_device.spirit.name} return None + def get_role_memory(self, obj): + user_device = self.context.get('user_device') + if user_device and user_device.role_memory: + return RoleMemorySerializer(user_device.role_memory).data + return None + class DeviceSettingsUpdateSerializer(serializers.Serializer): """更新设备设置序列化器""" @@ -149,3 +181,24 @@ class DeviceReportStatusSerializer(serializers.Serializer): def validate_mac_address(self, value): return value.upper().replace('-', ':') + + +class RoleMemorySettingsUpdateSerializer(serializers.Serializer): + """更新角色记忆-设备设置""" + nickname = serializers.CharField(max_length=50, required=False) + user_name = serializers.CharField(max_length=50, required=False) + volume = serializers.IntegerField(min_value=0, max_value=100, required=False) + brightness = serializers.IntegerField(min_value=0, max_value=100, required=False) + allow_interrupt = serializers.BooleanField(required=False) + privacy_mode = serializers.BooleanField(required=False) + + +class RoleMemoryAgentUpdateSerializer(serializers.Serializer): + """更新角色记忆-Agent信息""" + prompt = serializers.CharField(required=False, allow_blank=True) + voice_id = serializers.CharField(max_length=100, required=False, allow_blank=True) + + +class RoleMemoryMemoryUpdateSerializer(serializers.Serializer): + """更新角色记忆-聊天记忆摘要""" + memory_summary = serializers.CharField(required=True, allow_blank=True) diff --git a/apps/devices/views.py b/apps/devices/views.py index ffd65b9..fe6711a 100644 --- a/apps/devices/views.py +++ b/apps/devices/views.py @@ -9,7 +9,7 @@ from drf_spectacular.utils import extend_schema from utils.response import success, error from utils.exceptions import ErrorCode from apps.admins.authentication import AppJWTAuthentication -from .models import Device, UserDevice, DeviceType, DeviceSettings, DeviceWifi +from .models import Device, UserDevice, DeviceType, DeviceSettings, DeviceWifi, RoleMemory from .serializers import ( DeviceSerializer, UserDeviceSerializer, @@ -19,17 +19,21 @@ from .serializers import ( DeviceDetailSerializer, DeviceSettingsUpdateSerializer, DeviceReportStatusSerializer, + RoleMemorySerializer, + RoleMemorySettingsUpdateSerializer, + RoleMemoryAgentUpdateSerializer, + RoleMemoryMemoryUpdateSerializer, ) @extend_schema(tags=['设备']) class DeviceViewSet(viewsets.ViewSet): """设备视图集(App端)""" - + authentication_classes = [AppJWTAuthentication] permission_classes = [IsAuthenticated] - - @action(detail=False, methods=['get'], url_path='query-by-mac', + + @action(detail=False, methods=['get'], url_path='query-by-mac', authentication_classes=[], permission_classes=[AllowAny]) def query_by_mac(self, request): """ @@ -39,10 +43,10 @@ class DeviceViewSet(viewsets.ViewSet): mac = request.query_params.get('mac', '') if not mac: return error(message='MAC地址不能为空') - + # 统一格式 mac = mac.upper().replace('-', ':') - + try: device = Device.objects.select_related('device_type').get(mac_address=mac) return success(data={ @@ -54,11 +58,11 @@ class DeviceViewSet(viewsets.ViewSet): }) except Device.DoesNotExist: return error( - code=404, + code=404, message='未找到对应的设备,请检查MAC地址是否正确或设备是否已完成入库', status_code=status.HTTP_404_NOT_FOUND ) - + @action(detail=False, methods=['get']) def latest(self, request): """ @@ -69,7 +73,7 @@ class DeviceViewSet(viewsets.ViewSet): UserDevice.objects.filter( user=request.user, is_active=True - ).select_related('device', 'device__device_type', 'spirit') + ).select_related('device', 'device__device_type', 'spirit', 'role_memory') .order_by('-bind_time')[:1] ) if not devices: @@ -86,23 +90,23 @@ class DeviceViewSet(viewsets.ViewSet): serializer = DeviceVerifySerializer(data=request.data) if not serializer.is_valid(): return error(message=str(serializer.errors)) - + sn = serializer.validated_data['sn'] - + try: device = Device.objects.select_related('device_type').get(sn=sn) except Device.DoesNotExist: return error(code=ErrorCode.DEVICE_NOT_FOUND, message='设备不存在') - + # 检查是否已被绑定 is_bindable = device.status != 'bound' - + return success(data={ 'sn': device.sn, 'is_bindable': is_bindable, 'device_type': DeviceTypeSerializer(device.device_type).data if device.device_type else None }) - + @action(detail=False, methods=['post']) def bind(self, request): """ @@ -112,28 +116,40 @@ class DeviceViewSet(viewsets.ViewSet): serializer = BindDeviceSerializer(data=request.data) if not serializer.is_valid(): return error(message=str(serializer.errors)) - + sn = serializer.validated_data['sn'] spirit_id = serializer.validated_data.get('spirit_id') - + try: - device = Device.objects.get(sn=sn) + device = Device.objects.select_related('device_type').get(sn=sn) except Device.DoesNotExist: return error(code=ErrorCode.DEVICE_NOT_FOUND, message='设备不存在') - + # 检查是否已被绑定 if device.status == 'bound': # 检查是否是当前用户绑定的 existing = UserDevice.objects.filter(device=device, is_active=True).first() if existing and existing.user != request.user: return error(code=ErrorCode.DEVICE_ALREADY_BOUND, message='设备已被其他用户绑定') - + + # 创建角色记忆 + role_memory = None + if device.device_type: + role_memory = RoleMemory.objects.create( + user=request.user, + device_type=device.device_type, + is_bound=True, + prompt=device.device_type.default_prompt, + voice_id=device.device_type.default_voice_id, + ) + # 创建绑定关系 user_device, created = UserDevice.objects.update_or_create( user=request.user, device=device, defaults={ 'spirit_id': spirit_id, + 'role_memory': role_memory, 'is_active': True } ) @@ -144,12 +160,12 @@ class DeviceViewSet(viewsets.ViewSet): device.name = device_name device.status = 'bound' device.save() - + return success( data=UserDeviceSerializer(user_device).data, message='绑定成功' if created else '更新绑定成功' ) - + @action(detail=False, methods=['get']) def my_devices(self, request): """ @@ -157,13 +173,13 @@ class DeviceViewSet(viewsets.ViewSet): GET /api/v1/devices/my_devices """ user_devices = UserDevice.objects.filter( - user=request.user, + user=request.user, is_active=True - ).select_related('device', 'device__device_type', 'spirit') - + ).select_related('device', 'device__device_type', 'spirit', 'role_memory') + serializer = UserDeviceSerializer(user_devices, many=True) return success(data=serializer.data) - + @action(detail=True, methods=['delete']) def unbind(self, request, pk=None): """ @@ -171,22 +187,29 @@ class DeviceViewSet(viewsets.ViewSet): DELETE /api/v1/devices/{id}/unbind """ try: - user_device = UserDevice.objects.get(id=pk, user=request.user) + user_device = UserDevice.objects.select_related('role_memory').get( + id=pk, user=request.user + ) except UserDevice.DoesNotExist: return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') - + # 更新绑定状态 user_device.is_active = False user_device.save() - + + # 将关联的角色记忆标记为闲置 + if user_device.role_memory: + user_device.role_memory.is_bound = False + user_device.role_memory.save(update_fields=['is_bound', 'updated_at']) + # 检查设备是否还有其他活跃绑定 active_bindings = UserDevice.objects.filter(device=user_device.device, is_active=True).count() if active_bindings == 0: user_device.device.status = 'out_stock' user_device.device.save() - + return success(message='解绑成功') - + @action(detail=True, methods=['put'], url_path='update-spirit') def update_spirit(self, request, pk=None): """ @@ -197,7 +220,7 @@ class DeviceViewSet(viewsets.ViewSet): user_device = UserDevice.objects.get(id=pk, user=request.user, is_active=True) except UserDevice.DoesNotExist: return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') - + spirit_id = request.data.get('spirit_id') user_device.spirit_id = spirit_id user_device.save() @@ -213,7 +236,7 @@ class DeviceViewSet(viewsets.ViewSet): """ try: user_device = UserDevice.objects.select_related( - 'device', 'spirit' + 'device', 'spirit', 'role_memory', 'role_memory__device_type' ).get(id=pk, user=request.user, is_active=True) except UserDevice.DoesNotExist: return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') @@ -329,3 +352,170 @@ class DeviceViewSet(viewsets.ViewSet): data={'device_id': device.id, 'sn': device.sn}, message='状态上报成功' ) + + # ==================== 角色记忆相关端点 ==================== + + @action(detail=True, methods=['get'], url_path='role-memory') + def get_role_memory(self, request, pk=None): + """ + 获取设备的角色记忆 + GET /api/v1/devices/{user_device_id}/role-memory/ + """ + try: + user_device = UserDevice.objects.select_related( + 'role_memory', 'role_memory__device_type' + ).get(id=pk, user=request.user, is_active=True) + except UserDevice.DoesNotExist: + return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') + + if not user_device.role_memory: + return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='角色记忆不存在') + + return success(data=RoleMemorySerializer(user_device.role_memory).data) + + @action(detail=True, methods=['put'], url_path='role-memory/settings') + def update_role_memory_settings(self, request, pk=None): + """ + 更新角色记忆-设备设置 + PUT /api/v1/devices/{user_device_id}/role-memory/settings/ + """ + try: + user_device = UserDevice.objects.select_related('role_memory').get( + id=pk, user=request.user, is_active=True + ) + except UserDevice.DoesNotExist: + return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') + + if not user_device.role_memory: + return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='角色记忆不存在') + + serializer = RoleMemorySettingsUpdateSerializer(data=request.data) + if not serializer.is_valid(): + return error(message=str(serializer.errors)) + + rm = user_device.role_memory + for field, value in serializer.validated_data.items(): + setattr(rm, field, value) + rm.save() + + return success(data=RoleMemorySerializer(rm).data, message='设置已保存') + + @action(detail=True, methods=['put'], url_path='role-memory/agent') + def update_role_memory_agent(self, request, pk=None): + """ + 更新角色记忆-Agent信息(提示词、音色) + PUT /api/v1/devices/{user_device_id}/role-memory/agent/ + """ + try: + user_device = UserDevice.objects.select_related('role_memory').get( + id=pk, user=request.user, is_active=True + ) + except UserDevice.DoesNotExist: + return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') + + if not user_device.role_memory: + return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='角色记忆不存在') + + serializer = RoleMemoryAgentUpdateSerializer(data=request.data) + if not serializer.is_valid(): + return error(message=str(serializer.errors)) + + rm = user_device.role_memory + for field, value in serializer.validated_data.items(): + setattr(rm, field, value) + rm.save() + + return success(data=RoleMemorySerializer(rm).data, message='Agent信息已更新') + + @action(detail=True, methods=['put'], url_path='role-memory/memory') + def update_role_memory_summary(self, request, pk=None): + """ + 更新角色记忆-聊天记忆摘要 + PUT /api/v1/devices/{user_device_id}/role-memory/memory/ + """ + try: + user_device = UserDevice.objects.select_related('role_memory').get( + id=pk, user=request.user, is_active=True + ) + except UserDevice.DoesNotExist: + return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') + + if not user_device.role_memory: + return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='角色记忆不存在') + + serializer = RoleMemoryMemoryUpdateSerializer(data=request.data) + if not serializer.is_valid(): + return error(message=str(serializer.errors)) + + rm = user_device.role_memory + rm.memory_summary = serializer.validated_data['memory_summary'] + rm.save(update_fields=['memory_summary', 'updated_at']) + + return success(data=RoleMemorySerializer(rm).data, message='记忆摘要已更新') + + @action(detail=False, methods=['get'], url_path='role-memories') + def role_memory_list(self, request): + """ + 获取角色记忆列表 + GET /api/v1/devices/role-memories/?device_type_id=1&is_bound=false + """ + qs = RoleMemory.objects.filter( + user=request.user + ).select_related('device_type').order_by('-created_at') + + device_type_id = request.query_params.get('device_type_id') + if device_type_id: + qs = qs.filter(device_type_id=device_type_id) + + is_bound = request.query_params.get('is_bound') + if is_bound is not None: + qs = qs.filter(is_bound=is_bound.lower() == 'true') + + return success(data=RoleMemorySerializer(qs, many=True).data) + + @action(detail=True, methods=['put'], url_path='switch-role-memory') + def switch_role_memory(self, request, pk=None): + """ + 切换设备的角色记忆 + PUT /api/v1/devices/{user_device_id}/switch-role-memory/ + body: { "role_memory_id": 5 } + """ + try: + user_device = UserDevice.objects.select_related( + 'device', 'device__device_type', 'role_memory' + ).get(id=pk, user=request.user, is_active=True) + except UserDevice.DoesNotExist: + return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在') + + role_memory_id = request.data.get('role_memory_id') + if not role_memory_id: + return error(message='请指定角色记忆ID') + + try: + new_rm = RoleMemory.objects.get(id=role_memory_id, user=request.user) + except RoleMemory.DoesNotExist: + return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='角色记忆不存在') + + # 校验: 目标记忆必须是同一设备类型 + if user_device.device.device_type_id and new_rm.device_type_id != user_device.device.device_type_id: + return error(code=ErrorCode.ROLE_MEMORY_TYPE_MISMATCH, message='只能切换到同类型设备的角色记忆') + + # 校验: 目标记忆必须是闲置状态 + if new_rm.is_bound: + return error(code=ErrorCode.ROLE_MEMORY_ALREADY_BOUND, message='该角色记忆正在被其他设备使用') + + # 旧记忆标记闲置 + old_rm = user_device.role_memory + if old_rm: + old_rm.is_bound = False + old_rm.save(update_fields=['is_bound', 'updated_at']) + + # 新记忆标记绑定 + new_rm.is_bound = True + new_rm.save(update_fields=['is_bound', 'updated_at']) + + # 更新绑定关系 + user_device.role_memory = new_rm + user_device.save(update_fields=['role_memory']) + + return success(data=RoleMemorySerializer(new_rm).data, message='切换成功') diff --git a/tests.py b/tests.py index cf9464a..12f5e74 100644 --- a/tests.py +++ b/tests.py @@ -13,7 +13,7 @@ from rest_framework import status from apps.users.models import User, PointsRecord from apps.admins.models import AdminUser from apps.spirits.models import Spirit -from apps.devices.models import DeviceType, DeviceBatch, Device, UserDevice +from apps.devices.models import DeviceType, DeviceBatch, Device, UserDevice, RoleMemory from apps.stories.models import StoryShelf, Story from apps.music.models import Track from apps.users.views import get_app_tokens @@ -2068,3 +2068,274 @@ class MigrateHistoricalTracksTests(TestCase): self.assertTrue(track.audio_url.startswith('https://qy-rtc.oss-cn-beijing.aliyuncs.com/')) self.assertTrue(track.cover_url.startswith('https://qy-rtc.oss-cn-beijing.aliyuncs.com/')) + +# ==================== 角色记忆测试 ==================== + +class RoleMemoryTests(APITestCase): + """角色记忆功能测试""" + + def setUp(self): + self.user = User.objects.create_user(phone='13800139000', nickname='记忆测试用户') + tokens = get_app_tokens(self.user) + self.client.credentials(HTTP_AUTHORIZATION=f'Bearer {tokens["access"]}') + + self.device_type, _ = DeviceType.objects.get_or_create( + product_code='KPBL-ON-RM', + defaults={ + 'brand': 'AL', + 'name': '卡皮巴拉-联网版', + 'default_prompt': '你是一只可爱的卡皮巴拉', + 'default_voice_id': 'voice_kpbl_01', + } + ) + # 确保模板字段正确(get_or_create 可能返回已有记录) + if not self.device_type.default_prompt: + self.device_type.default_prompt = '你是一只可爱的卡皮巴拉' + self.device_type.default_voice_id = 'voice_kpbl_01' + self.device_type.save() + + self.device, _ = Device.objects.get_or_create( + sn='AL-KPBL-ON-25W01-RM-00001', + defaults={ + 'device_type': self.device_type, + 'status': 'in_stock', + } + ) + # 重置设备状态 + self.device.status = 'in_stock' + self.device.save() + # 清理旧绑定关系和角色记忆 + UserDevice.objects.filter(user=self.user).delete() + RoleMemory.objects.filter(user=self.user).delete() + + def test_bind_creates_role_memory(self): + """测试绑定设备自动创建角色记忆""" + url = '/api/v1/devices/bind/' + data = {'sn': 'AL-KPBL-ON-25W01-RM-00001'} + response = self.client.post(url, data, format='json') + + self.assertEqual(response.data['code'], 0) + self.assertIsNotNone(response.data['data']['role_memory']) + rm_data = response.data['data']['role_memory'] + self.assertEqual(rm_data['prompt'], '你是一只可爱的卡皮巴拉') + self.assertEqual(rm_data['voice_id'], 'voice_kpbl_01') + self.assertTrue(rm_data['is_bound']) + self.assertEqual(rm_data['volume'], 50) + self.assertEqual(rm_data['brightness'], 50) + + def test_bind_creates_new_memory_each_time(self): + """测试每次绑定新设备都创建新的角色记忆""" + device2 = Device.objects.create( + sn='AL-KPBL-ON-25W01-A01-00002', + device_type=self.device_type, + status='in_stock' + ) + # 绑定第一个设备 + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + # 绑定第二个设备 + self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json') + + self.assertEqual(RoleMemory.objects.filter(user=self.user).count(), 2) + + def test_unbind_marks_memory_idle(self): + """测试解绑后角色记忆标记为闲置""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + rm = ud.role_memory + + url = f'/api/v1/devices/{ud.id}/unbind/' + response = self.client.delete(url) + + self.assertEqual(response.data['code'], 0) + rm.refresh_from_db() + self.assertFalse(rm.is_bound) + + def test_unbind_preserves_memory(self): + """测试解绑不删除角色记忆""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + self.client.delete(f'/api/v1/devices/{ud.id}/unbind/') + + self.assertEqual(RoleMemory.objects.filter(user=self.user).count(), 1) + + def test_get_role_memory(self): + """测试获取角色记忆""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + url = f'/api/v1/devices/{ud.id}/role-memory/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['prompt'], '你是一只可爱的卡皮巴拉') + + def test_update_role_memory_settings(self): + """测试更新角色记忆设备设置""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + url = f'/api/v1/devices/{ud.id}/role-memory/settings/' + data = {'nickname': '我的卡皮', 'volume': 80, 'brightness': 30} + response = self.client.put(url, data, format='json') + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['nickname'], '我的卡皮') + self.assertEqual(response.data['data']['volume'], 80) + self.assertEqual(response.data['data']['brightness'], 30) + + def test_update_role_memory_agent(self): + """测试更新Agent信息""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + url = f'/api/v1/devices/{ud.id}/role-memory/agent/' + data = {'prompt': '你是一只会讲故事的卡皮巴拉', 'voice_id': 'voice_new'} + response = self.client.put(url, data, format='json') + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['prompt'], '你是一只会讲故事的卡皮巴拉') + self.assertEqual(response.data['data']['voice_id'], 'voice_new') + + def test_update_role_memory_summary(self): + """测试更新聊天记忆摘要""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + url = f'/api/v1/devices/{ud.id}/role-memory/memory/' + data = {'memory_summary': '用户喜欢恐龙故事,不喜欢太吓人的情节'} + response = self.client.put(url, data, format='json') + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['memory_summary'], '用户喜欢恐龙故事,不喜欢太吓人的情节') + + def test_role_memory_list(self): + """测试角色记忆列表""" + device2 = Device.objects.create( + sn='AL-KPBL-ON-25W01-A01-00002', + device_type=self.device_type, + status='in_stock' + ) + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json') + + url = '/api/v1/devices/role-memories/' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(len(response.data['data']), 2) + + def test_role_memory_list_filter_by_device_type(self): + """测试角色记忆列表按设备类型过滤""" + other_type = DeviceType.objects.create( + brand='AL', product_code='OTHER-ON', name='其他设备' + ) + other_device = Device.objects.create( + sn='AL-OTHER-ON-25W01-A01-00001', + device_type=other_type, + status='in_stock' + ) + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + self.client.post('/api/v1/devices/bind/', {'sn': other_device.sn}, format='json') + + url = f'/api/v1/devices/role-memories/?device_type_id={self.device_type.id}' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(len(response.data['data']), 1) + self.assertEqual(response.data['data'][0]['device_type'], self.device_type.id) + + def test_role_memory_list_filter_by_is_bound(self): + """测试角色记忆列表按绑定状态过滤""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + # 创建一个闲置的记忆 + RoleMemory.objects.create( + user=self.user, device_type=self.device_type, is_bound=False + ) + + url = '/api/v1/devices/role-memories/?is_bound=false' + response = self.client.get(url) + + self.assertEqual(response.data['code'], 0) + self.assertEqual(len(response.data['data']), 1) + self.assertFalse(response.data['data'][0]['is_bound']) + + def test_switch_role_memory(self): + """测试切换角色记忆""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + old_rm = ud.role_memory + + # 创建一个闲置的同类型记忆 + idle_rm = RoleMemory.objects.create( + user=self.user, device_type=self.device_type, + is_bound=False, prompt='闲置的提示词', + memory_summary='之前的记忆内容' + ) + + url = f'/api/v1/devices/{ud.id}/switch-role-memory/' + response = self.client.put(url, {'role_memory_id': idle_rm.id}, format='json') + + self.assertEqual(response.data['code'], 0) + self.assertEqual(response.data['data']['prompt'], '闲置的提示词') + + # 验证状态变化 + old_rm.refresh_from_db() + idle_rm.refresh_from_db() + ud.refresh_from_db() + self.assertFalse(old_rm.is_bound) + self.assertTrue(idle_rm.is_bound) + self.assertEqual(ud.role_memory_id, idle_rm.id) + + def test_switch_rejects_different_type(self): + """测试切换到不同类型的记忆被拒绝""" + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + other_type = DeviceType.objects.create( + brand='AL', product_code='OTHER-ON', name='其他设备' + ) + other_rm = RoleMemory.objects.create( + user=self.user, device_type=other_type, is_bound=False + ) + + url = f'/api/v1/devices/{ud.id}/switch-role-memory/' + response = self.client.put(url, {'role_memory_id': other_rm.id}, format='json') + + self.assertNotEqual(response.data['code'], 0) + + def test_switch_rejects_bound_memory(self): + """测试切换到已绑定的记忆被拒绝""" + device2 = Device.objects.create( + sn='AL-KPBL-ON-25W01-A01-00002', + device_type=self.device_type, + status='in_stock' + ) + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + self.client.post('/api/v1/devices/bind/', {'sn': device2.sn}, format='json') + + ud1 = UserDevice.objects.get(user=self.user, device=self.device) + ud2 = UserDevice.objects.get(user=self.user, device=device2) + + url = f'/api/v1/devices/{ud1.id}/switch-role-memory/' + response = self.client.put(url, {'role_memory_id': ud2.role_memory_id}, format='json') + + self.assertNotEqual(response.data['code'], 0) + + def test_user_isolation(self): + """测试用户隔离 - 不能访问其他用户的角色记忆""" + other_user = User.objects.create_user(phone='13800139001') + other_rm = RoleMemory.objects.create( + user=other_user, device_type=self.device_type, is_bound=False + ) + + self.client.post('/api/v1/devices/bind/', {'sn': self.device.sn}, format='json') + ud = UserDevice.objects.get(user=self.user, device=self.device) + + # 尝试切换到其他用户的记忆 + url = f'/api/v1/devices/{ud.id}/switch-role-memory/' + response = self.client.put(url, {'role_memory_id': other_rm.id}, format='json') + + self.assertNotEqual(response.data['code'], 0) + diff --git a/utils/exceptions.py b/utils/exceptions.py index 5c5e7e8..ddb2b30 100644 --- a/utils/exceptions.py +++ b/utils/exceptions.py @@ -103,6 +103,9 @@ class ErrorCode: DEVICE_ALREADY_BOUND = 201 DEVICE_MAC_EXISTS = 202 DEVICE_SN_INVALID = 203 + ROLE_MEMORY_NOT_FOUND = 204 + ROLE_MEMORY_TYPE_MISMATCH = 205 + ROLE_MEMORY_ALREADY_BOUND = 206 # 智能体模块 300-399 SPIRIT_NOT_FOUND = 300