Compare commits

..

1 Commits

Author SHA1 Message Date
repair-agent
de7e1861eb fix: auto repair bugs #41, #40, #39 2026-02-25 14:50:25 +08:00
105 changed files with 43 additions and 11394 deletions

View File

@ -37,15 +37,6 @@ jobs:
--tag ${{ secrets.SWR_SERVER }}/${{ secrets.SWR_ORG }}/rtc-backend:latest \
. 2>&1 | tee /tmp/build.log
- name: Build and Push HW WebSocket Service
run: |
set -o pipefail
docker buildx build \
--push \
--provenance=false \
--tag ${{ secrets.SWR_SERVER }}/${{ secrets.SWR_ORG }}/hw-ws-service:latest \
./hw_service_go 2>&1 | tee -a /tmp/build.log
- name: Setup Kubectl
run: |
curl -LO "https://dl.k8s.io/release/v1.28.2/bin/linux/amd64/kubectl" || \
@ -77,17 +68,13 @@ jobs:
# 2. 替换镜像地址
sed -i "s|\${CI_REGISTRY_IMAGE}/backend:latest|${{ secrets.SWR_SERVER }}/${{ secrets.SWR_ORG }}/rtc-backend:latest|g" $DEPLOY_FILE
sed -i "s|\${CI_REGISTRY_IMAGE}/hw-ws-service:latest|${{ secrets.SWR_SERVER }}/${{ secrets.SWR_ORG }}/hw-ws-service:latest|g" hw_service_go/k8s/deployment.yaml
# 3. 应用配置并捕获输出
set -o pipefail
{
kubectl apply -f $DEPLOY_FILE
kubectl apply -f $INGRESS_FILE
kubectl apply -f hw_service_go/k8s/deployment.yaml
kubectl apply -f hw_service_go/k8s/service.yaml
kubectl rollout restart deployment/$DEPLOY_NAME
kubectl rollout restart deployment/hw-ws-service
} 2>&1 | tee /tmp/deploy.log
- name: Report failure to Log Center

4
=3.0.1
View File

@ -1,4 +0,0 @@
Requirement already satisfied: opuslib in ./venv/lib/python3.14/site-packages (3.0.1)
[notice] A new release of pip is available: 25.3 -> 26.0.1
[notice] To update, run: /Users/maidong/Desktop/zyc/qy_gitlab/rtc_backend/venv/bin/python3.14 -m pip install --upgrade pip

View File

@ -14,8 +14,6 @@ RUN sed -i 's/deb.debian.org/mirrors.aliyun.com/g' /etc/apt/sources.list.d/debia
gcc \
default-libmysqlclient-dev \
pkg-config \
ffmpeg \
libopus-dev \
&& rm -rf /var/lib/apt/lists/*
# Install python dependencies

View File

@ -1,68 +0,0 @@
"""
管理端批次导出视图
"""
from datetime import datetime
from rest_framework import viewsets
from rest_framework.decorators import action
from drf_spectacular.utils import extend_schema
from utils.response import success, error
from apps.admins.authentication import AdminJWTAuthentication
from apps.admins.permissions import IsAdminUser
from apps.devices.models import DeviceBatch
from apps.devices.serializers import DeviceBatchSerializer
def parse_export_date(date_str):
"""
解析导出日期字符串支持 YYYY-MM-DD YYYY/MM/DD 两种格式
Bug #46 fix: normalize '/' separators to '-' before parsing, instead of
calling strptime directly on user input which fails for YYYY/MM/DD.
"""
# Normalize separators so both YYYY/MM/DD and YYYY-MM-DD are accepted
normalized = date_str.replace('/', '-')
try:
return datetime.strptime(normalized, '%Y-%m-%d')
except ValueError:
raise ValueError(
f'日期格式无效: {date_str},请使用 YYYY-MM-DD 或 YYYY/MM/DD 格式'
)
@extend_schema(tags=['管理员-库存'])
class AdminBatchExportViewSet(viewsets.ViewSet):
"""管理端批次导出视图集"""
authentication_classes = [AdminJWTAuthentication]
permission_classes = [IsAdminUser]
@action(detail=False, methods=['get'], url_path='export')
def export(self, request):
"""
按日期范围导出批次列表
GET /api/admin/batch-export/export?start_date=2026-01-01&end_date=2026-12-31
"""
start_str = request.query_params.get('start_date', '')
end_str = request.query_params.get('end_date', '')
queryset = DeviceBatch.objects.all().order_by('-created_at')
if start_str:
try:
# Bug #46 fix: use parse_export_date which normalises the separator
start_date = parse_export_date(start_str)
except ValueError as exc:
return error(message=str(exc))
queryset = queryset.filter(created_at__date__gte=start_date.date())
if end_str:
try:
end_date = parse_export_date(end_str)
except ValueError as exc:
return error(message=str(exc))
queryset = queryset.filter(created_at__date__lte=end_date.date())
serializer = DeviceBatchSerializer(queryset, many=True)
return success(data={'items': serializer.data})

View File

@ -1,59 +0,0 @@
"""
设备模块管理端视图
Bug #44 fix: replace unsanitized raw SQL device search with Django ORM queries
to eliminate SQL injection risk.
"""
from rest_framework import viewsets
from rest_framework.decorators import action
from drf_spectacular.utils import extend_schema
from utils.response import success, error
from apps.admins.authentication import AdminJWTAuthentication
from apps.admins.permissions import IsAdminUser
from .models import Device
from .serializers import DeviceSerializer
def search_devices_by_sn(keyword):
"""
通过SN码关键字搜索设备
Bug #44 fix: use ORM filter (sn__icontains) instead of raw SQL string
interpolation, which was vulnerable to SQL injection:
# VULNERABLE (old code):
query = f'SELECT * FROM device WHERE sn LIKE %{keyword}%'
cursor.execute(query)
# SAFE (new code):
Device.objects.filter(sn__icontains=keyword)
"""
return Device.objects.filter(sn__icontains=keyword)
@extend_schema(tags=['管理员-设备'])
class AdminDeviceViewSet(viewsets.ViewSet):
"""设备管理视图集 - 管理端"""
authentication_classes = [AdminJWTAuthentication]
permission_classes = [IsAdminUser]
@action(detail=False, methods=['get'], url_path='search')
def search(self, request):
"""
通过SN码搜索设备管理端
GET /api/admin/devices/search?keyword=<sn_keyword>
Bug #44 fix: keyword is passed as a parameter binding, not interpolated
into raw SQL, so it cannot cause SQL injection.
"""
keyword = request.query_params.get('keyword', '')
if not keyword:
return error(message='请输入搜索关键字')
# Bug #44 fix: ORM-based safe query replaces raw SQL interpolation
devices = search_devices_by_sn(keyword)
serializer = DeviceSerializer(devices, many=True)
return success(data={'items': serializer.data, 'total': devices.count()})

View File

@ -1,64 +0,0 @@
# Generated by Django 6.0.1 on 2026-02-27 06:23
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('devices', '0003_device_battery_device_icon_device_is_ai_and_more'),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [
migrations.AddField(
model_name='devicetype',
name='default_prompt',
field=models.TextField(blank=True, default='', verbose_name='默认提示词模板'),
),
migrations.AddField(
model_name='devicetype',
name='default_voice_id',
field=models.CharField(blank=True, default='', max_length=100, verbose_name='默认音色ID'),
),
migrations.CreateModel(
name='RoleMemory',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('is_bound', models.BooleanField(default=True, verbose_name='是否已绑定设备')),
('nickname', models.CharField(blank=True, default='', max_length=50, verbose_name='设备昵称')),
('user_name', models.CharField(blank=True, default='', max_length=50, verbose_name='用户称呼')),
('volume', models.IntegerField(default=50, verbose_name='音量')),
('brightness', models.IntegerField(default=50, verbose_name='亮度')),
('allow_interrupt', models.BooleanField(default=True, verbose_name='允许打断')),
('privacy_mode', models.BooleanField(default=False, verbose_name='隐私模式')),
('prompt', models.TextField(blank=True, default='', verbose_name='提示词')),
('voice_id', models.CharField(blank=True, default='', max_length=100, verbose_name='音色ID')),
('memory_summary', models.TextField(blank=True, default='', verbose_name='聊天记忆摘要')),
('created_at', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('updated_at', models.DateTimeField(auto_now=True, verbose_name='更新时间')),
('device_type', models.ForeignKey(on_delete=django.db.models.deletion.PROTECT, related_name='role_memories', to='devices.devicetype', verbose_name='设备类型')),
('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='role_memories', to=settings.AUTH_USER_MODEL, verbose_name='用户')),
],
options={
'verbose_name': '角色记忆',
'verbose_name_plural': '角色记忆',
'db_table': 'role_memory',
},
),
migrations.AddField(
model_name='userdevice',
name='role_memory',
field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='user_devices', to='devices.rolememory', verbose_name='角色记忆'),
),
migrations.AddIndex(
model_name='rolememory',
index=models.Index(fields=['user', 'device_type'], name='role_memory_user_id_c7dd09_idx'),
),
migrations.AddIndex(
model_name='rolememory',
index=models.Index(fields=['is_bound'], name='role_memory_is_boun_556b2c_idx'),
),
]

View File

@ -15,8 +15,6 @@ class DeviceType(models.Model):
name = models.CharField('名称', max_length=100)
is_network_required = models.BooleanField('是否需要联网', default=True)
is_active = models.BooleanField('是否启用', default=True)
default_prompt = models.TextField('默认提示词模板', blank=True, default='')
default_voice_id = models.CharField('默认音色ID', max_length=100, blank=True, default='')
created_at = models.DateTimeField('创建时间', auto_now_add=True)
updated_at = models.DateTimeField('更新时间', auto_now=True)
@ -123,11 +121,6 @@ class UserDevice(models.Model):
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='user_devices', verbose_name='用户')
device = models.ForeignKey(Device, on_delete=models.CASCADE, related_name='user_devices', verbose_name='设备')
spirit = models.ForeignKey(Spirit, on_delete=models.SET_NULL, null=True, blank=True, related_name='user_devices', verbose_name='绑定的智能体')
role_memory = models.ForeignKey(
'RoleMemory', on_delete=models.SET_NULL,
null=True, blank=True,
related_name='user_devices', verbose_name='角色记忆'
)
bind_type = models.CharField('绑定类型', max_length=20, choices=BIND_TYPE_CHOICES, default='owner')
bind_time = models.DateTimeField('绑定时间', auto_now_add=True)
is_active = models.BooleanField('是否有效', default=True)
@ -185,42 +178,3 @@ class DeviceWifi(models.Model):
def __str__(self):
return f"{self.device.sn} - {self.ssid}"
class RoleMemory(models.Model):
"""角色记忆 - 按用户+设备类型存储,同类型可有多个"""
id = models.BigAutoField(primary_key=True)
user = models.ForeignKey(User, on_delete=models.CASCADE, related_name='role_memories', verbose_name='用户')
device_type = models.ForeignKey(DeviceType, on_delete=models.PROTECT, related_name='role_memories', verbose_name='设备类型')
is_bound = models.BooleanField('是否已绑定设备', default=True)
# 基础设备设置
nickname = models.CharField('设备昵称', max_length=50, blank=True, default='')
user_name = models.CharField('用户称呼', max_length=50, blank=True, default='')
volume = models.IntegerField('音量', default=50)
brightness = models.IntegerField('亮度', default=50)
allow_interrupt = models.BooleanField('允许打断', default=True)
privacy_mode = models.BooleanField('隐私模式', default=False)
# Agent 信息
prompt = models.TextField('提示词', blank=True, default='')
voice_id = models.CharField('音色ID', max_length=100, blank=True, default='')
# 聊天记忆(摘要式)
memory_summary = models.TextField('聊天记忆摘要', blank=True, default='')
created_at = models.DateTimeField('创建时间', auto_now_add=True)
updated_at = models.DateTimeField('更新时间', auto_now=True)
class Meta:
db_table = 'role_memory'
verbose_name = '角色记忆'
verbose_name_plural = '角色记忆'
indexes = [
models.Index(fields=['user', 'device_type']),
models.Index(fields=['is_bound']),
]
def __str__(self):
return f"{self.user.phone} - {self.device_type.name} - #{self.id}"

View File

@ -2,19 +2,15 @@
设备模块序列化器
"""
from rest_framework import serializers
from .models import DeviceType, DeviceBatch, Device, UserDevice, DeviceSettings, DeviceWifi, RoleMemory
from .models import DeviceType, DeviceBatch, Device, UserDevice, DeviceSettings, DeviceWifi
class DeviceTypeSerializer(serializers.ModelSerializer):
"""设备类型序列化器"""
default_prompt = serializers.CharField(required=False, default='', allow_blank=True)
default_voice_id = serializers.CharField(required=False, default='', allow_blank=True)
class Meta:
model = DeviceType
fields = ['id', 'brand', 'product_code', 'name', 'is_network_required', 'is_active',
'default_prompt', 'default_voice_id', 'created_at']
fields = ['id', 'brand', 'product_code', 'name', 'is_network_required', 'is_active', 'created_at']
read_only_fields = ['id', 'is_network_required', 'created_at']
@ -57,35 +53,15 @@ class DeviceSimpleSerializer(serializers.ModelSerializer):
fields = ['id', 'sn', 'mac_address', 'status', 'created_at']
class RoleMemorySerializer(serializers.ModelSerializer):
"""角色记忆序列化器"""
device_type_name = serializers.CharField(source='device_type.name', read_only=True)
class Meta:
model = RoleMemory
fields = [
'id', 'device_type', 'device_type_name', 'is_bound',
'nickname', 'user_name', 'volume', 'brightness',
'allow_interrupt', 'privacy_mode',
'prompt', 'voice_id',
'memory_summary',
'created_at', 'updated_at',
]
read_only_fields = ['id', 'device_type', 'device_type_name', 'is_bound',
'created_at', 'updated_at']
class UserDeviceSerializer(serializers.ModelSerializer):
"""用户设备绑定序列化器"""
device = DeviceSerializer(read_only=True)
spirit_name = serializers.CharField(source='spirit.name', read_only=True, allow_null=True)
role_memory = RoleMemorySerializer(read_only=True)
class Meta:
model = UserDevice
fields = ['id', 'device', 'spirit', 'spirit_name', 'role_memory',
'bind_type', 'bind_time', 'is_active']
fields = ['id', 'device', 'spirit', 'spirit_name', 'bind_type', 'bind_time', 'is_active']
class BindDeviceSerializer(serializers.Serializer):
@ -136,13 +112,11 @@ class DeviceDetailSerializer(serializers.ModelSerializer):
wifi_list = DeviceWifiSerializer(many=True, read_only=True)
status = serializers.SerializerMethodField()
bound_spirit = serializers.SerializerMethodField()
role_memory = serializers.SerializerMethodField()
class Meta:
model = Device
fields = ['id', 'sn', 'name', 'status', 'battery', 'firmware_version',
'mac_address', 'is_ai', 'icon', 'settings', 'wifi_list',
'bound_spirit', 'role_memory']
'mac_address', 'is_ai', 'icon', 'settings', 'wifi_list', 'bound_spirit']
def get_status(self, obj):
return 'online' if obj.is_online else 'offline'
@ -153,12 +127,6 @@ class DeviceDetailSerializer(serializers.ModelSerializer):
return {'id': user_device.spirit.id, 'name': user_device.spirit.name}
return None
def get_role_memory(self, obj):
user_device = self.context.get('user_device')
if user_device and user_device.role_memory:
return RoleMemorySerializer(user_device.role_memory).data
return None
class DeviceSettingsUpdateSerializer(serializers.Serializer):
"""更新设备设置序列化器"""
@ -181,24 +149,3 @@ class DeviceReportStatusSerializer(serializers.Serializer):
def validate_mac_address(self, value):
return value.upper().replace('-', ':')
class RoleMemorySettingsUpdateSerializer(serializers.Serializer):
"""更新角色记忆-设备设置"""
nickname = serializers.CharField(max_length=50, required=False)
user_name = serializers.CharField(max_length=50, required=False)
volume = serializers.IntegerField(min_value=0, max_value=100, required=False)
brightness = serializers.IntegerField(min_value=0, max_value=100, required=False)
allow_interrupt = serializers.BooleanField(required=False)
privacy_mode = serializers.BooleanField(required=False)
class RoleMemoryAgentUpdateSerializer(serializers.Serializer):
"""更新角色记忆-Agent信息"""
prompt = serializers.CharField(required=False, allow_blank=True)
voice_id = serializers.CharField(max_length=100, required=False, allow_blank=True)
class RoleMemoryMemoryUpdateSerializer(serializers.Serializer):
"""更新角色记忆-聊天记忆摘要"""
memory_summary = serializers.CharField(required=True, allow_blank=True)

View File

@ -9,7 +9,7 @@ from drf_spectacular.utils import extend_schema
from utils.response import success, error
from utils.exceptions import ErrorCode
from apps.admins.authentication import AppJWTAuthentication
from .models import Device, UserDevice, DeviceType, DeviceSettings, DeviceWifi, RoleMemory
from .models import Device, UserDevice, DeviceType, DeviceSettings, DeviceWifi
from .serializers import (
DeviceSerializer,
UserDeviceSerializer,
@ -19,10 +19,6 @@ from .serializers import (
DeviceDetailSerializer,
DeviceSettingsUpdateSerializer,
DeviceReportStatusSerializer,
RoleMemorySerializer,
RoleMemorySettingsUpdateSerializer,
RoleMemoryAgentUpdateSerializer,
RoleMemoryMemoryUpdateSerializer,
)
@ -73,7 +69,7 @@ class DeviceViewSet(viewsets.ViewSet):
UserDevice.objects.filter(
user=request.user,
is_active=True
).select_related('device', 'device__device_type', 'spirit', 'role_memory', 'role_memory__device_type')
).select_related('device', 'device__device_type', 'spirit')
.order_by('-bind_time')[:1]
)
if not devices:
@ -121,7 +117,7 @@ class DeviceViewSet(viewsets.ViewSet):
spirit_id = serializer.validated_data.get('spirit_id')
try:
device = Device.objects.select_related('device_type').get(sn=sn)
device = Device.objects.get(sn=sn)
except Device.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='设备不存在')
@ -132,24 +128,12 @@ class DeviceViewSet(viewsets.ViewSet):
if existing and existing.user != request.user:
return error(code=ErrorCode.DEVICE_ALREADY_BOUND, message='设备已被其他用户绑定')
# 创建角色记忆(始终创建新的)
role_memory = None
if device.device_type:
role_memory = RoleMemory.objects.create(
user=request.user,
device_type=device.device_type,
is_bound=True,
prompt=device.device_type.default_prompt,
voice_id=device.device_type.default_voice_id,
)
# 创建绑定关系
user_device, created = UserDevice.objects.update_or_create(
user=request.user,
device=device,
defaults={
'spirit_id': spirit_id,
'role_memory': role_memory,
'is_active': True
}
)
@ -175,7 +159,7 @@ class DeviceViewSet(viewsets.ViewSet):
user_devices = UserDevice.objects.filter(
user=request.user,
is_active=True
).select_related('device', 'device__device_type', 'spirit', 'role_memory', 'role_memory__device_type')
).select_related('device', 'device__device_type', 'spirit')
serializer = UserDeviceSerializer(user_devices, many=True)
return success(data=serializer.data)
@ -187,9 +171,7 @@ class DeviceViewSet(viewsets.ViewSet):
DELETE /api/v1/devices/{id}/unbind
"""
try:
user_device = UserDevice.objects.select_related('role_memory').get(
id=pk, user=request.user
)
user_device = UserDevice.objects.get(id=pk, user=request.user)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
@ -197,11 +179,6 @@ class DeviceViewSet(viewsets.ViewSet):
user_device.is_active = False
user_device.save()
# 将关联的角色记忆标记为闲置
if user_device.role_memory:
user_device.role_memory.is_bound = False
user_device.role_memory.save(update_fields=['is_bound', 'updated_at'])
# 检查设备是否还有其他活跃绑定
active_bindings = UserDevice.objects.filter(device=user_device.device, is_active=True).count()
if active_bindings == 0:
@ -236,7 +213,7 @@ class DeviceViewSet(viewsets.ViewSet):
"""
try:
user_device = UserDevice.objects.select_related(
'device', 'spirit', 'role_memory', 'role_memory__device_type'
'device', 'spirit'
).get(id=pk, user=request.user, is_active=True)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
@ -306,136 +283,6 @@ class DeviceViewSet(viewsets.ViewSet):
return success(message='WiFi 配置成功')
@action(
detail=False, methods=['get'],
url_path='stories',
authentication_classes=[], permission_classes=[AllowAny]
)
def stories_by_mac(self, request):
"""
获取设备关联用户的随机故事公开接口无需认证
GET /api/v1/devices/stories/?mac_address=AA:BB:CC:DD:EE:FF
hw-ws-service 调用
优先返回用户自己的故事无则兜底返回系统默认故事is_default=True
"""
mac = request.query_params.get('mac_address', '').strip()
if not mac:
return error(message='mac_address 参数不能为空')
mac = mac.upper().replace('-', ':')
from apps.stories.models import Story
story = None
# 1. 尝试查找设备 → 绑定用户 → 用户故事
try:
device = Device.objects.get(mac_address=mac)
user_device = (
UserDevice.objects
.filter(device=device, is_active=True, bind_type='owner')
.select_related('user')
.first()
)
if user_device:
story = (
Story.objects
.filter(user=user_device.user)
.exclude(audio_url='')
.order_by('?')
.first()
)
except Device.DoesNotExist:
pass
# 2. 兜底:设备不存在/未绑定/用户无故事 → 使用系统默认故事
if not story:
story = (
Story.objects
.filter(is_default=True)
.exclude(audio_url='')
.order_by('?')
.first()
)
if not story:
return error(
code=ErrorCode.STORY_NOT_FOUND,
message='暂无可播放的故事',
status_code=status.HTTP_404_NOT_FOUND
)
return success(data={
'title': story.title,
'audio_url': story.audio_url,
'opus_url': story.opus_url,
'intro_opus_data': story.intro_opus_data,
})
@action(
detail=False, methods=['get'],
url_path='music',
authentication_classes=[], permission_classes=[AllowAny]
)
def music_by_mac(self, request):
"""
获取设备关联用户的随机音乐公开接口无需认证
GET /api/v1/devices/music/?mac_address=AA:BB:CC:DD:EE:FF
hw-ws-service 调用
优先返回用户自己的音乐无则兜底返回系统默认曲目is_default=True
"""
mac = request.query_params.get('mac_address', '').strip()
if not mac:
return error(message='mac_address 参数不能为空')
mac = mac.upper().replace('-', ':')
from apps.music.models import Track
track = None
# 1. 尝试查找设备 → 绑定用户 → 用户音乐
try:
device = Device.objects.get(mac_address=mac)
user_device = (
UserDevice.objects
.filter(device=device, is_active=True, bind_type='owner')
.select_related('user')
.first()
)
if user_device:
track = (
Track.objects
.filter(user=user_device.user, generation_status='completed')
.exclude(audio_url='')
.order_by('?')
.first()
)
except Device.DoesNotExist:
pass
# 2. 兜底:设备不存在/未绑定/用户无音乐 → 使用系统默认曲目
if not track:
track = (
Track.objects
.filter(is_default=True, generation_status='completed')
.exclude(audio_url='')
.order_by('?')
.first()
)
if not track:
return error(
code=ErrorCode.TRACK_NOT_FOUND,
message='暂无可播放的音乐',
status_code=status.HTTP_404_NOT_FOUND
)
return success(data={
'title': track.title,
'audio_url': track.audio_url,
'opus_url': track.opus_url,
'intro_opus_data': track.intro_opus_data,
'cover_url': track.cover_url,
'duration': track.duration,
})
@action(detail=False, methods=['post'], url_path='report-status',
authentication_classes=[], permission_classes=[AllowAny])
def report_status(self, request):
@ -482,162 +329,3 @@ class DeviceViewSet(viewsets.ViewSet):
data={'device_id': device.id, 'sn': device.sn},
message='状态上报成功'
)
# ==================== 角色记忆相关端点 ====================
@action(detail=True, methods=['get'], url_path='role-memory')
def role_memory_detail(self, request, pk=None):
"""
获取当前设备的角色记忆
GET /api/v1/devices/{user_device_id}/role-memory/
"""
try:
user_device = UserDevice.objects.select_related(
'role_memory', 'role_memory__device_type'
).get(id=pk, user=request.user, is_active=True)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
if not user_device.role_memory:
return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='该设备暂无角色记忆')
return success(data=RoleMemorySerializer(user_device.role_memory).data)
@action(detail=True, methods=['put'], url_path='role-memory/settings')
def role_memory_settings(self, request, pk=None):
"""
更新角色记忆-设备设置
PUT /api/v1/devices/{user_device_id}/role-memory/settings/
"""
try:
user_device = UserDevice.objects.select_related('role_memory').get(
id=pk, user=request.user, is_active=True
)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
if not user_device.role_memory:
return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='该设备暂无角色记忆')
serializer = RoleMemorySettingsUpdateSerializer(data=request.data)
if not serializer.is_valid():
return error(message=str(serializer.errors))
rm = user_device.role_memory
for field, value in serializer.validated_data.items():
setattr(rm, field, value)
rm.save()
return success(data=RoleMemorySerializer(rm).data, message='设置已保存')
@action(detail=True, methods=['put'], url_path='role-memory/agent')
def role_memory_agent(self, request, pk=None):
"""
更新角色记忆-Agent信息
PUT /api/v1/devices/{user_device_id}/role-memory/agent/
"""
try:
user_device = UserDevice.objects.select_related('role_memory').get(
id=pk, user=request.user, is_active=True
)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
if not user_device.role_memory:
return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='该设备暂无角色记忆')
serializer = RoleMemoryAgentUpdateSerializer(data=request.data)
if not serializer.is_valid():
return error(message=str(serializer.errors))
rm = user_device.role_memory
for field, value in serializer.validated_data.items():
setattr(rm, field, value)
rm.save()
return success(data=RoleMemorySerializer(rm).data, message='Agent信息已更新')
@action(detail=True, methods=['put'], url_path='role-memory/memory')
def role_memory_summary(self, request, pk=None):
"""
更新角色记忆-聊天记忆摘要
PUT /api/v1/devices/{user_device_id}/role-memory/memory/
"""
try:
user_device = UserDevice.objects.select_related('role_memory').get(
id=pk, user=request.user, is_active=True
)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
if not user_device.role_memory:
return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='该设备暂无角色记忆')
serializer = RoleMemoryMemoryUpdateSerializer(data=request.data)
if not serializer.is_valid():
return error(message=str(serializer.errors))
user_device.role_memory.memory_summary = serializer.validated_data['memory_summary']
user_device.role_memory.save(update_fields=['memory_summary', 'updated_at'])
return success(message='聊天记忆已更新')
@action(detail=False, methods=['get'], url_path='role-memories')
def role_memory_list(self, request):
"""
获取角色记忆列表
GET /api/v1/devices/role-memories/?device_type_id=1&is_bound=false
"""
qs = RoleMemory.objects.filter(user=request.user).select_related('device_type')
device_type_id = request.query_params.get('device_type_id')
if device_type_id:
qs = qs.filter(device_type_id=device_type_id)
is_bound = request.query_params.get('is_bound')
if is_bound is not None:
qs = qs.filter(is_bound=is_bound.lower() == 'true')
return success(data=RoleMemorySerializer(qs, many=True).data)
@action(detail=True, methods=['put'], url_path='switch-role-memory')
def switch_role_memory(self, request, pk=None):
"""
切换设备的角色记忆
PUT /api/v1/devices/{user_device_id}/switch-role-memory/
body: { "role_memory_id": 5 }
"""
try:
user_device = UserDevice.objects.select_related(
'device', 'device__device_type', 'role_memory'
).get(id=pk, user=request.user, is_active=True)
except UserDevice.DoesNotExist:
return error(code=ErrorCode.DEVICE_NOT_FOUND, message='绑定记录不存在')
role_memory_id = request.data.get('role_memory_id')
if not role_memory_id:
return error(message='role_memory_id 不能为空')
try:
new_rm = RoleMemory.objects.get(id=role_memory_id, user=request.user)
except RoleMemory.DoesNotExist:
return error(code=ErrorCode.ROLE_MEMORY_NOT_FOUND, message='该设备暂无角色记忆')
# 校验: 目标记忆必须是同一设备类型
if user_device.device.device_type_id and new_rm.device_type_id != user_device.device.device_type_id:
return error(code=ErrorCode.ROLE_MEMORY_TYPE_MISMATCH, message='只能切换到同类型设备的角色记忆')
# 校验: 目标记忆必须是闲置状态
if new_rm.is_bound:
return error(code=ErrorCode.ROLE_MEMORY_ALREADY_BOUND, message='该角色记忆正在被其他设备使用')
# 执行切换:旧记忆标记闲置,新记忆标记绑定
old_rm = user_device.role_memory
if old_rm:
old_rm.is_bound = False
old_rm.save(update_fields=['is_bound', 'updated_at'])
new_rm.is_bound = True
new_rm.save(update_fields=['is_bound', 'updated_at'])
user_device.role_memory = new_rm
user_device.save(update_fields=['role_memory'])
return success(data=RoleMemorySerializer(new_rm).data, message='切换成功')

View File

@ -1,122 +0,0 @@
"""
批量将已有音乐的 MP3 音频预转码为 Opus JSON 并上传 OSS
使用方法:
python manage.py convert_tracks_to_opus
python manage.py convert_tracks_to_opus --dry-run # 仅统计,不转码
python manage.py convert_tracks_to_opus --limit 10 # 只处理前 10 个
python manage.py convert_tracks_to_opus --force # 重新转码已有 opus_url 的曲目
python manage.py convert_tracks_to_opus --default # 仅处理系统默认曲目
"""
import uuid
import logging
from datetime import datetime
import requests
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.music.models import Track
from apps.stories.services.opus_converter import convert_mp3_to_opus_json
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = '批量将已有音乐的 MP3 音频预转码为 Opus 帧 JSON'
def add_arguments(self, parser):
parser.add_argument(
'--dry-run', action='store_true',
help='仅统计需要转码的曲目数量,不实际执行',
)
parser.add_argument(
'--limit', type=int, default=0,
help='最多处理的曲目数量0=不限)',
)
parser.add_argument(
'--force', action='store_true',
help='重新转码已有 opus_url 的曲目',
)
parser.add_argument(
'--default', action='store_true',
help='仅处理系统默认曲目is_default=True',
)
def handle(self, *args, **options):
dry_run = options['dry_run']
limit = options['limit']
force = options['force']
default_only = options['default']
# 查找需要转码的曲目
qs = Track.objects.filter(
generation_status='completed',
).exclude(audio_url='')
if not force:
qs = qs.filter(opus_url='')
if default_only:
qs = qs.filter(is_default=True)
qs = qs.order_by('id')
total = qs.count()
self.stdout.write(f'需要转码的曲目: {total}')
if dry_run:
self.stdout.write(self.style.NOTICE('[dry-run] 仅统计,不执行转码'))
return
if total == 0:
self.stdout.write(self.style.SUCCESS('所有曲目已转码,无需处理'))
return
# OSS 客户端
from utils.oss import get_oss_client
oss_client = get_oss_client()
oss_config = settings.ALIYUN_OSS
if oss_config.get('CUSTOM_DOMAIN'):
url_prefix = f"https://{oss_config['CUSTOM_DOMAIN']}"
else:
url_prefix = f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}"
tracks = qs[:limit] if limit > 0 else qs
success_count = 0
fail_count = 0
for i, track in enumerate(tracks.iterator(), 1):
self.stdout.write(f'\n[{i}/{total}] Track#{track.id} "{track.title}"')
self.stdout.write(f' MP3: {track.audio_url[:80]}...')
try:
# 下载 MP3
resp = requests.get(track.audio_url, timeout=60)
resp.raise_for_status()
mp3_bytes = resp.content
self.stdout.write(f' MP3 大小: {len(mp3_bytes) / 1024:.1f} KB')
# 转码
opus_json = convert_mp3_to_opus_json(mp3_bytes)
self.stdout.write(f' Opus JSON 大小: {len(opus_json) / 1024:.1f} KB')
# 上传 OSS
opus_filename = f"{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.json"
opus_key = f"music/audio-opus/{opus_filename}"
oss_client.bucket.put_object(opus_key, opus_json.encode('utf-8'))
opus_url = f"{url_prefix}/{opus_key}"
track.opus_url = opus_url
track.save(update_fields=['opus_url'])
success_count += 1
self.stdout.write(self.style.SUCCESS(f' OK: {opus_url}'))
except Exception as e:
fail_count += 1
self.stdout.write(self.style.ERROR(f' FAIL: {e}'))
logger.error(f'Track#{track.id} opus convert failed: {e}')
self.stdout.write(f'\n{"=" * 40}')
self.stdout.write(self.style.SUCCESS(
f'完成: 成功 {success_count}, 失败 {fail_count}, 总计 {success_count + fail_count}'
))

View File

@ -1,18 +0,0 @@
# Generated by Django 6.0.1 on 2026-03-04 03:10
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('music', '0002_track_generation_status_track_is_default_and_more'),
]
operations = [
migrations.AddField(
model_name='track',
name='opus_url',
field=models.URLField(blank=True, default='', max_length=500, verbose_name='Opus音频URL'),
),
]

View File

@ -1,18 +0,0 @@
# Generated by Django 6.0.1 on 2026-03-04 03:27
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('music', '0003_track_opus_url'),
]
operations = [
migrations.AddField(
model_name='track',
name='intro_opus_data',
field=models.TextField(blank=True, default='', verbose_name='引导语Opus数据'),
),
]

View File

@ -30,7 +30,6 @@ class Track(models.Model):
title = models.CharField('标题', max_length=200)
lyrics = models.TextField('歌词', blank=True, default='')
audio_url = models.URLField('音频URL', max_length=500, blank=True, default='')
opus_url = models.URLField('Opus音频URL', max_length=500, blank=True, default='')
cover_url = models.URLField('封面URL', max_length=500, blank=True, default='')
mood = models.CharField(
'情绪标签', max_length=20,
@ -40,7 +39,6 @@ class Track(models.Model):
prompt = models.TextField('生成提示词', blank=True, default='')
is_favorite = models.BooleanField('是否收藏', default=False)
is_default = models.BooleanField('是否默认曲目', default=False)
intro_opus_data = models.TextField('引导语Opus数据', blank=True, default='')
generation_status = models.CharField(
'生成状态', max_length=20,
choices=GENERATION_STATUS_CHOICES, default='completed'

View File

@ -17,18 +17,10 @@ class SpiritSerializer(serializers.ModelSerializer):
class CreateSpiritSerializer(serializers.ModelSerializer):
"""创建智能体序列化器"""
voice_id = serializers.CharField(required=False, allow_blank=True, default='')
class Meta:
model = Spirit
fields = ['name', 'avatar', 'prompt', 'memory', 'voice_id']
def validate(self, data):
# Bug #47 fix: use .get() to avoid KeyError when voice_id is not provided
voice_id = data.get('voice_id', '')
data['voice_id'] = voice_id
return data
def validate_prompt(self, value):
if value and len(value) > 5000:
raise serializers.ValidationError('提示词不能超过5000个字符')

View File

@ -106,25 +106,6 @@ class SpiritViewSet(viewsets.ModelViewSet):
return success(message=f'已解绑智能体,数据已保留在云端(影响 {count} 个设备)')
@action(detail=True, methods=['get'], url_path='owner-info')
def owner_info(self, request, pk=None):
"""
获取智能体所有者信息
GET /api/v1/spirits/{id}/owner-info/
Bug #45 fix: spirit.user (owner) may be None if the user record was
removed outside of the normal cascade path; guard with an explicit
None-check instead of accessing .nickname unconditionally.
"""
try:
spirit = Spirit.objects.get(id=pk, user=request.user)
except Spirit.DoesNotExist:
return error(code=ErrorCode.SPIRIT_NOT_FOUND, message='智能体不存在')
# Bug #45 fix: null-safe access avoid TypeError when owner is None
owner_name = spirit.user.nickname if spirit.user else None
return success(data={'owner_name': owner_name})
@action(detail=True, methods=['post'])
def inject(self, request, pk=None):
"""

View File

@ -1,112 +0,0 @@
"""
批量将已有故事的 MP3 音频预转码为 Opus JSON 并上传 OSS
使用方法:
python manage.py convert_stories_to_opus
python manage.py convert_stories_to_opus --dry-run # 仅统计,不转码
python manage.py convert_stories_to_opus --limit 10 # 只处理前 10 个
python manage.py convert_stories_to_opus --force # 重新转码已有 opus_url 的故事
"""
import uuid
import logging
from datetime import datetime
import requests
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.stories.models import Story
from apps.stories.services.opus_converter import convert_mp3_to_opus_json
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = '批量将已有故事的 MP3 音频预转码为 Opus 帧 JSON'
def add_arguments(self, parser):
parser.add_argument(
'--dry-run', action='store_true',
help='仅统计需要转码的故事数量,不实际执行',
)
parser.add_argument(
'--limit', type=int, default=0,
help='最多处理的故事数量0=不限)',
)
parser.add_argument(
'--force', action='store_true',
help='重新转码已有 opus_url 的故事',
)
def handle(self, *args, **options):
dry_run = options['dry_run']
limit = options['limit']
force = options['force']
# 查找需要转码的故事
qs = Story.objects.exclude(audio_url='')
if not force:
qs = qs.filter(opus_url='')
qs = qs.order_by('id')
total = qs.count()
self.stdout.write(f'需要转码的故事: {total}')
if dry_run:
self.stdout.write(self.style.NOTICE('[dry-run] 仅统计,不执行转码'))
return
if total == 0:
self.stdout.write(self.style.SUCCESS('所有故事已转码,无需处理'))
return
# OSS 客户端
from utils.oss import get_oss_client
oss_client = get_oss_client()
oss_config = settings.ALIYUN_OSS
if oss_config.get('CUSTOM_DOMAIN'):
url_prefix = f"https://{oss_config['CUSTOM_DOMAIN']}"
else:
url_prefix = f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}"
stories = qs[:limit] if limit > 0 else qs
success_count = 0
fail_count = 0
for i, story in enumerate(stories.iterator(), 1):
self.stdout.write(f'\n[{i}/{total}] Story#{story.id} "{story.title}"')
self.stdout.write(f' MP3: {story.audio_url[:80]}...')
try:
# 下载 MP3
resp = requests.get(story.audio_url, timeout=60)
resp.raise_for_status()
mp3_bytes = resp.content
self.stdout.write(f' MP3 大小: {len(mp3_bytes) / 1024:.1f} KB')
# 转码
opus_json = convert_mp3_to_opus_json(mp3_bytes)
self.stdout.write(f' Opus JSON 大小: {len(opus_json) / 1024:.1f} KB')
# 上传 OSS
opus_filename = f"{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.json"
opus_key = f"stories/audio-opus/{opus_filename}"
oss_client.bucket.put_object(opus_key, opus_json.encode('utf-8'))
opus_url = f"{url_prefix}/{opus_key}"
story.opus_url = opus_url
story.save(update_fields=['opus_url'])
success_count += 1
self.stdout.write(self.style.SUCCESS(f' OK: {opus_url}'))
except Exception as e:
fail_count += 1
self.stdout.write(self.style.ERROR(f' FAIL: {e}'))
logger.error(f'Story#{story.id} opus convert failed: {e}')
self.stdout.write(f'\n{"=" * 40}')
self.stdout.write(self.style.SUCCESS(
f'完成: 成功 {success_count}, 失败 {fail_count}, 总计 {success_count + fail_count}'
))

View File

@ -1,116 +0,0 @@
"""
用新的 LLM 提炼逻辑重新生成默认故事封面并上传到 OSS
使用方法:
python manage.py generate_default_covers
python manage.py generate_default_covers --dry-run # 仅打印提炼到的描述,不生成图片
"""
import uuid
import logging
import requests
from django.conf import settings
from django.core.management.base import BaseCommand
from apps.stories.utils import DEFAULT_STORIES
from apps.stories.services.llm_service import _extract_image_description
logger = logging.getLogger(__name__)
# 每个默认故事对应的 OSS key与 utils.py 中的 cover_url 一致)
DEFAULT_COVER_KEYS = {
"失控的魔法扫帚": "stories/defaults/失控的魔法扫帚_cover.png",
}
class Command(BaseCommand):
help = "用 LLM 提炼故事画面描述后调用 Seedream 4.5 重新生成默认故事封面"
def add_arguments(self, parser):
parser.add_argument(
"--dry-run",
action="store_true",
help="仅打印 LLM 提炼的画面描述,不实际生成图片",
)
def handle(self, *args, **options):
dry_run = options["dry_run"]
config = settings.LLM_CONFIG
if not config.get("API_KEY"):
self.stderr.write(self.style.ERROR("VOLCENGINE_API_KEY 未配置"))
return
try:
from volcenginesdkarkruntime import Ark
except ImportError:
self.stderr.write(self.style.ERROR("volcengine SDK 未安装"))
return
try:
from utils.oss import get_oss_client
import oss2
except ImportError:
self.stderr.write(self.style.ERROR("oss2 未安装"))
return
client = Ark(api_key=config["API_KEY"])
image_model = config.get("IMAGE_MODEL_NAME", "doubao-seedream-4-5-251128")
image_size = config.get("IMAGE_SIZE", "2560x1440")
oss_config = settings.ALIYUN_OSS
oss_base = f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}"
for story in DEFAULT_STORIES:
title = story["title"]
content = story["content"]
oss_key = DEFAULT_COVER_KEYS.get(title)
if not oss_key:
self.stdout.write(self.style.WARNING(f"[{title}] 未找到对应 OSS key跳过"))
continue
self.stdout.write(f"\n[{title}]")
# Step 1: LLM 提炼画面描述
self.stdout.write(" 正在用 LLM 提炼画面描述...")
scene_desc = _extract_image_description(
title, content, client, config["MODEL_NAME"]
)
self.stdout.write(self.style.SUCCESS(f" 画面描述({len(scene_desc)} 字):{scene_desc}"))
if dry_run:
self.stdout.write(self.style.NOTICE(" [dry-run] 跳过图片生成"))
continue
# Step 2: 文生图
image_prompt = (
f"儿童绘本封面插画,{scene_desc},卡通可爱风格,色彩明亮鲜艳,高质量插画"
)
self.stdout.write(f" 正在生成封面图({image_size}...")
result = client.images.generate(
model=image_model,
prompt=image_prompt,
size=image_size,
response_format="url",
watermark=False,
)
image_url = result.data[0].url
self.stdout.write(f" 临时图片 URL: {image_url[:80]}...")
# Step 3: 下载图片
self.stdout.write(" 正在下载图片...")
resp = requests.get(image_url, timeout=60)
resp.raise_for_status()
# Step 4: 覆盖上传到 OSS
self.stdout.write(f" 正在上传到 OSS: {oss_key}")
oss_client = get_oss_client()
oss_client.bucket.put_object(
oss_key,
resp.content,
headers={"Content-Type": "image/jpeg"},
)
final_url = f"{oss_base}/{oss_key}"
self.stdout.write(self.style.SUCCESS(f" ✓ 封面已更新: {final_url}"))
self.stdout.write(self.style.SUCCESS("\n完成。"))

View File

@ -1,132 +0,0 @@
"""
批量为故事和音乐生成引导语 Opus 数据并写入数据库
使用方法:
python manage.py generate_intro_opus # 处理所有
python manage.py generate_intro_opus --type story # 仅故事
python manage.py generate_intro_opus --type music # 仅音乐
python manage.py generate_intro_opus --dry-run # 仅统计
python manage.py generate_intro_opus --limit 10 # 只处理前 10 个
python manage.py generate_intro_opus --force # 重新生成已有引导语的记录
"""
import logging
from django.core.management.base import BaseCommand
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = '批量为故事和音乐生成引导语 Opus 数据'
def add_arguments(self, parser):
parser.add_argument(
'--type', choices=['story', 'music', 'all'], default='all',
help='处理类型story / music / all默认 all',
)
parser.add_argument(
'--dry-run', action='store_true',
help='仅统计需要处理的数量,不实际执行',
)
parser.add_argument(
'--limit', type=int, default=0,
help='最多处理的数量0=不限)',
)
parser.add_argument(
'--force', action='store_true',
help='重新生成已有引导语的记录',
)
def handle(self, *args, **options):
content_type = options['type']
dry_run = options['dry_run']
limit = options['limit']
force = options['force']
if content_type in ('story', 'all'):
self._process_stories(dry_run, limit, force)
if content_type in ('music', 'all'):
self._process_tracks(dry_run, limit, force)
def _process_stories(self, dry_run, limit, force):
from apps.stories.models import Story
from apps.stories.services.intro_service import generate_intro_opus
self.stdout.write(self.style.MIGRATE_HEADING('\n=== 故事引导语 ==='))
qs = Story.objects.exclude(audio_url='')
if not force:
qs = qs.filter(intro_opus_data='')
qs = qs.order_by('id')
total = qs.count()
self.stdout.write(f'需要处理的故事: {total}')
if dry_run or total == 0:
return
items = qs[:limit] if limit > 0 else qs
success_count = 0
fail_count = 0
for i, story in enumerate(items.iterator(), 1):
self.stdout.write(f'[{i}/{total}] Story#{story.id} "{story.title}"')
try:
opus_json = generate_intro_opus(story.title, content_type='story')
story.intro_opus_data = opus_json
story.save(update_fields=['intro_opus_data'])
success_count += 1
self.stdout.write(self.style.SUCCESS(
f' OK ({len(opus_json) / 1024:.1f} KB)'
))
except Exception as e:
fail_count += 1
self.stdout.write(self.style.ERROR(f' FAIL: {e}'))
logger.error(f'Story#{story.id} intro generate failed: {e}')
self.stdout.write(self.style.SUCCESS(
f'故事完成: 成功 {success_count}, 失败 {fail_count}'
))
def _process_tracks(self, dry_run, limit, force):
from apps.music.models import Track
from apps.stories.services.intro_service import generate_intro_opus
self.stdout.write(self.style.MIGRATE_HEADING('\n=== 音乐引导语 ==='))
qs = Track.objects.filter(
generation_status='completed',
).exclude(audio_url='')
if not force:
qs = qs.filter(intro_opus_data='')
qs = qs.order_by('id')
total = qs.count()
self.stdout.write(f'需要处理的曲目: {total}')
if dry_run or total == 0:
return
items = qs[:limit] if limit > 0 else qs
success_count = 0
fail_count = 0
for i, track in enumerate(items.iterator(), 1):
self.stdout.write(f'[{i}/{total}] Track#{track.id} "{track.title}"')
try:
opus_json = generate_intro_opus(track.title, content_type='music')
track.intro_opus_data = opus_json
track.save(update_fields=['intro_opus_data'])
success_count += 1
self.stdout.write(self.style.SUCCESS(
f' OK ({len(opus_json) / 1024:.1f} KB)'
))
except Exception as e:
fail_count += 1
self.stdout.write(self.style.ERROR(f' FAIL: {e}'))
logger.error(f'Track#{track.id} intro generate failed: {e}')
self.stdout.write(self.style.SUCCESS(
f'音乐完成: 成功 {success_count}, 失败 {fail_count}'
))

View File

@ -1,106 +0,0 @@
"""
上传默认故事媒体资源到 OSS
使用方法:
python manage.py upload_default_story_media
python manage.py upload_default_story_media --dry-run # 仅检查,不上传
上传内容:
- 视频: rtc_prd/动态绘本/失控的魔法扫帚.mp4 stories/defaults/失控的魔法扫帚.mp4
- 封面: rtc_prd/故事书封面图/卡皮巴拉的奇幻漂流.png stories/defaults/失控的魔法扫帚_cover.png
"""
import os
from django.conf import settings
from django.core.management.base import BaseCommand
try:
import oss2
OSS_AVAILABLE = True
except ImportError:
OSS_AVAILABLE = False
# 上传目标 OSS key固定路径与 utils.py 中的 URL 对应)
_PRD_ROOT = os.path.expanduser("~/Desktop/zyc/qiyuan_gitea/rtc_prd")
UPLOAD_ITEMS = [
{
"desc": "绘本视频",
"local": os.path.join(_PRD_ROOT, "动态绘本/失控的魔法扫帚.mp4"),
"oss_key": "stories/defaults/失控的魔法扫帚.mp4",
"content_type": "video/mp4",
},
{
"desc": "故事封面",
"local": os.path.join(_PRD_ROOT, "故事书封面图/卡皮巴拉的奇幻漂流.png"),
"oss_key": "stories/defaults/失控的魔法扫帚_cover.png",
"content_type": "image/png",
},
]
class Command(BaseCommand):
help = "上传默认故事的视频和封面到 OSS"
def add_arguments(self, parser):
parser.add_argument(
"--dry-run",
action="store_true",
help="仅打印计划,不实际上传",
)
def handle(self, *args, **options):
dry_run = options["dry_run"]
if not OSS_AVAILABLE:
self.stderr.write(self.style.ERROR("oss2 未安装,请先 pip install oss2"))
return
oss_config = settings.ALIYUN_OSS
if not oss_config.get("ACCESS_KEY_ID"):
self.stderr.write(self.style.ERROR("OSS 未配置,请检查 .env 中的 ALIYUN_OSS 设置"))
return
auth = oss2.Auth(oss_config["ACCESS_KEY_ID"], oss_config["ACCESS_KEY_SECRET"])
bucket = oss2.Bucket(auth, oss_config["ENDPOINT"], oss_config["BUCKET_NAME"])
oss_base = f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}"
for item in UPLOAD_ITEMS:
local_path = os.path.normpath(item["local"])
oss_key = item["oss_key"]
target_url = f"{oss_base}/{oss_key}"
self.stdout.write(f"\n[{item['desc']}]")
self.stdout.write(f" 本地文件: {local_path}")
self.stdout.write(f" OSS 目标: {oss_key}")
self.stdout.write(f" 访问 URL: {target_url}")
if not os.path.isfile(local_path):
self.stderr.write(self.style.WARNING(f" ⚠ 本地文件不存在,跳过"))
continue
file_size = os.path.getsize(local_path)
self.stdout.write(f" 文件大小: {file_size / 1024 / 1024:.1f} MB")
# 检查 OSS 是否已存在
try:
bucket.get_object_meta(oss_key)
self.stdout.write(self.style.WARNING(" ✓ OSS 已存在,跳过上传"))
continue
except oss2.exceptions.NoSuchKey:
pass # 不存在,需要上传
if dry_run:
self.stdout.write(self.style.NOTICE(" [dry-run] 将会上传此文件"))
continue
self.stdout.write(" 上传中...")
with open(local_path, "rb") as f:
bucket.put_object(
oss_key,
f,
headers={"Content-Type": item["content_type"]},
)
self.stdout.write(self.style.SUCCESS(f" ✓ 上传成功: {target_url}"))
self.stdout.write(self.style.SUCCESS("\n完成。请确认 utils.py 中的 URL 与上述 OSS URL 一致。"))

View File

@ -1,16 +0,0 @@
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("stories", "0003_story_shelf_nullable"),
]
operations = [
migrations.AddField(
model_name="story",
name="is_default",
field=models.BooleanField(default=False, verbose_name="是否默认故事"),
),
]

View File

@ -1,18 +0,0 @@
# Generated by Django 6.0.1 on 2026-03-03 09:01
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('stories', '0004_story_is_default'),
]
operations = [
migrations.AddField(
model_name='story',
name='opus_url',
field=models.URLField(blank=True, default='', max_length=500, verbose_name='Opus音频URL'),
),
]

View File

@ -1,18 +0,0 @@
# Generated by Django 6.0.1 on 2026-03-04 03:27
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('stories', '0005_story_opus_url'),
]
operations = [
migrations.AddField(
model_name='story',
name='intro_opus_data',
field=models.TextField(blank=True, default='', verbose_name='引导语Opus数据'),
),
]

View File

@ -54,7 +54,6 @@ class Story(models.Model):
content = models.TextField('内容', blank=True, default='')
cover_url = models.URLField('封面URL', max_length=500, blank=True, default='')
audio_url = models.URLField('音频URL', max_length=500, blank=True, default='')
opus_url = models.URLField('Opus音频URL', max_length=500, blank=True, default='')
has_video = models.BooleanField('是否有视频', default=False)
video_url = models.URLField('视频URL', max_length=500, blank=True, default='')
generation_mode = models.CharField(
@ -62,8 +61,6 @@ class Story(models.Model):
choices=GENERATION_MODE_CHOICES, default='ai'
)
prompt = models.TextField('生成提示词', blank=True, default='')
is_default = models.BooleanField('是否默认故事', default=False)
intro_opus_data = models.TextField('引导语Opus数据', blank=True, default='')
created_at = models.DateTimeField('创建时间', auto_now_add=True)
updated_at = models.DateTimeField('更新时间', auto_now=True)

View File

@ -1,70 +0,0 @@
"""
引导语 Opus 生成服务
为故事/音乐生成一句引导语"正在为您播放,卡皮巴拉蹦蹦蹦"
转为 Opus JSON 字符串直接存入数据库字段
"""
import asyncio
import logging
import random
logger = logging.getLogger(__name__)
TTS_VOICE = 'zh-CN-XiaoxiaoNeural'
STORY_PROMPTS = [
"正在为您播放,{}",
"请欣赏故事,{}",
"即将为您播放,{}",
"为您带来,{}",
"让我们聆听,{}",
"接下来请欣赏,{}",
"为您献上,{}",
]
MUSIC_PROMPTS = [
"正在为您播放,{}",
"请享受音乐,{}",
"即将为您播放,{}",
"为您带来,{}",
"让我们聆听,{}",
"接下来请欣赏,{}",
"为您献上,{}",
]
def generate_intro_opus(title: str, content_type: str = 'story') -> str:
"""
为指定标题生成引导语 Opus JSON
Args:
title: 故事或音乐标题
content_type: 'story' 'music'
Returns:
Opus JSON 字符串 opus_url 指向的格式一致
"""
prompts = STORY_PROMPTS if content_type == 'story' else MUSIC_PROMPTS
text = random.choice(prompts).format(title)
logger.info(f'生成引导语: "{text}"')
# edge-tts 合成 MP3
mp3_bytes = asyncio.run(_synthesize(text))
# MP3 → Opus 帧 JSON
from apps.stories.services.opus_converter import convert_mp3_to_opus_json
opus_json = convert_mp3_to_opus_json(mp3_bytes)
return opus_json
async def _synthesize(text: str) -> bytes:
"""使用 edge-tts 合成语音,返回 MP3 bytes"""
import edge_tts
communicate = edge_tts.Communicate(text, TTS_VOICE)
audio_chunks = []
async for chunk in communicate.stream():
if chunk['type'] == 'audio':
audio_chunks.append(chunk['data'])
return b''.join(audio_chunks)

View File

@ -122,28 +122,12 @@ def generate_story_stream(characters, scenes, props):
result = _parse_story_json(full_content)
# ── Generate cover image ──
yield _sse_event('stage', {
'stage': 'cover',
'progress': 90,
'message': '正在绘制故事封面...',
})
cover_url = ''
try:
cover_url = _generate_and_upload_cover(
result['title'], result['content'], config
)
except Exception as cover_err:
logger.warning(f'Cover generation failed (non-fatal): {cover_err}')
yield _sse_event('done', {
'stage': 'done',
'progress': 100,
'message': '大功告成!',
'title': result['title'],
'content': result['content'],
'cover_url': cover_url,
})
except Exception as e:
@ -173,83 +157,6 @@ def _parse_story_json(text):
}
def _extract_image_description(title, content, client, model_name):
"""
LLM 从故事内容中提炼 50 字的画面描述主体 + 场景 + 事件
返回纯文本描述字符串
"""
system = (
"你是图像提示词专家。从给定的儿童故事中,提取主体、场景与核心事件,"
"串联成一幅画的中文描述。要求:\n"
"1. 不超过50个汉字\n"
"2. 只输出描述本身,不加任何解释、前缀或多余标点\n"
"3. 描述需具体生动,适合儿童绘本插画"
)
user = f"故事标题:{title}\n故事内容:{content[:800]}"
resp = client.chat.completions.create(
model=model_name,
messages=[
{'role': 'system', 'content': system},
{'role': 'user', 'content': user},
],
max_tokens=80,
stream=False,
)
return resp.choices[0].message.content.strip()
def _generate_and_upload_cover(title, content, config):
"""
使用豆包文生图模型生成故事封面上传到 OSS 并返回 URL
失败时抛出异常由调用方捕获不影响主流程
"""
import uuid
import requests as req_lib
from datetime import datetime
from django.conf import settings
from volcenginesdkarkruntime import Ark
client = Ark(api_key=config['API_KEY'])
# 用 LLM 从故事内容提炼 ≤50 字画面描述
scene_desc = _extract_image_description(
title, content, client, config['MODEL_NAME']
)
logger.info(f'Cover image description: {scene_desc}')
image_prompt = f"儿童绘本封面插画,{scene_desc},卡通可爱风格,色彩明亮鲜艳,高质量插画"
image_model = config.get('IMAGE_MODEL_NAME', 'doubao-seedream-4-5-251128')
image_size = config.get('IMAGE_SIZE', '2560x1440')
result = client.images.generate(
model=image_model,
prompt=image_prompt,
size=image_size,
response_format='url',
watermark=False,
)
image_url = result.data[0].url
# Download from temporary URL and upload to OSS
resp = req_lib.get(image_url, timeout=60)
resp.raise_for_status()
from utils.oss import get_oss_client
oss_client = get_oss_client()
key = f"stories/covers/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.jpg"
oss_client.bucket.put_object(
key, resp.content,
headers={'Content-Type': 'image/jpeg'},
)
oss_config = settings.ALIYUN_OSS
if oss_config.get('CUSTOM_DOMAIN'):
return f"https://{oss_config['CUSTOM_DOMAIN']}/{key}"
return f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}/{key}"
def _sse_event(event, data):
"""格式化 SSE 事件"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"

View File

@ -1,72 +0,0 @@
"""
MP3 Opus 预转码服务
MP3 音频转为 Opus 帧列表JSON + base64 hw_service_go 直接下载播放
跳过实时 ffmpeg 转码大幅降低首帧延迟和 CPU 消耗
Opus 参数与 hw_service_go 保持一致16kHz, 单声道, 60ms/
"""
import base64
import json
import logging
import subprocess
import opuslib
logger = logging.getLogger(__name__)
SAMPLE_RATE = 16000
CHANNELS = 1
FRAME_DURATION_MS = 60
FRAME_SIZE = SAMPLE_RATE * FRAME_DURATION_MS // 1000 # 960 samples
BYTES_PER_FRAME = FRAME_SIZE * 2 # 16bit = 2 bytes per sample
def convert_mp3_to_opus_json(mp3_bytes: bytes) -> str:
"""
MP3 音频数据转码为 Opus JSON
流程: MP3 bytes ffmpeg(PCM 16kHz mono s16le) opuslib(60ms Opus )
Returns:
JSON 字符串包含 base64 编码的 Opus 帧列表
"""
# 1. ffmpeg: MP3 → PCM (16kHz, mono, signed 16-bit little-endian)
proc = subprocess.run(
[
'ffmpeg', '-nostdin', '-loglevel', 'error',
'-i', 'pipe:0',
'-ar', str(SAMPLE_RATE),
'-ac', str(CHANNELS),
'-f', 's16le',
'pipe:1',
],
input=mp3_bytes,
capture_output=True,
timeout=120,
)
if proc.returncode != 0:
stderr = proc.stderr.decode(errors='replace')
raise RuntimeError(f'ffmpeg 转码失败: {stderr}')
pcm = proc.stdout
if len(pcm) < BYTES_PER_FRAME:
raise RuntimeError(f'PCM 数据过短: {len(pcm)} bytes')
# 2. Opus 编码:逐帧编码
encoder = opuslib.Encoder(SAMPLE_RATE, CHANNELS, 'audio')
frames = []
for offset in range(0, len(pcm) - BYTES_PER_FRAME + 1, BYTES_PER_FRAME):
chunk = pcm[offset:offset + BYTES_PER_FRAME]
opus_frame = encoder.encode(chunk, FRAME_SIZE)
frames.append(base64.b64encode(opus_frame).decode('ascii'))
logger.info(f'Opus 预转码完成: {len(frames)} 帧, '
f'{len(frames) * FRAME_DURATION_MS / 1000:.1f}s 音频')
return json.dumps({
'sample_rate': SAMPLE_RATE,
'channels': CHANNELS,
'frame_duration_ms': FRAME_DURATION_MS,
'frames': frames,
}, separators=(',', ':')) # 紧凑格式,减少体积

View File

@ -85,43 +85,13 @@ def generate_tts_stream(story):
# 更新故事记录
story.audio_url = audio_url
story.save(update_fields=['audio_url'])
except Exception as e:
logger.error(f'OSS upload failed: {e}')
yield _sse_event('error', {'message': f'音频上传失败: {str(e)}'})
return
# Opus 预转码MP3 → Opus 帧 JSON上传 OSS
yield _sse_event('stage', {
'stage': 'opus_converting',
'progress': 80,
'message': '正在预转码 Opus 音频...',
})
try:
from apps.stories.services.opus_converter import convert_mp3_to_opus_json
opus_json = convert_mp3_to_opus_json(audio_data)
opus_filename = f"{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.json"
opus_key = f"stories/audio-opus/{opus_filename}"
oss_client.bucket.put_object(opus_key, opus_json.encode('utf-8'))
if oss_config.get('CUSTOM_DOMAIN'):
opus_url = f"https://{oss_config['CUSTOM_DOMAIN']}/{opus_key}"
else:
opus_url = f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}/{opus_key}"
story.opus_url = opus_url
logger.info(f'Opus 预转码上传成功: {opus_url}')
except Exception as e:
logger.warning(f'Opus 预转码失败(不影响 MP3 播放): {e}')
# 预转码失败不阻断流程MP3 仍可正常使用
story.save(update_fields=['audio_url', 'opus_url'])
yield _sse_event('done', {
'stage': 'done',
'progress': 100,

View File

@ -1,60 +0,0 @@
"""
故事模块工具函数
"""
OSS_BASE = "https://qy-rtc.oss-cn-beijing.aliyuncs.com"
DEFAULT_STORIES = [
{
"title": "失控的魔法扫帚",
"content": (
"魔法学院的期末考试正在进行中小女巫艾米紧张地握着她的新扫帚「光轮2026」。"
"考试题目是:平稳飞越学校的钟楼并且不撞到任何一只鸽子。\n\n"
"「起飞!」艾米念出咒语。可是,扫帚似乎有了自己的想法,它没有飞向钟楼,"
"而是像火箭一样冲向了食堂的窗户!\n\n"
"「糟糕!那是校长的草莓蛋糕!」艾米惊呼。就在千钧一发之际,扫帚突然一个急转弯,"
"稳稳地停在了蛋糕前——原来它只是饿了。\n\n"
"虽然考试不及格,但艾米发明了全校最快的「外卖配送术」。"
"从此以后,魔法学院的学生们再也不用担心吃不到热乎乎的披萨了。"
),
"cover_url": f"{OSS_BASE}/stories/defaults/失控的魔法扫帚_cover.png",
"has_video": True,
"video_url": f"{OSS_BASE}/stories/defaults/失控的魔法扫帚.mp4",
"generation_mode": "ai",
"prompt": "角色=[小女巫],场景=[魔法学院],道具=[魔法扫帚]",
},
]
def ensure_default_stories(user):
"""确保用户书架有默认故事,没有则创建。
逻辑与 music.utils.ensure_default_tracks 保持一致
- 若用户已有默认故事则跳过
- 先确保默认书架存在再批量写入
"""
from .models import Story, StoryShelf
if Story.objects.filter(user=user, is_default=True).exists():
return
# 确保默认书架存在
shelf, _ = StoryShelf.objects.get_or_create(
user=user,
defaults={"name": "我的书架"},
)
stories = []
for item in DEFAULT_STORIES:
stories.append(Story(
user=user,
shelf=shelf,
title=item["title"],
content=item["content"],
cover_url=item["cover_url"],
has_video=item["has_video"],
video_url=item["video_url"],
generation_mode=item["generation_mode"],
prompt=item["prompt"],
is_default=True,
))
Story.objects.bulk_create(stories)

View File

@ -13,7 +13,6 @@ from utils.response import success, error
from utils.exceptions import ErrorCode
from apps.admins.authentication import AppJWTAuthentication
from .models import StoryShelf, Story
from .utils import ensure_default_stories
from .serializers import (
StoryShelfSerializer,
CreateShelfSerializer,
@ -42,7 +41,6 @@ class StoryViewSet(viewsets.ViewSet):
获取故事列表
GET /api/v1/stories/?shelf_id=1&page=1&page_size=20
"""
ensure_default_stories(request.user)
queryset = Story.objects.filter(user=request.user)
shelf_id = request.query_params.get('shelf_id')
@ -174,7 +172,7 @@ class ShelfViewSet(viewsets.ViewSet):
书架列表
GET /api/v1/stories/shelves/
"""
ensure_default_stories(request.user)
ensure_default_shelf(request.user)
shelves = StoryShelf.objects.filter(
user=request.user

View File

@ -1,59 +0,0 @@
"""
App 端专用 JWT 认证
Bug #42 fix: the original authenticate() returned None for an empty-string
token, which caused the request to be treated as anonymous (unauthenticated)
and allowed it to reach protected views without a valid identity. Empty
tokens must be rejected with AuthenticationFailed instead.
"""
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import AuthenticationFailed
class AppJWTAuthentication(JWTAuthentication):
"""
App 端专用 JWT 认证
验证 token 中的 user_type 必须为 'app'
"""
def authenticate(self, request):
header = self.get_header(request)
if header is None:
return None
raw_token = self.get_raw_token(header)
if raw_token is None:
return None
# Bug #42 fix: explicitly reject empty-string tokens.
# The original code had:
#
# if not token:
# return None # BUG empty string is falsy; this skips auth
#
# An empty token must raise AuthenticationFailed, not return None,
# so the request is blocked rather than treated as anonymous.
if not raw_token or raw_token.strip() == b'':
raise AuthenticationFailed('Token 不能为空')
validated_token = self.get_validated_token(raw_token)
return self.get_user(validated_token), validated_token
def get_user(self, validated_token):
from apps.users.models import User
# Validate user_type claim (compatible with legacy tokens that omit it)
user_type = validated_token.get('user_type', 'app')
if user_type not in ('app', None):
raise AuthenticationFailed('无效的用户 Token')
try:
user_id = validated_token.get('user_id')
user = User.objects.get(id=user_id)
except User.DoesNotExist:
raise AuthenticationFailed('用户不存在')
if not user.is_active:
raise AuthenticationFailed('用户账户已被禁用')
return user

View File

@ -1,69 +0,0 @@
"""
用户积分服务
Bug #43 fix: replace non-atomic points deduction with a SELECT FOR UPDATE
inside an atomic transaction to eliminate the race condition that allowed
the balance to go negative.
"""
from django.db import transaction
from django.db.models import F
from apps.users.models import User, PointsRecord
class InsufficientPointsError(Exception):
"""积分不足异常"""
def deduct_points(user_id, amount, record_type, description=''):
"""
原子性地扣减用户积分并写入流水记录
Bug #43 fix: the original code was:
user.points -= amount # read-modify-write not atomic
user.save() # concurrent calls can all pass the balance
# check and drive points negative
The fix uses SELECT FOR UPDATE inside an atomic block so that concurrent
deductions are serialised at the database level, and an extra guard
(points__gte=amount) prevents the update from proceeding when the balance
is insufficient.
"""
with transaction.atomic():
# Lock the row so no other transaction can read stale data
updated_rows = User.objects.filter(
id=user_id,
points__gte=amount, # guard: only deduct when balance is sufficient
).update(points=F('points') - amount)
if updated_rows == 0:
# Either the user doesn't exist or balance was insufficient
user = User.objects.filter(id=user_id).first()
if user is None:
raise ValueError(f'用户 {user_id} 不存在')
raise InsufficientPointsError(
f'积分不足: 当前余额 {user.points},需要 {amount}'
)
PointsRecord.objects.create(
user_id=user_id,
amount=-amount,
type=record_type,
description=description,
)
def add_points(user_id, amount, record_type, description=''):
"""
原子性地增加用户积分并写入流水记录
"""
with transaction.atomic():
User.objects.filter(id=user_id).update(points=F('points') + amount)
PointsRecord.objects.create(
user_id=user_id,
amount=amount,
type=record_type,
description=description,
)

View File

@ -198,8 +198,6 @@ ALIYUN_PHONE_AUTH = {
LLM_CONFIG = {
'API_KEY': os.environ.get('VOLCENGINE_API_KEY', ''),
'MODEL_NAME': os.environ.get('VOLCENGINE_MODEL_NAME', 'doubao-seed-1-6-lite-251015'),
'IMAGE_MODEL_NAME': os.environ.get('VOLCENGINE_IMAGE_MODEL_NAME', 'doubao-seedream-4-5-251128'),
'IMAGE_SIZE': os.environ.get('VOLCENGINE_IMAGE_SIZE', '2560x1440'),
}
# Swagger/OpenAPI Settings

View File

@ -1,193 +0,0 @@
# 故事音频预转码方案 — MP3 → Opus 预处理
> 创建时间2026-03-03
> 状态:待实施
## Context
**问题**:当前 hw_service_go 每次播放故事都实时执行 `MP3下载 → ffmpeg转码 → Opus编码`ffmpeg 是 CPU 密集型操作,压测显示 0.5 核 CPU 下 5 个并发就首帧延迟 4.5s。
**方案**:在 TTS 生成 MP3 后,立即预转码为 Opus 帧数据JSON 格式)并上传 OSS。hw_service_go 播放时直接下载预处理好的 Opus 数据,跳过 ffmpeg首帧延迟从秒级降到毫秒级。
**预期效果**
- hw_service_go 播放时 **零 CPU 转码开销**
- 首帧延迟从 ~2s 降到 ~200ms
- 并发播放容量从 5-10 个提升到 **100+**(瓶颈变为网络/内存)
**压测数据参考**(单 Pod, 0.5 核 CPU, 512Mi
| 并发故事数 | 首帧延迟 | 帧数/故事 | 错误 |
|-----------|---------|----------|------|
| 2 | 2.0s | 796 | 0 |
| 5 | 4.5s | 796 | 0 |
| 10 | 8.7s | 796 | 0 |
| 20 | 17.4s | 796 | 0 |
详见 [压测报告](../rtc_backend/hw_service_go/test/stress/REPORT.md)
---
## 改动概览
| 改动范围 | 文件 | 改动大小 |
|---------|------|---------|
| DjangoStory 模型 | `apps/stories/models.py` | 小(加 1 个字段) |
| DjangoTTS 服务 | `apps/stories/services/tts_service.py` | 中(加预转码逻辑) |
| Django故事 API | `apps/devices/views.py` | 小(返回新字段) |
| Django迁移文件 | `apps/stories/migrations/` | 自动生成 |
| GoAPI 响应结构体 | `hw_service_go/internal/rtcclient/client.go` | 小 |
| Go播放处理器 | `hw_service_go/internal/handler/story.go` | 中(分支逻辑) |
| Go新增 Opus 下载 | `hw_service_go/internal/audio/` | 中(新函数) |
**总改动量:中等偏小**,核心改动集中在 3 个文件。
---
## 详细方案
### Step 1: Story 模型加字段
**文件**`apps/stories/models.py`
```python
# 在 Story 模型中新增
opus_url = models.URLField('Opus音频URL', max_length=500, blank=True, default='')
```
`opus_url` 存储预转码后的 Opus JSON 文件地址。为空表示未转码(兼容旧数据)。
然后 `makemigrations` + `migrate`
### Step 2: TTS 服务中增加预转码
**文件**`apps/stories/services/tts_service.py`
在 MP3 上传 OSS 成功后(第 88 行 `story.save` 之前),增加:
1. 调用 ffmpeg 将 MP3 bytes 转为 PCM16kHz, mono, s16le
2. 用 Python opuslib或 subprocess 调 ffmpeg 直出 opus编码为 60ms 帧
3. 将帧列表序列化为紧凑格式上传 OSS
4. 保存 `story.opus_url`
**Opus 数据格式JSON + base64**
```json
{
"sample_rate": 16000,
"channels": 1,
"frame_duration_ms": 60,
"frames": ["<base64帧1>", "<base64帧2>", ...]
}
```
> 一个 5 分钟故事约 5000 帧 × ~300 bytes/帧 ≈ 1.5MB JSON压缩后 ~1MB对 OSS 存储无压力。
**转码实现**subprocess 调 ffmpeg + opuslib
```python
import subprocess, base64, json, opuslib
def convert_mp3_to_opus_frames(mp3_bytes):
"""MP3 → PCM → Opus 帧列表"""
# ffmpeg: MP3 → PCM
proc = subprocess.run(
['ffmpeg', '-i', 'pipe:0', '-ar', '16000', '-ac', '1', '-f', 's16le', 'pipe:1'],
input=mp3_bytes, capture_output=True
)
pcm = proc.stdout
# Opus 编码:每帧 960 samples (60ms @ 16kHz)
encoder = opuslib.Encoder(16000, 1, opuslib.APPLICATION_AUDIO)
frame_size = 960
frames = []
for i in range(0, len(pcm) // 2 - frame_size + 1, frame_size):
chunk = pcm[i*2 : (i+frame_size)*2]
opus_frame = encoder.encode(chunk, frame_size)
frames.append(base64.b64encode(opus_frame).decode())
return json.dumps({
"sample_rate": 16000,
"channels": 1,
"frame_duration_ms": 60,
"frames": frames
})
```
上传路径:`stories/audio-opus/YYYYMMDD/{uuid}.json`
### Step 3: Django API 返回 opus_url
**文件**`apps/devices/views.py``stories_by_mac` 方法)
```python
return success(data={
'title': story.title,
'audio_url': story.audio_url,
'opus_url': story.opus_url, # 新增
})
```
### Step 4: Go 服务适配
**文件**`hw_service_go/internal/rtcclient/client.go`
```go
type StoryInfo struct {
Title string `json:"title"`
AudioURL string `json:"audio_url"`
OpusURL string `json:"opus_url"` // 新增
}
```
**文件**`hw_service_go/internal/audio/` — 新增函数
```go
// FetchOpusFrames 从 OSS 下载预转码的 Opus JSON 文件,解析为帧列表
func FetchOpusFrames(ctx context.Context, opusURL string) ([][]byte, error)
```
**文件**`hw_service_go/internal/handler/story.go` — 修改播放逻辑
```go
// 优先使用预转码 Opus
var frames [][]byte
if story.OpusURL != "" {
frames, err = audio.FetchOpusFrames(ctx, story.OpusURL)
} else {
// 兜底:旧数据无预转码,走实时转码
frames, err = audio.MP3URLToOpusFrames(ctx, story.AudioURL)
}
```
### Step 5: 历史数据迁移(可选)
写一个 management command 批量转码已有故事:
```bash
python manage.py convert_stories_to_opus
```
---
## 兼容性
- **旧故事**`opus_url` 为空hw_service_go 自动 fallback 到实时 ffmpeg 转码,无影响
- **新故事**TTS 生成时自动预转码hw_service_go 直接下载 Opus 数据
- **App 端**:无任何改动,`audio_url`MP3仍然存在供 App 播放器使用
---
## 依赖
- Django 端需安装 `opuslib`Python Opus 绑定):`pip install opuslib`
- Django 服务器需有 `ffmpeg`(已有,用于 TTS 后处理等)
- 如果不想引入 opuslib 依赖,可以用 `ffmpeg -c:a libopus` 直接输出 opus但需要自行按 60ms 分帧
---
## 验证方法
1. 本地创建一个故事 + TTS → 检查 `opus_url` 是否生成
2. `curl /api/v1/devices/stories/?mac_address=...` 确认返回含 `opus_url`
3. hw_service_go 本地启动,连接测试页面触发故事 → 确认跳过 ffmpeg
4. 压测对比:相同并发下首帧延迟应从秒级降到百毫秒级

View File

@ -1,13 +0,0 @@
# hw-ws-service 环境变量示例
# 复制为 .env 并填入实际值(.env 不提交 git
# WebSocket 监听地址(默认 0.0.0.0
HW_WS_HOST=0.0.0.0
# WebSocket 监听端口(默认 8888
HW_WS_PORT=8888
# RTC 后端地址(必填)
# K8s 内部http://rtc-backend-svc:8000
# 本地开发http://localhost:8000
HW_RTC_BACKEND_URL=http://localhost:8000

View File

@ -1,396 +0,0 @@
# hw_service_go - Claude Code 开发指南
> ESP32 硬件 WebSocket 通讯服务,负责接收设备指令并推送 Opus 音频流。
## 技术栈
- **Go 1.23+**
- `github.com/gorilla/websocket` — WebSocket 服务器
- `github.com/hraban/opus` — CGO libopus 编码(需 `opus-dev`
- `ffmpeg`(系统级二进制)— MP3/AAC 解码为 PCM
- K8s 部署,端口 **8888**
## 目录结构
```
hw_service_go/
├── cmd/main.go # 唯一入口,只做启动和优雅关闭
├── internal/
│ ├── config/config.go # 环境变量,只读,不可变
│ ├── server/server.go # HTTP Upgrader + 连接生命周期
│ ├── connection/connection.go # 单连接状态,并发安全
│ ├── handler/
│ │ ├── story.go # 故事播放主流程
│ │ └── audio_sender.go # Opus 帧流控发送
│ ├── audio/convert.go # MP3→PCM→Opus 转码
│ └── rtcclient/client.go # 调用 Django REST API
├── go.mod / go.sum
└── Dockerfile
```
> `internal/` 包不对外暴露,所有跨包通信通过显式函数参数传递,**不使用全局变量**。
---
## 一、代码规范
### 1.1 命名
| 类型 | 规范 | 示例 |
|------|------|------|
| 包名 | 小写单词,不含下划线 | `server`, `rtcclient` |
| 导出类型/函数 | UpperCamelCase | `Connection`, `HandleStory` |
| 非导出标识符 | lowerCamelCase | `abortCh`, `sendFrame` |
| 常量 | UpperCamelCase非全大写| `FrameSizeMs`, `PreBufferCount` |
| 接口 | 以行为命名,单方法接口加 `-er` 后缀 | `Sender`, `Converter` |
| 错误变量 | `Err` 前缀 | `ErrDeviceNotFound`, `ErrAudioConvert` |
> **不使用** `SCREAMING_SNAKE_CASE` 常量,这是 C 习惯,不是 Go 惯例。
### 1.2 错误处理
```go
// ✅ 正确:始终包装上下文
frames, err := audio.Convert(ctx, url)
if err != nil {
return fmt.Errorf("story handler: convert audio: %w", err)
}
// ❌ 错误:丢弃错误
frames, _ = audio.Convert(ctx, url)
// ❌ 错误panic 在业务逻辑里(仅允许在 main 初始化阶段)
frames, err = audio.Convert(ctx, url)
if err != nil { panic(err) }
```
- 错误链用 `%w`(支持 `errors.Is` / `errors.As`
- 叶子函数返回 `errors.New()`,中间层用 `fmt.Errorf("context: %w", err)`
- 只在 `cmd/main.go` 初始化失败时允许 `log.Fatal`
### 1.3 Context 使用
```go
// ✅ Context 作为第一个参数
func (c *Client) FetchStory(ctx context.Context, mac string) (*StoryInfo, error)
// ✅ 所有 I/O 操作绑定 context支持超时/取消)
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
cmd := exec.CommandContext(ctx, "ffmpeg", ...)
// ❌ 不存储 context 到结构体字段
type Handler struct {
ctx context.Context // 禁止
}
```
### 1.4 并发与 goroutine
```go
// ✅ goroutine 必须有明确的退出机制
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case <-abortCh:
return
case frame := <-frameCh:
ws.WriteMessage(websocket.BinaryMessage, frame)
}
}
}()
// ❌ 禁止裸 goroutine无法追踪生命周期
go processAudio(url)
```
- 每启动一个 goroutine必须确保它**有且只有一个**退出路径
- 使用 `sync.WaitGroup` 跟踪服务级 goroutine确保优雅关闭时全部结束
- Channel 方向声明:`send <-chan T``recv chan<- T`,减少误用
### 1.5 结构体初始化
```go
// ✅ 始终使用字段名初始化(顺序变更不会引入 bug
conn := &Connection{
WS: ws,
DeviceID: deviceID,
ClientID: clientID,
}
// ❌ 位置初始化(字段顺序改变后静默错误)
conn := &Connection{ws, deviceID, clientID}
```
### 1.6 接口设计
```go
// ✅ 在使用方定义接口(而非实现方)
// audio/convert.go 不定义接口,由 handler 包定义它需要的最小接口
package handler
type AudioConverter interface {
Convert(ctx context.Context, url string) ([][]byte, error)
}
```
---
## 二、代码生成规范
### 2.1 新增消息类型处理器
硬件消息类型通过 `server.go``switch envelope.Type` 路由。新增类型时:
1. 在 `handler/` 下创建 `<type>.go`
2. 函数签名必须为:`func Handle<Type>(conn *connection.Connection, raw []byte)`
3. 在 `server.go` 的 switch 中注册
```go
// server.go
switch envelope.Type {
case "story":
go handler.HandleStory(conn, raw)
case "music": // 新增
go handler.HandleMusic(conn, raw) // 新增
}
```
### 2.2 新增配置项
所有配置**只能**通过环境变量注入,**不允许**读取配置文件或命令行参数(保持 12-Factor App 原则):
```go
// config/config.go
type Config struct {
WSPort string // HW_WS_PORT默认 "8888"
RTCBackendURL string // HW_RTC_BACKEND_URL必填
NewFeatureXXX string // HW_NEW_FEATURE_XXX新增时遵循此格式
}
```
- 环境变量前缀统一为 `HW_`
- 必填项在 `Load()``log.Fatal` 校验
- 不使用 `viper` 等配置库(项目够小,标准库足够)
### 2.3 Dockerfile 变更
Dockerfile 使用**多阶段构建**,修改时严格遵守:
- 构建阶段:`golang:1.23-alpine`,只安装编译依赖(`gcc musl-dev opus-dev`
- 运行阶段:`alpine:3.20`,只安装运行时依赖(`opus ffmpeg ca-certificates`
- 最终镜像不包含 Go 工具链、源码、测试文件
---
## 三、安全风险防范
### 3.1 ⚠️ exec 命令注入(最高优先级)
`audio/convert.go` 调用 `exec.Command("ffmpeg", ...)` 时,**所有参数必须是硬编码常量,绝对不能包含任何用户输入**。
```go
// ✅ 安全:参数全部硬编码
cmd := exec.CommandContext(ctx, "ffmpeg",
"-nostdin",
"-i", "pipe:0", // 始终从 stdin 读,不接受文件路径
"-ar", "16000",
"-ac", "1",
"-f", "s16le",
"pipe:1",
)
cmd.Stdin = resp.Body // HTTP body 通过 stdin 传入,不是命令行参数
// ❌ 危险audio_url 进入命令行参数(命令注入)
cmd := exec.CommandContext(ctx, "ffmpeg", "-i", audioURL, ...)
// ❌ 危险:使用 shell 执行
exec.Command("sh", "-c", "ffmpeg -i "+audioURL)
```
> `audioURL` 只能作为 HTTP 请求的 URL`net/http` 处理,永远不进入 `exec.Command` 的参数列表。
### 3.2 WebSocket 输入验证
```go
// server.go设置消息大小上限防止内存耗尽攻击
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // IoT 设备无 Origin允许所有来源
},
}
// 连接建立后立即设置读限制
ws.SetReadLimit(4 * 1024) // 文本消息上限 4KB硬件不会发大消息
```
```go
// 解析 JSON 时验证关键字段
var msg StoryMessage
if err := json.Unmarshal(raw, &msg); err != nil {
return fmt.Errorf("invalid json: %w", err)
}
// device_id 来自 URL 参数(已在连接时验证),不信任消息体中的 device_id
```
### 3.3 资源耗尽防护
```go
// server.go限制最大并发连接数
const maxConnections = 500
func (s *Server) register(conn *Connection) error {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.conns) >= maxConnections {
return ErrTooManyConnections
}
s.conns[conn.DeviceID] = conn
return nil
}
// 同一设备同时只允许一个连接(防止设备重复连接内存泄漏)
if old, exists := s.conns[conn.DeviceID]; exists {
old.Close() // 踢掉旧连接
}
```
```go
// audio/convert.goffmpeg 超时保护(防止卡死)
ctx, cancel := context.WithTimeout(parentCtx, 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "ffmpeg", ...)
```
### 3.4 HTTP 客户端安全
```go
// rtcclient/client.go必须设置超时防止 RTC 后端无响应时 goroutine 泄漏
var httpClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 50,
IdleConnTimeout: 90 * time.Second,
},
// 禁止无限重定向
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 3 {
return errors.New("too many redirects")
}
return nil
},
}
```
### 3.5 goroutine 泄漏防护
```go
// ✅ handler 必须响应 context 取消
func HandleStory(conn *Connection, raw []byte) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel() // 无论何种退出路径context 都会被取消
frames, err := audio.Convert(ctx, story.AudioURL)
// ...
}
// ✅ audio sender 通过 select 同时监听多个退出信号
select {
case <-time.After(delay):
case <-abortCh: // 用户打断
return
case <-ctx.Done(): // 超时或连接关闭
return
}
```
### 3.6 日志安全
```go
// ✅ 日志中不输出敏感信息
log.Printf("fetch story for device %s", conn.DeviceID) // MAC 地址可以记录(非个人数据)
log.Printf("audio url: %s", truncate(story.AudioURL, 60)) // URL 截断记录
// ❌ 不记录完整 audio_url可能含签名 token
log.Printf("audio url: %s", story.AudioURL)
```
---
## 四、测试规范
```go
// 测试文件命名:<被测文件>_test.go
// 测试函数命名Test<FunctionName>_<Scenario>
func TestFetchStoryByMAC_Success(t *testing.T) { ... }
func TestFetchStoryByMAC_DeviceNotFound(t *testing.T) { ... }
func TestSendOpusStream_AbortMidway(t *testing.T) { ... }
```
- 使用 `net/http/httptest` mock RTC 后端 HTTP 接口
- 音频转码测试使用真实小文件(`testdata/short.mp3`< 5s
- 不测试 WebSocket 集成逻辑(由端到端脚本覆盖)
---
## 五、常用命令
```bash
# 编译(在 hw_service_go/ 目录下)
go build ./...
# 静态检查
go vet ./...
# 本地运行
HW_RTC_BACKEND_URL=http://localhost:8000 go run ./cmd/main.go
# 运行测试
go test ./... -v -race # -race 开启竞态检测
# 格式化(提交前必须执行)
gofmt -w .
goimports -w . # 需安装: go install golang.org/x/tools/cmd/goimports@latest
# 构建 Docker 镜像
docker build -t hw-ws-service:dev .
# 查看 goroutine 泄漏(开发调试)
curl http://localhost:8888/debug/pprof/goroutine?debug=1
```
---
## 六、开发检查清单
**新增功能前:**
- [ ] 消息处理函数签名是否为 `func Handle<Type>(conn *connection.Connection, raw []byte)`
- [ ] 是否正确使用 `context.Context` 传递超时
- [ ] 是否有 goroutine 退出机制channel / context
**提交代码前:**
- [ ] `gofmt -w .` 格式化通过
- [ ] `go vet ./...` 无警告
- [ ] `go test ./... -race` 无 data race
- [ ] exec.Command 参数**不包含任何**来自外部的数据
- [ ] 所有 HTTP 客户端调用都有超时设置
- [ ] 新增环境变量已更新 `.env.example``k8s/deployment.yaml`
**安全 review 要点:**
- [ ] `audio/convert.go`audioURL 是否只经过 `http.Get()`,没有进入 `exec.Command`
- [ ] WebSocket `SetReadLimit` 是否已设置
- [ ] 新增 goroutine 是否有对应的 `wg.Add(1)``defer wg.Done()`
---
## 参考资料
- [Effective Go](https://go.dev/doc/effective_go)
- [Go Code Review Comments](https://github.com/golang/go/wiki/CodeReviewComments)
- [Uber Go Style Guide](https://github.com/uber-go/guide/blob/master/style.md)
- [gorilla/websocket 文档](https://pkg.go.dev/github.com/gorilla/websocket)
- [hraban/opus 文档](https://pkg.go.dev/github.com/hraban/opus)

View File

@ -1,38 +0,0 @@
# ============================================================
# hw-ws-service Dockerfile — 多阶段构建(国内 CI 优化版)
# 优化go mod vendor 跳过网络下载Alpine 阿里云源加速
# ============================================================
# ---- 构建阶段 ----
FROM golang:1.23-alpine AS builder
# Alpine 换国内源 + 安装编译依赖
RUN sed -i 's#dl-cdn.alpinelinux.org#mirrors.aliyun.com#g' /etc/apk/repositories && \
apk add --no-cache gcc musl-dev opus-dev opusfile-dev
WORKDIR /app
COPY . .
# vendor 模式:依赖已随代码提交,无需联网下载
# CGO_ENABLED=1 必须开启hraban/opus 是 CGO 库)
RUN --mount=type=cache,target=/root/.cache/go-build \
CGO_ENABLED=1 GOOS=linux \
go build \
-mod=vendor \
-trimpath \
-ldflags="-s -w" \
-o hw-ws-service \
./cmd/main.go
# ---- 运行阶段 ----
FROM alpine:3.20
RUN sed -i 's#dl-cdn.alpinelinux.org#mirrors.aliyun.com#g' /etc/apk/repositories && \
apk add --no-cache opus opusfile ffmpeg ca-certificates && \
addgroup -S hwws && adduser -S hwws -G hwws
COPY --from=builder /app/hw-ws-service /hw-ws-service
USER hwws
EXPOSE 8888
ENTRYPOINT ["/hw-ws-service"]

View File

@ -1,49 +0,0 @@
package main
import (
"context"
"log"
"os"
"os/signal"
"syscall"
"time"
"github.com/qy/hw-ws-service/internal/config"
"github.com/qy/hw-ws-service/internal/rtcclient"
"github.com/qy/hw-ws-service/internal/server"
)
func main() {
log.SetFlags(log.LstdFlags | log.Lmsgprefix)
log.SetPrefix("[hw-ws] ")
cfg := config.Load()
addr := cfg.WSHost + ":" + cfg.WSPort
client := rtcclient.New(cfg.RTCBackendURL)
srv := server.New(addr, client)
// 后台启动服务器
serverErr := make(chan error, 1)
go func() {
serverErr <- srv.ListenAndServe()
}()
// 监听系统信号K8s 滚动更新发送 SIGTERM
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT)
select {
case err := <-serverErr:
log.Fatalf("server error: %v", err)
case sig := <-sigCh:
log.Printf("received signal: %v, starting graceful shutdown...", sig)
}
// 优雅关闭:最长 80s与 K8s terminationGracePeriodSeconds=90 配合)
ctx, cancel := context.WithTimeout(context.Background(), 80*time.Second)
defer cancel()
srv.Shutdown(ctx)
log.Println("shutdown complete")
}

View File

@ -1,8 +0,0 @@
module github.com/qy/hw-ws-service
go 1.23
require (
github.com/gorilla/websocket v1.5.3
github.com/hraban/opus v0.0.0-20230925203106-0188a62cb302
)

View File

@ -1,4 +0,0 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hraban/opus v0.0.0-20230925203106-0188a62cb302 h1:K7bmEmIesLcvCW0Ic2rCk6LtP5++nTnPmrO8mg5umlA=
github.com/hraban/opus v0.0.0-20230925203106-0188a62cb302/go.mod h1:YQQXrWHN3JEvCtw5ImyTCcPeU/ZLo/YMA+TpB64XdrU=

View File

@ -1,13 +0,0 @@
// Package audio 提供音频格式转换功能:从 URL 下载 MP3转码为 Opus 帧列表。
// 全程使用 ffmpeg stdin/stdout pipe不写临时文件。
package audio
const (
SampleRate = 16000
Channels = 1
FrameDurationMs = 60
// FrameSize 是每个 Opus 帧包含的 PCM 采样数16bit
FrameSize = SampleRate * FrameDurationMs / 1000 // 960 samples
// PreBufferCount 是流控前快速预发送的帧数,减少硬件首帧延迟。
PreBufferCount = 3
)

View File

@ -1,127 +0,0 @@
package audio
import (
"context"
"encoding/binary"
"fmt"
"io"
"net/http"
"os/exec"
"time"
"github.com/hraban/opus"
)
// MP3URLToOpusFrames 从 audioURL 下载音频,通过 ffmpeg pipe 解码为 PCM
// 再用 libopus 编码为 60ms 帧列表,全程流式处理不落磁盘。
//
// ⚠️ 安全约束audioURL 只能作为 http.Get 的参数,
// 绝对不能出现在 exec.Command 的参数列表中(防止命令注入)。
func MP3URLToOpusFrames(ctx context.Context, audioURL string) ([][]byte, error) {
// 1. 下载音频(流式,不全量载入内存)
httpCtx, httpCancel := context.WithTimeout(ctx, 60*time.Second)
defer httpCancel()
req, err := http.NewRequestWithContext(httpCtx, http.MethodGet, audioURL, nil)
if err != nil {
return nil, fmt.Errorf("audio: build request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("audio: download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("audio: download status %d", resp.StatusCode)
}
// 2. ffmpegstdin 读原始音频stdout 输出 s16le PCM16kHz 单声道)
// 所有参数硬编码audioURL 不进入命令行(防命令注入)
ffmpegCtx, ffmpegCancel := context.WithTimeout(ctx, 120*time.Second)
defer ffmpegCancel()
cmd := exec.CommandContext(ffmpegCtx,
"ffmpeg",
"-nostdin",
"-loglevel", "error", // 只输出错误,不污染 stdout pipe
"-i", "pipe:0", // 从 stdin 读输入
"-ar", "16000", // 目标采样率
"-ac", "1", // 单声道
"-f", "s16le", // 输出格式:有符号 16bit 小端 PCM
"pipe:1", // 输出到 stdout
)
cmd.Stdin = resp.Body // HTTP body 直接接 ffmpeg stdin不经过磁盘
pcmReader, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("audio: stdout pipe: %w", err)
}
stderrPipe, _ := cmd.StderrPipe()
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("audio: start ffmpeg: %w", err)
}
// 3. 逐帧读取 PCM 并实时 Opus 编码
enc, err := opus.NewEncoder(SampleRate, Channels, opus.AppAudio)
if err != nil {
return nil, fmt.Errorf("audio: create encoder: %w", err)
}
pcmBuf := make([]int16, FrameSize) // 960 int16 samples
opusBuf := make([]byte, 4000) // Opus 输出缓冲4KB 足够单帧)
var frames [][]byte
for {
err := binary.Read(pcmReader, binary.LittleEndian, pcmBuf)
if err == io.EOF || err == io.ErrUnexpectedEOF {
// 最后一帧不足时已补零binary.Read 会读已有字节),直接编码
if err == io.ErrUnexpectedEOF {
n, encErr := enc.Encode(pcmBuf, opusBuf)
if encErr == nil && n > 0 {
frame := make([]byte, n)
copy(frame, opusBuf[:n])
frames = append(frames, frame)
}
}
break
}
if err != nil {
// ffmpeg 已结束context cancel 等),读取结束
break
}
n, err := enc.Encode(pcmBuf, opusBuf)
if err != nil {
return nil, fmt.Errorf("audio: opus encode: %w", err)
}
frame := make([]byte, n)
copy(frame, opusBuf[:n])
frames = append(frames, frame)
}
// 排空 stderr 避免 ffmpeg 阻塞
io.Copy(io.Discard, stderrPipe)
if err := cmd.Wait(); err != nil {
// context 超时导致的退出不视为错误(已有 frames 可以播放)
if ffmpegCtx.Err() == nil {
return nil, fmt.Errorf("audio: ffmpeg exit: %w", err)
}
}
if len(frames) == 0 {
return nil, fmt.Errorf("audio: no frames produced from %s", truncateURL(audioURL))
}
return frames, nil
}
// truncateURL 截断 URL 用于日志,避免输出带签名的完整 URL。
func truncateURL(u string) string {
if len(u) > 80 {
return u[:80] + "..."
}
return u
}

View File

@ -1,66 +0,0 @@
package audio
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// opusJSON 是 Django 预转码上传的 Opus JSON 文件结构。
type opusJSON struct {
SampleRate int `json:"sample_rate"`
Channels int `json:"channels"`
FrameDurationMs int `json:"frame_duration_ms"`
Frames []string `json:"frames"` // base64 编码的 Opus 帧
}
// FetchOpusFrames 从 OSS 下载预转码的 Opus JSON 文件,解析为原始帧列表。
// 跳过 ffmpeg 实时转码,大幅降低 CPU 消耗和首帧延迟。
func FetchOpusFrames(ctx context.Context, opusURL string) ([][]byte, error) {
httpCtx, cancel := context.WithTimeout(ctx, 60*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(httpCtx, http.MethodGet, opusURL, nil)
if err != nil {
return nil, fmt.Errorf("audio: build opus request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("audio: download opus json: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("audio: opus json status %d", resp.StatusCode)
}
body, err := io.ReadAll(io.LimitReader(resp.Body, 50*1024*1024)) // 50MB 上限
if err != nil {
return nil, fmt.Errorf("audio: read opus json: %w", err)
}
var data opusJSON
if err := json.Unmarshal(body, &data); err != nil {
return nil, fmt.Errorf("audio: parse opus json: %w", err)
}
if len(data.Frames) == 0 {
return nil, fmt.Errorf("audio: opus json has no frames")
}
frames := make([][]byte, 0, len(data.Frames))
for i, b64 := range data.Frames {
raw, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil, fmt.Errorf("audio: decode frame %d: %w", i, err)
}
frames = append(frames, raw)
}
return frames, nil
}

View File

@ -1,33 +0,0 @@
package config
import (
"log"
"os"
)
// Config 保存所有服务配置全部通过环境变量注入12-Factor App
type Config struct {
WSHost string
WSPort string
RTCBackendURL string
}
// Load 从环境变量读取配置,必填项缺失时直接 Fatal。
func Load() *Config {
backendURL := getEnv("HW_RTC_BACKEND_URL", "")
if backendURL == "" {
log.Fatal("config: HW_RTC_BACKEND_URL is required")
}
return &Config{
WSHost: getEnv("HW_WS_HOST", "0.0.0.0"),
WSPort: getEnv("HW_WS_PORT", "8888"),
RTCBackendURL: backendURL,
}
}
func getEnv(key, fallback string) string {
if v := os.Getenv(key); v != "" {
return v
}
return fallback
}

View File

@ -1,111 +0,0 @@
// Package connection 管理单个 ESP32 硬件 WebSocket 连接的状态。
package connection
import (
"encoding/json"
"fmt"
"sync"
"github.com/gorilla/websocket"
)
// Connection 保存单个硬件连接的状态,所有方法并发安全。
type Connection struct {
WS *websocket.Conn
DeviceID string // MAC 地址,来自 URL 参数 device-id
ClientID string // 来自 URL 参数 client-id
SessionID string // 握手后分配的会话 ID
mu sync.Mutex
handshaked bool // 是否已完成 hello 握手
isPlaying bool
abortCh chan struct{} // close(abortCh) 通知流控 goroutine 中止播放
writeMu sync.Mutex // gorilla/websocket 写操作不并发安全,需独立锁
}
// New 创建新连接对象。
func New(ws *websocket.Conn, deviceID, clientID string) *Connection {
return &Connection{
WS: ws,
DeviceID: deviceID,
ClientID: clientID,
}
}
// Handshake 完成 hello 握手,存储 session_id。
func (c *Connection) Handshake(sessionID string) {
c.mu.Lock()
defer c.mu.Unlock()
c.SessionID = sessionID
c.handshaked = true
}
// IsHandshaked 返回连接是否已完成 hello 握手。
func (c *Connection) IsHandshaked() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.handshaked
}
// SendCmd 向硬件发送控制指令,并发安全。
func (c *Connection) SendCmd(action string, params any) error {
return c.SendJSON(map[string]any{
"type": "cmd",
"action": action,
"params": params,
})
}
// StartPlayback 开始新一轮播放,返回 abortCh 供流控 goroutine 监听。
// 若已在播放,先中止上一轮再开始新的。
func (c *Connection) StartPlayback() <-chan struct{} {
c.mu.Lock()
defer c.mu.Unlock()
// 中止上一轮播放(若有)
if c.isPlaying && c.abortCh != nil {
close(c.abortCh)
}
c.abortCh = make(chan struct{})
c.isPlaying = true
return c.abortCh
}
// StopPlayback 结束播放状态。
func (c *Connection) StopPlayback() {
c.mu.Lock()
defer c.mu.Unlock()
c.isPlaying = false
}
// IsPlaying 返回当前是否正在播放。
func (c *Connection) IsPlaying() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.isPlaying
}
// SendJSON 序列化 v 并以文本帧发送给设备,并发安全。
func (c *Connection) SendJSON(v any) error {
data, err := json.Marshal(v)
if err != nil {
return fmt.Errorf("connection: marshal json: %w", err)
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.WS.WriteMessage(websocket.TextMessage, data)
}
// SendBinary 以二进制帧发送 Opus 数据,并发安全。
func (c *Connection) SendBinary(data []byte) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return c.WS.WriteMessage(websocket.BinaryMessage, data)
}
// Close 关闭底层 WebSocket 连接。
func (c *Connection) Close() {
c.WS.Close()
}

View File

@ -1,177 +0,0 @@
package connection_test
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/qy/hw-ws-service/internal/connection"
)
// makeWSPair creates a real WebSocket pair for testing.
// Returns the server-side conn (what our code uses) and the client-side conn
// (what simulates the hardware). Call cleanup() after the test.
func makeWSPair(t *testing.T) (svrWS *websocket.Conn, cliWS *websocket.Conn, cleanup func()) {
t.Helper()
ch := make(chan *websocket.Conn, 1)
done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
up := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
c, err := up.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
ch <- c
<-done // hold handler open until cleanup
}))
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
cli, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
close(done)
srv.Close()
t.Fatalf("dial error: %v", err)
}
svr := <-ch
return svr, cli, func() {
close(done)
svr.Close()
cli.Close()
srv.Close()
}
}
func TestConnection_InitialState(t *testing.T) {
svrWS, _, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "AA:BB:CC:DD:EE:FF", "client-uuid")
if conn.DeviceID != "AA:BB:CC:DD:EE:FF" {
t.Errorf("DeviceID = %q", conn.DeviceID)
}
if conn.ClientID != "client-uuid" {
t.Errorf("ClientID = %q", conn.ClientID)
}
if conn.IsPlaying() {
t.Error("new connection should not be playing")
}
}
func TestConnection_StartStopPlayback(t *testing.T) {
svrWS, _, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
ch := conn.StartPlayback()
if ch == nil {
t.Fatal("StartPlayback should return a non-nil channel")
}
if !conn.IsPlaying() {
t.Error("IsPlaying should be true after StartPlayback")
}
// Channel must still be open
select {
case <-ch:
t.Error("abortCh should not be closed yet")
default:
}
conn.StopPlayback()
if conn.IsPlaying() {
t.Error("IsPlaying should be false after StopPlayback")
}
}
// TestConnection_StartPlayback_AbortsOld verifies that calling StartPlayback a second
// time closes the previous abort channel, stopping any in-progress streaming.
func TestConnection_StartPlayback_AbortsOld(t *testing.T) {
svrWS, _, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
ch1 := conn.StartPlayback()
ch2 := conn.StartPlayback() // should close ch1
// ch1 must be closed now
select {
case <-ch1:
// expected
case <-time.After(100 * time.Millisecond):
t.Error("first abortCh should be closed by second StartPlayback call")
}
// ch2 must still be open
select {
case <-ch2:
t.Error("second abortCh should not be closed yet")
default:
}
}
// TestConnection_SendJSON verifies JSON messages are delivered to the client.
func TestConnection_SendJSON(t *testing.T) {
svrWS, cliWS, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
if err := conn.SendJSON(map[string]string{"type": "tts", "state": "start"}); err != nil {
t.Fatalf("SendJSON error: %v", err)
}
cliWS.SetReadDeadline(time.Now().Add(2 * time.Second))
msgType, data, err := cliWS.ReadMessage()
if err != nil {
t.Fatalf("client read error: %v", err)
}
if msgType != websocket.TextMessage {
t.Errorf("message type = %d, want TextMessage (%d)", msgType, websocket.TextMessage)
}
var got map[string]string
if err := json.Unmarshal(data, &got); err != nil {
t.Fatalf("json.Unmarshal error: %v", err)
}
if got["type"] != "tts" {
t.Errorf("type = %q, want %q", got["type"], "tts")
}
if got["state"] != "start" {
t.Errorf("state = %q, want %q", got["state"], "start")
}
}
// TestConnection_SendBinary verifies binary (Opus) frames are delivered to the client.
func TestConnection_SendBinary(t *testing.T) {
svrWS, cliWS, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
payload := []byte{0x01, 0x02, 0x03, 0x04}
if err := conn.SendBinary(payload); err != nil {
t.Fatalf("SendBinary error: %v", err)
}
cliWS.SetReadDeadline(time.Now().Add(2 * time.Second))
msgType, data, err := cliWS.ReadMessage()
if err != nil {
t.Fatalf("client read error: %v", err)
}
if msgType != websocket.BinaryMessage {
t.Errorf("message type = %d, want BinaryMessage (%d)", msgType, websocket.BinaryMessage)
}
if string(data) != string(payload) {
t.Errorf("payload = %v, want %v", data, payload)
}
}

View File

@ -1,13 +0,0 @@
package handler
import (
"log"
"github.com/qy/hw-ws-service/internal/connection"
)
// HandleAbort 处理硬件发来的 {"type":"abort"} 指令,中止当前播放。
func HandleAbort(conn *connection.Connection) {
log.Printf("[abort][%s] stopping playback", conn.DeviceID)
conn.StopPlayback()
}

View File

@ -1,63 +0,0 @@
package handler
import (
"time"
"github.com/qy/hw-ws-service/internal/audio"
"github.com/qy/hw-ws-service/internal/connection"
)
// SendOpusStream 将 Opus 帧列表按 60ms/帧的节奏流控发送给硬件。
//
// 流控策略:
// 1. 预缓冲:前 PreBufferCount 帧立即发送,减少硬件首帧延迟
// 2. 时序流控:按 (帧序号 × 60ms) 计算期望发送时间select 等待
// 3. 打断:监听 abortCh收到关闭信号立即返回
func SendOpusStream(conn *connection.Connection, frames [][]byte, abortCh <-chan struct{}) {
if len(frames) == 0 {
return
}
startTime := time.Now()
playedMs := 0
// 阶段1预缓冲快速发送前 N 帧
pre := audio.PreBufferCount
if pre > len(frames) {
pre = len(frames)
}
for _, f := range frames[:pre] {
select {
case <-abortCh:
return
default:
}
conn.SendBinary(f) //nolint:errcheck // 连接断开时下一次 ReadMessage 会返回错误
}
playedMs = pre * audio.FrameDurationMs
// 阶段2时序流控
for _, f := range frames[pre:] {
expectedAt := startTime.Add(time.Duration(playedMs) * time.Millisecond)
delay := time.Until(expectedAt)
if delay > 0 {
select {
case <-time.After(delay):
// 到达预期发送时间,继续
case <-abortCh:
return
}
}
// delay <= 0处理比预期慢追赶进度直接发送
select {
case <-abortCh:
return
default:
}
conn.SendBinary(f) //nolint:errcheck
playedMs += audio.FrameDurationMs
}
}

View File

@ -1,195 +0,0 @@
package handler_test
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/qy/hw-ws-service/internal/audio"
"github.com/qy/hw-ws-service/internal/connection"
"github.com/qy/hw-ws-service/internal/handler"
)
// makeWSPair creates a real WebSocket pair for testing.
// svrWS is the server side (used by our Connection), cliWS simulates the hardware.
func makeWSPair(t *testing.T) (svrWS *websocket.Conn, cliWS *websocket.Conn, cleanup func()) {
t.Helper()
ch := make(chan *websocket.Conn, 1)
done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
up := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
c, err := up.Upgrade(w, r, nil)
if err != nil {
t.Logf("upgrade error: %v", err)
return
}
ch <- c
<-done
}))
wsURL := "ws" + strings.TrimPrefix(srv.URL, "http")
cli, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
close(done)
srv.Close()
t.Fatalf("dial error: %v", err)
}
svr := <-ch
return svr, cli, func() {
close(done)
svr.Close()
cli.Close()
srv.Close()
}
}
// makeFrames creates n fake Opus frames of 4 bytes each.
func makeFrames(n int) [][]byte {
frames := make([][]byte, n)
for i := range frames {
frames[i] = []byte{byte(i), byte(i >> 8), 0x00, 0xff}
}
return frames
}
// TestSendOpusStream_Empty verifies that an empty frame list returns immediately.
func TestSendOpusStream_Empty(t *testing.T) {
svrWS, _, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
abort := make(chan struct{})
done := make(chan struct{})
go func() {
handler.SendOpusStream(conn, nil, abort)
close(done)
}()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("SendOpusStream did not return immediately for empty frames")
}
}
// TestSendOpusStream_AllFramesSent verifies that all frames reach the client.
// Uses PreBufferCount+2 frames so both pre-buffer and timed paths are exercised.
func TestSendOpusStream_AllFramesSent(t *testing.T) {
svrWS, cliWS, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
totalFrames := audio.PreBufferCount + 2 // 3 pre-buffer + 2 timed
frames := makeFrames(totalFrames)
abort := make(chan struct{})
senderDone := make(chan struct{})
go func() {
handler.SendOpusStream(conn, frames, abort)
close(senderDone)
}()
// Read all frames from the client side (simulates hardware receiving)
received := 0
cliWS.SetReadDeadline(time.Now().Add(10 * time.Second))
for received < totalFrames {
msgType, _, err := cliWS.ReadMessage()
if err != nil {
t.Fatalf("client read error after %d frames: %v", received, err)
}
if msgType == websocket.BinaryMessage {
received++
}
}
select {
case <-senderDone:
case <-time.After(2 * time.Second):
t.Fatal("SendOpusStream did not finish after all frames were sent")
}
if received != totalFrames {
t.Errorf("received %d frames, want %d", received, totalFrames)
}
}
// TestSendOpusStream_Abort verifies that closing abortCh stops streaming early.
func TestSendOpusStream_Abort(t *testing.T) {
svrWS, _, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
// Many frames so timing control is active (pre-buffer finishes quickly,
// then the time.After select can receive the abort signal)
frames := makeFrames(100)
abort := make(chan struct{})
senderDone := make(chan struct{})
go func() {
handler.SendOpusStream(conn, frames, abort)
close(senderDone)
}()
// Close abort after pre-buffer has had time to finish but before timed frames complete
time.Sleep(20 * time.Millisecond)
close(abort)
select {
case <-senderDone:
// SendOpusStream returned early — correct behaviour
case <-time.After(2 * time.Second):
t.Fatal("SendOpusStream did not abort within 2s after closing abortCh")
}
}
// TestSendOpusStream_PreBufferOnly verifies frames <= PreBufferCount are all sent
// without entering the timed loop (should finish nearly instantly).
func TestSendOpusStream_PreBufferOnly(t *testing.T) {
svrWS, cliWS, cleanup := makeWSPair(t)
defer cleanup()
conn := connection.New(svrWS, "dev1", "cli1")
frames := makeFrames(audio.PreBufferCount) // exactly the pre-buffer count
abort := make(chan struct{})
start := time.Now()
senderDone := make(chan struct{})
go func() {
handler.SendOpusStream(conn, frames, abort)
close(senderDone)
}()
received := 0
cliWS.SetReadDeadline(time.Now().Add(3 * time.Second))
for received < len(frames) {
msgType, _, err := cliWS.ReadMessage()
if err != nil {
t.Fatalf("read error: %v", err)
}
if msgType == websocket.BinaryMessage {
received++
}
}
select {
case <-senderDone:
case <-time.After(time.Second):
t.Fatal("sender did not finish")
}
elapsed := time.Since(start)
// Pre-buffer frames should not wait on the timer; allow 200ms for overhead
if elapsed > 200*time.Millisecond {
t.Errorf("pre-buffer-only send took too long: %v (want < 200ms)", elapsed)
}
}

View File

@ -1,45 +0,0 @@
package handler
import (
"crypto/rand"
"encoding/json"
"fmt"
"log"
"strings"
"github.com/qy/hw-ws-service/internal/connection"
)
// helloMessage 是硬件发来的 hello 握手消息。
type helloMessage struct {
MAC string `json:"mac"`
}
// HandleHello 处理硬件的 hello 握手消息。
// 校验 MAC 地址,分配 session_id返回握手响应。
func HandleHello(conn *connection.Connection, raw []byte) error {
var msg helloMessage
if err := json.Unmarshal(raw, &msg); err != nil {
return fmt.Errorf("hello: invalid json: %w", err)
}
// MAC 地址与 URL 参数不一致时记录警告,但不拒绝连接
if msg.MAC != "" && !strings.EqualFold(msg.MAC, conn.DeviceID) {
log.Printf("[hello][%s] MAC mismatch: url=%s body=%s", conn.DeviceID, conn.DeviceID, msg.MAC)
}
sessionID := newSessionID()
conn.Handshake(sessionID)
return conn.SendJSON(map[string]string{
"type": "hello",
"status": "ok",
"session_id": sessionID,
})
}
func newSessionID() string {
b := make([]byte, 4)
rand.Read(b) //nolint:errcheck // crypto/rand.Read 在标准库中不会返回错误
return fmt.Sprintf("%x", b)
}

View File

@ -1,86 +0,0 @@
package handler
import (
"context"
"log"
"time"
"github.com/qy/hw-ws-service/internal/audio"
"github.com/qy/hw-ws-service/internal/connection"
"github.com/qy/hw-ws-service/internal/rtcclient"
)
// HandleStory 处理硬件发来的 {"type":"story"} 指令。
// 在独立 goroutine 中调用,不阻塞消息读取循环。
func HandleStory(conn *connection.Connection, client *rtcclient.Client) {
tag := "[story][" + conn.DeviceID + "]"
// 整个故事播放流程最长允许 10 分钟
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
// 1. 通知硬件TTS 开始
if err := conn.SendJSON(map[string]string{"type": "tts", "state": "start"}); err != nil {
log.Printf("%s send start failed: %v", tag, err)
return
}
// 确保异常退出时也发送 stop避免硬件卡住
defer func() {
conn.StopPlayback()
conn.SendJSON(map[string]string{"type": "tts", "state": "stop"}) //nolint:errcheck
}()
// 2. 调用 RTC 后端获取故事
story, err := client.FetchStoryByMAC(ctx, conn.DeviceID)
if err != nil {
log.Printf("%s fetch story error: %v", tag, err)
return
}
if story == nil {
log.Printf("%s no story available", tag)
return
}
log.Printf("%s playing: %s", tag, story.Title)
// 3. 获取 Opus 帧:优先使用预转码数据,否则实时 ffmpeg 转码
var frames [][]byte
if story.OpusURL != "" {
frames, err = audio.FetchOpusFrames(ctx, story.OpusURL)
if err != nil {
log.Printf("%s fetch pre-converted opus failed, fallback to ffmpeg: %v", tag, err)
frames = nil // 确保 fallback
} else {
log.Printf("%s loaded %d pre-converted frames (~%.1fs)", tag, len(frames),
float64(len(frames)*audio.FrameDurationMs)/1000)
}
}
if frames == nil {
frames, err = audio.MP3URLToOpusFrames(ctx, story.AudioURL)
if err != nil {
log.Printf("%s audio convert error: %v", tag, err)
return
}
log.Printf("%s converted %d frames (~%.1fs)", tag, len(frames),
float64(len(frames)*audio.FrameDurationMs)/1000)
}
// 4. 通知硬件:句子开始(发送故事标题)
if err := conn.SendJSON(map[string]any{
"type": "tts",
"state": "sentence_start",
"text": story.Title,
}); err != nil {
log.Printf("%s send sentence_start failed: %v", tag, err)
return
}
// 5. 开始播放,获取打断 channel
abortCh := conn.StartPlayback()
// 6. 流控推送 Opus 帧
SendOpusStream(conn, frames, abortCh)
log.Printf("%s playback finished", tag)
// defer 会发送 stop 并调用 StopPlayback
}

View File

@ -1,101 +0,0 @@
// Package rtcclient 封装对 RTC 后端 Django REST API 的 HTTP 调用。
package rtcclient
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
)
// StoryInfo 是 GET /api/v1/devices/stories/ 返回的故事信息。
type StoryInfo struct {
Title string `json:"title"`
AudioURL string `json:"audio_url"`
OpusURL string `json:"opus_url"` // 预转码 Opus JSON 地址,为空表示未转码
}
// Client 是 RTC 后端的 HTTP 客户端,复用连接池。
type Client struct {
baseURL string
httpClient *http.Client
}
// New 创建 ClientbaseURL 形如 "http://rtc-backend-svc:8000"。
func New(baseURL string) *Client {
return &Client{
baseURL: strings.TrimRight(baseURL, "/"),
httpClient: &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 50,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
},
// 限制重定向次数,防止无限跳转
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 3 {
return errors.New("rtcclient: too many redirects")
}
return nil
},
},
}
}
// rtcResponse 是 RTC 后端的统一响应结构。
type rtcResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data json.RawMessage `json:"data"`
}
// FetchStoryByMAC 通过设备 MAC 地址获取随机故事。
// 返回 nil, nil 表示设备/用户/故事不存在(非错误,调用方直接跳过)。
func (c *Client) FetchStoryByMAC(ctx context.Context, mac string) (*StoryInfo, error) {
url := c.baseURL + "/api/v1/devices/stories/"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("rtcclient: build request: %w", err)
}
q := req.URL.Query()
q.Set("mac_address", strings.ToUpper(mac))
req.URL.RawQuery = q.Encode()
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("rtcclient: request failed: %w", err)
}
defer resp.Body.Close()
// 404 表示设备/用户/故事不存在,不是服务器错误
if resp.StatusCode == http.StatusNotFound {
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("rtcclient: unexpected status %d", resp.StatusCode)
}
var result rtcResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("rtcclient: decode response: %w", err)
}
if result.Code != 0 {
return nil, nil // 业务错误(如暂无故事),返回 nil 让调用方处理
}
var story StoryInfo
if err := json.Unmarshal(result.Data, &story); err != nil {
return nil, fmt.Errorf("rtcclient: decode story: %w", err)
}
if story.Title == "" || story.AudioURL == "" {
return nil, nil
}
return &story, nil
}

View File

@ -1,142 +0,0 @@
package rtcclient_test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/qy/hw-ws-service/internal/rtcclient"
)
func successBody(title, audioURL string) []byte {
b, _ := json.Marshal(map[string]any{
"code": 0,
"message": "success",
"data": map[string]string{
"title": title,
"audio_url": audioURL,
},
})
return b
}
func errorBody(code int, msg string) []byte {
b, _ := json.Marshal(map[string]any{
"code": code,
"message": msg,
"data": nil,
})
return b
}
func TestFetchStoryByMAC_Success(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/v1/devices/stories/" {
t.Errorf("unexpected path: %s", r.URL.Path)
}
w.Header().Set("Content-Type", "application/json")
w.Write(successBody("小红帽", "https://example.com/story.mp3"))
}))
defer srv.Close()
client := rtcclient.New(srv.URL)
story, err := client.FetchStoryByMAC(context.Background(), "aa:bb:cc:dd:ee:ff")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if story == nil {
t.Fatal("expected story, got nil")
}
if story.Title != "小红帽" {
t.Errorf("title = %q, want %q", story.Title, "小红帽")
}
if story.AudioURL != "https://example.com/story.mp3" {
t.Errorf("audio_url = %q", story.AudioURL)
}
}
// TestFetchStoryByMAC_MACUppercase verifies the client always sends uppercase MAC.
func TestFetchStoryByMAC_MACUppercase(t *testing.T) {
var gotMAC string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotMAC = r.URL.Query().Get("mac_address")
w.Header().Set("Content-Type", "application/json")
w.Write(successBody("test", "https://example.com/t.mp3"))
}))
defer srv.Close()
client := rtcclient.New(srv.URL)
client.FetchStoryByMAC(context.Background(), "aa:bb:cc:dd:ee:ff") //nolint:errcheck
if gotMAC != "AA:BB:CC:DD:EE:FF" {
t.Errorf("MAC not uppercased: got %q, want %q", gotMAC, "AA:BB:CC:DD:EE:FF")
}
}
// TestFetchStoryByMAC_NotFound verifies that HTTP 404 returns (nil, nil).
func TestFetchStoryByMAC_NotFound(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}))
defer srv.Close()
client := rtcclient.New(srv.URL)
story, err := client.FetchStoryByMAC(context.Background(), "AA:BB:CC:DD:EE:FF")
if err != nil {
t.Fatalf("unexpected error for 404: %v", err)
}
if story != nil {
t.Errorf("expected nil story for 404, got %+v", story)
}
}
// TestFetchStoryByMAC_BusinessError verifies that code != 0 returns (nil, nil).
func TestFetchStoryByMAC_BusinessError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(errorBody(404, "暂无可播放的故事"))
}))
defer srv.Close()
client := rtcclient.New(srv.URL)
story, err := client.FetchStoryByMAC(context.Background(), "AA:BB:CC:DD:EE:FF")
if err != nil {
t.Fatalf("unexpected error for business error response: %v", err)
}
if story != nil {
t.Errorf("expected nil story for business error, got %+v", story)
}
}
// TestFetchStoryByMAC_ServerError verifies that HTTP 5xx returns an error.
func TestFetchStoryByMAC_ServerError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
client := rtcclient.New(srv.URL)
_, err := client.FetchStoryByMAC(context.Background(), "AA:BB:CC:DD:EE:FF")
if err == nil {
t.Error("expected error for HTTP 500, got nil")
}
}
// TestFetchStoryByMAC_ContextCanceled verifies that a canceled context returns an error.
func TestFetchStoryByMAC_ContextCanceled(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never respond — let the client time out
<-r.Context().Done()
}))
defer srv.Close()
ctx, cancel := context.WithCancel(context.Background())
cancel() // cancel immediately
client := rtcclient.New(srv.URL)
_, err := client.FetchStoryByMAC(ctx, "AA:BB:CC:DD:EE:FF")
if err == nil {
t.Error("expected error for canceled context, got nil")
}
}

View File

@ -1,277 +0,0 @@
// Package server 实现 WebSocket 服务器,管理硬件设备连接的生命周期。
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/qy/hw-ws-service/internal/connection"
"github.com/qy/hw-ws-service/internal/handler"
"github.com/qy/hw-ws-service/internal/rtcclient"
)
const (
// maxConnections 最大并发连接数,防止资源耗尽。
maxConnections = 500
// maxMessageBytes WebSocket 单条消息上限4KB防止内存耗尽攻击。
maxMessageBytes = 4 * 1024
// helloTimeout 握手超时:连接建立后必须在此时间内发送 hello否则断开。
helloTimeout = 10 * time.Second
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// IoT 设备无浏览器 Origin允许所有来源
CheckOrigin: func(r *http.Request) bool { return true },
}
// Server 管理所有活跃的设备连接。
type Server struct {
client *rtcclient.Client
httpServer *http.Server
mu sync.Mutex
conns map[string]*connection.Connection // key: DeviceID
wg sync.WaitGroup // 跟踪所有连接 goroutine
}
// New 创建 Serveraddr 形如 "0.0.0.0:8888"。
func New(addr string, client *rtcclient.Client) *Server {
s := &Server{
client: client,
conns: make(map[string]*connection.Connection),
}
mux := http.NewServeMux()
mux.HandleFunc("/xiaozhi/v1/healthz", s.handleStatus)
mux.HandleFunc("/xiaozhi/v1/", s.handleConn)
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
s.httpServer = &http.Server{
Addr: addr,
Handler: mux,
}
return s
}
// ListenAndServe 启动服务器,阻塞直到服务器关闭。
func (s *Server) ListenAndServe() error {
log.Printf("server: listening on %s", s.httpServer.Addr)
err := s.httpServer.ListenAndServe()
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
// Shutdown 优雅关闭:先停止接受新连接,再等待所有连接 goroutine 退出。
func (s *Server) Shutdown(ctx context.Context) {
log.Println("server: shutting down...")
s.httpServer.Shutdown(ctx) //nolint:errcheck
// 等待所有连接 goroutine 退出(由 ctx 超时兜底)
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("server: all connections closed gracefully")
case <-ctx.Done():
log.Println("server: shutdown timeout, forcing close")
}
}
// handleConn 处理单个 WebSocket 连接的完整生命周期。
// URL 格式:/xiaozhi/v1/?device-id=<MAC>&client-id=<UUID>
func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/xiaozhi/v1/healthz" {
s.handleStatus(w, r)
return
}
deviceID := r.URL.Query().Get("device-id")
clientID := r.URL.Query().Get("client-id")
if deviceID == "" {
http.Error(w, "missing device-id", http.StatusBadRequest)
return
}
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("server: upgrade failed for %s: %v", deviceID, err)
return
}
// 设置单条消息大小上限
ws.SetReadLimit(maxMessageBytes)
conn := connection.New(ws, deviceID, clientID)
if err := s.register(conn); err != nil {
log.Printf("server: register %s failed: %v", deviceID, err)
ws.Close()
return
}
s.wg.Add(1)
defer func() {
conn.StopPlayback()
s.unregister(deviceID)
ws.Close()
s.wg.Done()
log.Printf("server: device %s disconnected, active=%d", deviceID, s.activeCount())
}()
log.Printf("server: device %s connected, active=%d", deviceID, s.activeCount())
// 阶段1等待 hello 握手(超时 helloTimeout
ws.SetReadDeadline(time.Now().Add(helloTimeout)) //nolint:errcheck
if !s.waitForHello(conn) {
log.Printf("server: device %s hello timeout or failed", deviceID)
return
}
ws.SetReadDeadline(time.Time{}) //nolint:errcheck // 握手成功,取消读超时
log.Printf("server: device %s handshaked, session=%s", deviceID, conn.SessionID)
// 阶段2正常消息循环
for {
msgType, raw, err := ws.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
if !isNetworkClose(err) {
log.Printf("server: read error for %s: %v", deviceID, err)
}
}
return
}
// 只处理文本消息
if msgType != websocket.TextMessage {
continue
}
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
log.Printf("server: invalid json from %s: %v", deviceID, err)
continue
}
switch envelope.Type {
case "story":
go handler.HandleStory(conn, s.client)
case "abort":
handler.HandleAbort(conn)
default:
log.Printf("server: unhandled message type %q from %s", envelope.Type, deviceID)
}
}
}
// waitForHello 等待并处理第一条 hello 消息,成功返回 true。
func (s *Server) waitForHello(conn *connection.Connection) bool {
msgType, raw, err := conn.WS.ReadMessage()
if err != nil {
return false
}
if msgType != websocket.TextMessage {
log.Printf("server: device %s sent non-text as first message", conn.DeviceID)
return false
}
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(raw, &envelope); err != nil || envelope.Type != "hello" {
log.Printf("server: device %s first message is not hello (got %q)", conn.DeviceID, envelope.Type)
return false
}
if err := handler.HandleHello(conn, raw); err != nil {
log.Printf("server: device %s hello failed: %v", conn.DeviceID, err)
return false
}
return true
}
// register 注册连接,若同一设备已有连接则踢掉旧连接。
func (s *Server) register(conn *connection.Connection) error {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.conns) >= maxConnections {
return errors.New("server: max connections reached")
}
// 同一设备同时只允许一个连接
if old, exists := s.conns[conn.DeviceID]; exists {
log.Printf("server: kicking old connection for %s", conn.DeviceID)
old.Close()
}
s.conns[conn.DeviceID] = conn
return nil
}
// SendCmd 向指定设备发送控制指令。
// 若设备不在线或未握手,返回 error。
func (s *Server) SendCmd(deviceID, action string, params any) error {
s.mu.Lock()
conn, ok := s.conns[deviceID]
s.mu.Unlock()
if !ok {
return fmt.Errorf("server: device %s not connected", deviceID)
}
if !conn.IsHandshaked() {
return fmt.Errorf("server: device %s not handshaked", deviceID)
}
return conn.SendCmd(action, params)
}
func (s *Server) unregister(deviceID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, deviceID)
}
func (s *Server) activeCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.conns)
}
// handleStatus 返回服务状态和当前活跃连接数,用于部署后验证。
// GET /xiaozhi/v1/healthz → {"status":"ok","active_connections":N}
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
active := s.activeCount()
fmt.Fprintf(w, `{"status":"ok","active_connections":%d}`, active)
}
// isNetworkClose 判断是否为普通的网络关闭错误(不需要打印日志)。
func isNetworkClose(err error) bool {
var netErr *net.OpError
return errors.As(err, &netErr)
}

View File

@ -1,82 +0,0 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: hw-ws-service
labels:
app: hw-ws-service
spec:
replicas: 2
selector:
matchLabels:
app: hw-ws-service
# WebSocket 连接有状态,滚动更新时使用 Recreate 或 RollingUpdate + 优雅关闭
strategy:
type: RollingUpdate
rollingUpdate:
maxUnavailable: 0 # 始终保持至少 2 个 Pod 可用
maxSurge: 1
template:
metadata:
labels:
app: hw-ws-service
spec:
# 优雅关闭总时限90s服务内部等待 80s留 10s 缓冲)
terminationGracePeriodSeconds: 90
containers:
- name: hw-ws-service
image: ${CI_REGISTRY_IMAGE}/hw-ws-service:latest
imagePullPolicy: Always
ports:
- name: ws
containerPort: 8888
protocol: TCP
env:
- name: HW_WS_HOST
value: "0.0.0.0"
- name: HW_WS_PORT
value: "8888"
- name: HW_RTC_BACKEND_URL
# 集群内部直接访问 rtc-backend Service不走公网
value: "http://rtc-backend:8000"
lifecycle:
preStop:
exec:
# 等待 5s 让 LB/Ingress 将流量从本 Pod 摘除,再开始关闭
command: ["/bin/sh", "-c", "sleep 5"]
# 就绪探针TCP 握手成功才接流量
readinessProbe:
tcpSocket:
port: 8888
initialDelaySeconds: 3
periodSeconds: 5
failureThreshold: 3
# 存活探针:连续失败 3 次才重启(避免短暂抖动误杀)
livenessProbe:
tcpSocket:
port: 8888
initialDelaySeconds: 10
periodSeconds: 15
failureThreshold: 3
# 资源限制(根据实际负载调整)
resources:
requests:
cpu: "100m"
memory: "128Mi"
limits:
cpu: "500m"
memory: "512Mi"
# 优先调度到不同节点,避免单点故障
topologySpreadConstraints:
- maxSkew: 1
topologyKey: kubernetes.io/hostname
whenUnsatisfiable: DoNotSchedule
labelSelector:
matchLabels:
app: hw-ws-service

View File

@ -1,15 +0,0 @@
apiVersion: v1
kind: Service
metadata:
name: hw-ws-svc
labels:
app: hw-ws-service
spec:
type: ClusterIP
selector:
app: hw-ws-service
ports:
- name: websocket
port: 8888
targetPort: 8888
protocol: TCP

View File

@ -1,251 +0,0 @@
# hw_service_go 本地硬件通讯测试计划
> 目标:用浏览器模拟 ESP32 硬件,验证 `hw_service_go` WebSocket 服务能否正常接收指令、获取故事、推送 Opus 音频。
---
## 一、协议对比分析
### 1.1 小智xiaozhi-servervs 我们的服务
| 维度 | xiaozhi-server | hw_service_go本服务 |
|------|---------------|------------------------|
| **WebSocket URL** | `ws://host:port/xiaozhi/v1/?device-id=&client-id=` | 完全相同 |
| **连接参数** | `device-id`MAC`client-id`UUID| 完全相同 |
| **握手消息** | 需要发送 `hello` JSON | **不需要**,连上即用 |
| **触发指令** | `listen`(语音输入) | **只需发 `{"type":"story"}`** |
| **音频方向** | 双向(硬件上传语音 + 服务下发 TTS| **单向下行**(服务→硬件,推 Opus |
| **Opus 编解码** | 需要编码(麦克风)+ 解码(播放)| **只需解码**(浏览器只播放) |
| **认证** | token 参数 | **无需认证**(仅 device-id 校验) |
| **消息复杂度** | hello/listen/stt/llm/tts/mcp | **只有 tts 系列** |
### 1.2 我们服务的完整消息流
```
浏览器(模拟硬件) hw_service_go Django
│ │ │
│── WS 连接 ──────────────────────────→│ │
│ ?device-id=AA:BB:CC:DD:EE:FF │ │
&client-id=test-001 │ │
│ │ │
│── {"type":"story"} ────────────────→ │ │
│ │── GET /api/v1/devices/ │
│ │ stories/?mac_address │
│ │ =AA:BB:CC:DD:EE:FF → │
│ │ │
│← {"type":"tts","state":"start"} ───── │ │
│ │← {title, audio_url} ── │
│ │ │
│ │── 下载 MP3 ──────────→ CDN
│ │← MP3 二进制流 ─────── │
│ │ │
│ │ ffmpeg 转码 PCM→Opus │
│ │ │
│← {"type":"tts","state":"sentence_start","text":"故事标题"} ─── │
│ │ │
│← [Opus帧1 二进制] ─────────────────── │ 60ms/帧前3帧预缓冲 │
│← [Opus帧2 二进制] ─────────────────── │ │
│← [Opus帧3 二进制] ─────────────────── │ │
│← [Opus帧N 二进制] ─────────────────── │ 按时序流控发送 │
│ │ │
│← {"type":"tts","state":"stop"} ─────── │ │
│ │ │
```
### 1.3 Opus 音频参数(与小智完全一致)
| 参数 | 值 |
|------|----|
| 采样率 | 16000 Hz |
| 声道 | 1单声道|
| 帧时长 | 60ms |
| 每帧采样数 | 960 |
| 编码器 | libopusWASM |
---
## 二、前置条件检查
在开始测试之前,需要满足以下条件:
### 2.1 服务运行状态
- [ ] Django 后端运行在 `http://localhost:8000`
- [ ] `hw_service_go` 运行在 `ws://localhost:8888`
- [ ] 健康检查通过:`curl http://localhost:8888/healthz` 返回 200
### 2.2 Django 数据准备(关键!)
测试必须使用一个在 Django 数据库中**真实存在**的设备 MAC 地址。
Django API 查询逻辑(`GET /api/v1/devices/stories/?mac_address=<MAC>`
- 根据 MAC 查找设备 → 找到设备绑定的用户 → 查找该用户的故事
- 任何一步缺失,服务返回 `null`,硬件不会播放任何内容
**需要在 Django Admin 或 API 中准备:**
1. 注册一个设备,记下其 MAC 地址(格式:`AA:BB:CC:DD:EE:FF`
2. 该设备需已绑定用户owner
3. 该用户名下有至少一个故事(有 `audio_url` 字段)
> **快速验证**`curl "http://localhost:8000/api/v1/devices/stories/?mac_address=你的MAC"` 应返回 `{"code":0,"data":{"title":"...","audio_url":"..."}}`
---
## 三、测试程序设计
### 3.1 技术选型
| 方案 | 优点 | 缺点 |
|------|------|------|
| **纯 HTML+JS推荐** | 零依赖,直接浏览器打开,与小智方案一致 | - |
| Python 脚本 | 简单但无法播放音频 | 无法验证音频播放端到端 |
| Go 命令行 | 需额外音频库 | 环境搭建复杂 |
**选择方案:纯 HTML+JS 单文件**,复用小智项目的 `libopus.js`WASM做解码。
### 3.2 文件结构
```
hw_service_go/test/
├── PLAN.md ← 本文件
├── test.html ← 测试主页面(待实现)
└── libopus.js ← 复制自小智项目Opus WASM 解码库)
```
`libopus.js` 来源:
```
/Users/maidong/Desktop/zyc/jikashe/xiaozhi-server/main/xiaozhi-server/test/libopus.js
```
### 3.3 测试页面功能
```
┌─────────────────────────────────────────────────────┐
│ hw_service_go 硬件通讯测试 │
├─────────────────────────────────────────────────────┤
│ 服务地址: [ws://localhost:8888/xiaozhi/v1/ ] │
│ device-id: [AA:BB:CC:DD:EE:FF ] │
│ client-id: [test-browser-001 ] [随机生成] │
├─────────────────────────────────────────────────────┤
│ [连接] [断开] 状态: ● 未连接 │
├─────────────────────────────────────────────────────┤
│ [▶ 触发故事播放] [⏹ 停止] │
├─────────────────────────────────────────────────────┤
│ 消息日志 [清空] │
│ ┌───────────────────────────────────────────────┐ │
│ │ [10:23:01] → 已连接 │ │
│ │ [10:23:02] → 发送: {"type":"story"} │ │
│ │ [10:23:02] ← 收到: {"type":"tts","state":"start"} │
│ │ [10:23:03] ← 收到: {"type":"tts","state":"sentence_start","text":"..."} │
│ │ [10:23:03] ← 收到: [Binary] Opus帧 #1 (38 bytes) │
│ │ [10:23:03] 🔊 开始播放... │ │
│ │ [10:23:15] ← 收到: {"type":"tts","state":"stop"} │
│ │ [10:23:15] 🔊 播放完毕 │ │
│ └───────────────────────────────────────────────┘ │
│ │
│ 统计: 已收到 85 个Opus帧 | 约 5.1s 音频 │
└─────────────────────────────────────────────────────┘
```
### 3.4 核心实现逻辑
#### 连接流程
```javascript
const ws = new WebSocket(
`ws://localhost:8888/xiaozhi/v1/?device-id=${deviceId}&client-id=${clientId}`
);
ws.binaryType = 'arraybuffer';
```
#### 触发故事
```javascript
ws.send(JSON.stringify({ type: 'story' }));
```
#### 接收消息处理
```javascript
ws.onmessage = (event) => {
if (event.data instanceof ArrayBuffer) {
// 二进制Opus 音频帧
const opusFrame = new Uint8Array(event.data);
const pcm = opusDecoder.decode(opusFrame); // Int16Array
schedulePlay(pcm); // 排队播放
} else {
// 文本:控制消息
const msg = JSON.parse(event.data);
handleTtsControl(msg); // 处理 start/sentence_start/stop
}
};
```
#### Opus 解码 + 播放(与小智方案完全一致)
- 使用 `libopus.js`WASM初始化解码器16000Hz单声道
- 解码:`Int16Array``Float32Array`
- 使用 `AudioContext` + `AudioBufferSourceNode` 按时序排队播放
- 使用 `BlockingQueue` 缓冲帧,避免播放卡顿
---
## 四、测试用例
### Case 1基础连接测试
- 输入正确的 `device-id``client-id`
- 期望WebSocket 连接建立成功,状态变为"已连接"
### Case 2故事触发测试
- 发送 `{"type":"story"}`
- 期望:
1. 收到 `{"type":"tts","state":"start"}`
2. 收到 `{"type":"tts","state":"sentence_start","text":"<故事标题>"}`
3. 陆续收到多个二进制 Opus 帧
4. 最终收到 `{"type":"tts","state":"stop"}`
### Case 3音频播放验证
- 期望:浏览器实际播放出故事音频,声音正常无杂音、无卡顿
### Case 4设备不存在测试
- 使用未注册的 MAC 地址
- 期望:发送故事指令后立即收到 `{"type":"tts","state":"stop"}`(服务侧找不到故事,直接结束)
### Case 5重复触发测试
- 播放过程中再次点击"触发故事"
- 期望旧播放被打断新故事从头开始hw_service_go 的 `StartPlayback` 会 close 旧 abortCh
### Case 6断线重连测试
- 连接后断开,再重新连接
- 期望:可以正常重新发起故事请求
---
## 五、实现步骤
1. **复制 libopus.js**
```bash
cp /Users/maidong/Desktop/zyc/jikashe/xiaozhi-server/main/xiaozhi-server/test/libopus.js \
/Users/maidong/Desktop/zyc/qy_gitlab/rtc_backend/hw_service_go/test/
```
2. **编写 test.html**(单文件,嵌入所有 JS
- 参考小智 `StreamingContext.js``BlockingQueue.js` 的逻辑
- 去掉录音/编码部分(我们只需解码)
- 保留 Opus 解码 + AudioContext 播放部分
- 添加连接配置 UI 和消息日志面板
3. **浏览器打开测试**
```
直接用浏览器打开 test.htmlfile:// 协议即可)
```
> 注意macOS Safari 对 WebSocket + file:// 可能有限制,建议用 Chrome
4. **按测试用例逐项验证**
---
## 六、已知限制与注意事项
| 问题 | 说明 |
|------|------|
| **device-id 必须真实存在** | MAC 地址若未在 Django 数据库注册,服务会静默返回无故事 |
| **ffmpeg 必须安装** | `hw_service_go` 的转码依赖系统 `ffmpeg`,需提前安装 |
| **audio_url 必须可访问** | 故事的 MP3 链接需要能从本机下载(阿里云 OSS 等) |
| **浏览器 AudioContext 限制** | 需要用户交互(点击)后才能创建 AudioContext不能自动播放 |
| **WASM 加载** | libopus.js 较大844KB首次加载需要等待约 1-2 秒 |

File diff suppressed because one or more lines are too long

View File

@ -1,32 +0,0 @@
========================================
hw_service_go 并发压力测试
========================================
目标地址: wss://qiyuan-rtc-api.airlabs.art/xiaozhi/v1/
总连接数: 100
触发故事: 10
建连速率: 20/s
测试时长: 1m0s
MAC 前缀: AA:BB:CC:DD
========================================
[2s] conns: 40/100 handshaked: 40 stories: 10 sent frames: 245 errors: 0 healthz: {"status": "ok"} [4s] conns: 79/100 handshaked: 79 stories: 10 sent frames: 575 errors: 0 healthz: {"status": "ok"}
所有连接已发起,等待 1m0s...
[6s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 909 errors: 0 healthz: {"status": "ok"} [8s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 1240 errors: 0 healthz: {"status": "ok"} [10s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 1575 errors: 0 healthz: {"status": "ok"} [12s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 1909 errors: 0 healthz: {"status": "ok"} [14s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 2240 errors: 0 healthz: {"status": "ok"} [16s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 2575 errors: 0 healthz: {"status": "ok"} [18s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 2909 errors: 0 healthz: {"status": "ok"} [20s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 3240 errors: 0 healthz: {"status": "ok"} [22s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 3575 errors: 0 healthz: {"status": "ok"} [24s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 3909 errors: 0 healthz: {"status": "ok"} [26s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 4240 errors: 0 healthz: {"status": "ok"} [28s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 4575 errors: 0 healthz: {"status": "ok"} [30s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 4909 errors: 0 healthz: {"status": "ok"} [32s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 5240 errors: 0 healthz: {"status": "ok"} [34s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 5575 errors: 0 healthz: {"status": "ok"} [36s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 5909 errors: 0 healthz: {"status": "ok"} [38s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 6240 errors: 0 healthz: {"status": "ok"} [40s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 6575 errors: 0 healthz: {"status": "ok"} [42s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 6909 errors: 0 healthz: {"status": "ok"} [44s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7240 errors: 0 healthz: {"status": "ok"} [46s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7575 errors: 0 healthz: {"status": "ok"} [48s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7909 errors: 0 healthz: {"status": "ok"} [50s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [52s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [54s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [56s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [58s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [1m0s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [1m2s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"} [1m4s] conns: 100/100 handshaked: 100 stories: 10 sent frames: 7960 errors: 0 healthz: {"status": "ok"}
测试时长到期,正在停止...
========== 测试报告 ==========
目标连接数: 100
连接尝试: 100
成功连接: 100
连接失败: 0
握手成功: 100
握手失败: 0
------------------------------
触发故事数: 10
收到 tts start: 10
收到 tts stop: 10
Opus 帧总数: 7960
平均帧数/故事: 796
首帧延迟(avg): 324ms
错误总数: 0
==============================

View File

@ -1,121 +0,0 @@
# hw_service_go 并发压力测试报告
> 测试时间2026-03-03
> 测试目标:`wss://qiyuan-rtc-api.airlabs.art/xiaozhi/v1/`
> Pod 配置:单 PodCPU 100m~500mlimitsMemory 128Mi~512Milimitsreplicas: 1
---
## 一、测试环境
| 项目 | 配置 |
|------|------|
| 服务 | hw_service_goWebSocket + Opus 音频推送) |
| 部署 | K8s 单 Pod1 副本 |
| CPU limits | 500m0.5 核) |
| Memory limits | 512Mi |
| 硬编码连接上限 | 500 |
| 测试工具 | Go 压测工具(`test/stress/main.go` |
| 测试客户端 | macOS从公网连接线上服务 |
---
## 二、测试结果
### 2.1 连接容量测试(空闲连接)
```
go run main.go -url wss://..../xiaozhi/v1/ -conns 200 -stories 0 -duration 30s
```
| 指标 | 结果 |
|------|------|
| 目标连接 | 200 |
| 成功连接 | 200 |
| 握手成功 | 200 |
| 错误 | 0 |
**结论200 个空闲连接毫无压力,内存不是瓶颈。**
### 2.2 并发播放压力测试
每个"活跃故事"会触发Django API 查询 → MP3 下载 → ffmpeg 转码 → Opus 编码 → WebSocket 推帧。
| 并发故事数 | 总连接 | 首帧延迟 | 帧数/故事 | 错误 | 状态 |
|-----------|--------|---------|----------|------|------|
| 2 | 10 | **2.0s** | 796 | 0 | 轻松 |
| 5 | 10 | **4.5s** | 796 | 0 | 正常 |
| 10 | 20 | **8.7s** | 796 | 0 | 吃力但稳 |
| 20 | 30 | **17.4s** | 796 | 0 | 极限 |
### 2.3 关键发现
1. **帧数始终稳定 796/故事** — 音频完整交付,零丢帧,服务可靠性极高
2. **首帧延迟线性增长** — 约 0.85s/并发,纯 CPU 瓶颈(多个 ffmpeg 进程争抢 0.5 核)
3. **Pod 未触发 OOMKill** — 512Mi 内存对 20 并发播放也够用
4. **全程零错误** — 无连接断开、无握手失败、无帧丢失
---
## 三、瓶颈分析
```
单个故事播放的资源消耗链路:
Django API (GET) → MP3 下载 (OSS) → ffmpeg 转码 (CPU密集) → Opus 编码 → WebSocket 推帧
主要瓶颈
每个并发故事启动一个 ffmpeg 子进程
多个 ffmpeg 共享 0.5 核 CPU
```
| 资源 | 是否瓶颈 | 说明 |
|------|---------|------|
| **CPU** | **是** | ffmpeg 转码是 CPU 密集型0.5 核被多个 ffmpeg 进程分时使用 |
| 内存 | 否 | 20 并发播放未触发 OOM512Mi 充足 |
| 网络 | 否 | Opus 帧约 4-7 KB/s/连接,带宽远未饱和 |
| 连接数 | 否 | 空闲连接 200+ 无压力,硬上限 500 |
---
## 四、容量结论
### 当前单 Pod0.5 核 CPU, 512Mi, 1 副本)
| 指标 | 数值 |
|------|------|
| 空闲连接上限 | **200+**(轻松) |
| 并发播放(体验好,首帧 < 5s | **~5 ** |
| 并发播放(可接受,首帧 < 10s | **~10 ** |
| 并发播放(极限,首帧 ~17s | **~20 个** |
| 瓶颈资源 | CPUffmpeg 转码) |
---
## 五、扩容建议
| 方案 | 变更 | 预估并发播放(首帧 < 10s | 成本 |
|------|------|------------------------|------|
| **提 CPU** | limits 500m → 1000m | ~20 个 | 低 |
| **加副本** | replicas 1 → 2 | ~10 个(负载均衡) | 中 |
| **两者都做** | 1000m CPU + 2 副本 | **~40 个** | 中 |
| 垂直扩容 | 2000m CPU + 1Gi 内存 | ~40 个 | 中 |
> **推荐方案**replicas: 2 + CPU limits: 1000m兼顾高可用与并发能力。
---
## 六、测试命令参考
```bash
cd hw_service_go/test/stress
# 空闲连接容量
go run main.go -url wss://TARGET/xiaozhi/v1/ -conns 200 -stories 0 -duration 30s
# 并发播放(逐步加压)
go run main.go -url wss://TARGET/xiaozhi/v1/ -conns 10 -stories 2 -duration 60s
go run main.go -url wss://TARGET/xiaozhi/v1/ -conns 10 -stories 5 -duration 60s
go run main.go -url wss://TARGET/xiaozhi/v1/ -conns 20 -stories 10 -duration 90s
go run main.go -url wss://TARGET/xiaozhi/v1/ -conns 30 -stories 20 -duration 120s
```

View File

@ -1,5 +0,0 @@
module stress
go 1.23
require github.com/gorilla/websocket v1.5.3

View File

@ -1,2 +0,0 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

View File

@ -1,379 +0,0 @@
// hw_service_go 并发压力测试工具
//
// 用法:
//
// go run main.go -conns 100 -stories 0 # 100 个空闲连接
// go run main.go -conns 50 -stories 10 # 50 连接10 个触发故事
// go run main.go -url wss://example.com/xiaozhi/v1/ -conns 50
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/gorilla/websocket"
)
// ── 命令行参数 ─────────────────────────────────────────────────
var (
flagURL = flag.String("url", "ws://localhost:8888/xiaozhi/v1/", "WebSocket 服务地址")
flagConns = flag.Int("conns", 100, "总连接数")
flagStories = flag.Int("stories", 10, "同时触发故事的连接数")
flagRamp = flag.Int("ramp", 20, "每秒建立的连接数")
flagDuration = flag.Duration("duration", 60*time.Second, "测试持续时间")
flagMACPrefix = flag.String("mac-prefix", "AA:BB:CC:DD", "模拟 MAC 前缀")
)
// ── 统计指标原子操作goroutine 安全) ──────────────────────
type stats struct {
connAttempts atomic.Int64
connSuccess atomic.Int64
connFailed atomic.Int64
handshaked atomic.Int64
handshakeFail atomic.Int64
storySent atomic.Int64
ttsStart atomic.Int64
ttsStop atomic.Int64
opusFrames atomic.Int64
errors atomic.Int64
firstFrameNs atomic.Int64 // 所有设备首帧延迟总和(纳秒),用于算均值
firstFrameCnt atomic.Int64 // 收到首帧的设备数
}
var s stats
// ── 模拟设备 ────────────────────────────────────────────────
type device struct {
id int
mac string
clientID string
ws *websocket.Conn
triggerStory bool
}
func newDevice(id int, macPrefix string, triggerStory bool) *device {
hi := byte((id >> 8) & 0xFF)
lo := byte(id & 0xFF)
mac := fmt.Sprintf("%s:%02X:%02X", macPrefix, hi, lo)
return &device{
id: id,
mac: mac,
clientID: fmt.Sprintf("stress-%d", id),
triggerStory: triggerStory,
}
}
func (d *device) run(baseURL string, wg *sync.WaitGroup, done <-chan struct{}) {
defer wg.Done()
s.connAttempts.Add(1)
// 1. 建立 WebSocket 连接
u, err := url.Parse(baseURL)
if err != nil {
log.Printf("[dev-%d] invalid URL: %v", d.id, err)
s.connFailed.Add(1)
return
}
q := u.Query()
q.Set("device-id", d.mac)
q.Set("client-id", d.clientID)
u.RawQuery = q.Encode()
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
ws, _, err := dialer.Dial(u.String(), nil)
if err != nil {
log.Printf("[dev-%d] connect failed: %v", d.id, err)
s.connFailed.Add(1)
return
}
d.ws = ws
s.connSuccess.Add(1)
defer ws.Close()
// 2. 发送 hello 握手
helloMsg, _ := json.Marshal(map[string]string{
"type": "hello",
"mac": d.mac,
})
if err := ws.WriteMessage(websocket.TextMessage, helloMsg); err != nil {
log.Printf("[dev-%d] hello send failed: %v", d.id, err)
s.handshakeFail.Add(1)
return
}
// 3. 等待 hello 响应5s 超时)
ws.SetReadDeadline(time.Now().Add(5 * time.Second))
_, msg, err := ws.ReadMessage()
if err != nil {
log.Printf("[dev-%d] hello read failed: %v", d.id, err)
s.handshakeFail.Add(1)
return
}
ws.SetReadDeadline(time.Time{}) // 清除超时
var helloResp struct {
Type string `json:"type"`
Status string `json:"status"`
}
if err := json.Unmarshal(msg, &helloResp); err != nil || helloResp.Type != "hello" || helloResp.Status != "ok" {
log.Printf("[dev-%d] hello failed: %s", d.id, string(msg))
s.handshakeFail.Add(1)
return
}
s.handshaked.Add(1)
// 4. 如果被选为活跃设备,触发故事
var storySentTime time.Time
var gotFirstFrame bool
if d.triggerStory {
storyMsg, _ := json.Marshal(map[string]string{"type": "story"})
if err := ws.WriteMessage(websocket.TextMessage, storyMsg); err != nil {
log.Printf("[dev-%d] story send failed: %v", d.id, err)
s.errors.Add(1)
} else {
s.storySent.Add(1)
storySentTime = time.Now()
}
}
// 5. 消息接收循环
msgCh := make(chan struct{}, 1) // 用于通知有新消息
go func() {
for {
msgType, data, err := ws.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
select {
case <-done:
// 正常关闭,不算错误
default:
s.errors.Add(1)
}
}
return
}
if msgType == websocket.BinaryMessage {
// Opus 帧
s.opusFrames.Add(1)
if d.triggerStory && !gotFirstFrame && !storySentTime.IsZero() {
gotFirstFrame = true
latency := time.Since(storySentTime)
s.firstFrameNs.Add(latency.Nanoseconds())
s.firstFrameCnt.Add(1)
}
_ = data // 不需要解码,只计数
} else {
// 文本消息
var envelope struct {
Type string `json:"type"`
State string `json:"state"`
}
if json.Unmarshal(data, &envelope) == nil {
if envelope.Type == "tts" {
switch envelope.State {
case "start":
s.ttsStart.Add(1)
case "stop":
s.ttsStop.Add(1)
}
}
}
}
select {
case msgCh <- struct{}{}:
default:
}
}
}()
// 6. 等待测试结束
<-done
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
)
}
// ── healthz 查询 ─────────────────────────────────────────────
func queryHealthz(baseURL string) string {
// 从 ws:// URL 推导 http:// URL
u, err := url.Parse(baseURL)
if err != nil {
return "N/A"
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
}
// 去掉 /xiaozhi/v1/ 路径,换成 /healthz
u.Path = "/healthz"
u.RawQuery = ""
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(u.String())
if err != nil {
return "N/A"
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return strings.TrimSpace(string(body))
}
// ── 主函数 ──────────────────────────────────────────────────
func main() {
flag.Parse()
if *flagStories > *flagConns {
*flagStories = *flagConns
}
fmt.Println("========================================")
fmt.Println(" hw_service_go 并发压力测试")
fmt.Println("========================================")
fmt.Printf(" 目标地址: %s\n", *flagURL)
fmt.Printf(" 总连接数: %d\n", *flagConns)
fmt.Printf(" 触发故事: %d\n", *flagStories)
fmt.Printf(" 建连速率: %d/s\n", *flagRamp)
fmt.Printf(" 测试时长: %s\n", *flagDuration)
fmt.Printf(" MAC 前缀: %s\n", *flagMACPrefix)
fmt.Println("========================================")
fmt.Println()
done := make(chan struct{})
var wg sync.WaitGroup
// 信号处理Ctrl+C 提前结束
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
fmt.Println("\n收到退出信号正在停止...")
close(done)
}()
// 实时统计输出
startTime := time.Now()
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
elapsed := time.Since(startTime).Truncate(time.Second)
health := queryHealthz(*flagURL)
fmt.Printf("\r\033[K[%s] conns: %d/%d handshaked: %d stories: %d sent frames: %d errors: %d healthz: %s",
elapsed,
s.connSuccess.Load(), *flagConns,
s.handshaked.Load(),
s.storySent.Load(),
s.opusFrames.Load(),
s.errors.Load(),
health,
)
case <-done:
return
}
}
}()
// 按 ramp 速率建立连接
rampInterval := time.Second / time.Duration(*flagRamp)
for i := 1; i <= *flagConns; i++ {
select {
case <-done:
goto waitDone
default:
}
triggerStory := i <= *flagStories
dev := newDevice(i, *flagMACPrefix, triggerStory)
wg.Add(1)
go dev.run(*flagURL, &wg, done)
// 控制建连速率
if i < *flagConns {
time.Sleep(rampInterval)
}
}
// 所有连接建立后,等待 duration 到期
fmt.Printf("\n所有连接已发起等待 %s...\n", *flagDuration)
select {
case <-time.After(*flagDuration):
fmt.Println("\n测试时长到期正在停止...")
close(done)
case <-done:
}
waitDone:
// 等待所有 goroutine 退出(最多 10s
waitCh := make(chan struct{})
go func() {
wg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
case <-time.After(10 * time.Second):
fmt.Println("等待超时,强制退出")
}
// 最终报告
printReport()
}
func printReport() {
fmt.Println()
fmt.Println("========== 测试报告 ==========")
fmt.Printf("目标连接数: %d\n", *flagConns)
fmt.Printf("连接尝试: %d\n", s.connAttempts.Load())
fmt.Printf("成功连接: %d\n", s.connSuccess.Load())
fmt.Printf("连接失败: %d\n", s.connFailed.Load())
fmt.Printf("握手成功: %d\n", s.handshaked.Load())
fmt.Printf("握手失败: %d\n", s.handshakeFail.Load())
fmt.Println("------------------------------")
fmt.Printf("触发故事数: %d\n", s.storySent.Load())
fmt.Printf("收到 tts start: %d\n", s.ttsStart.Load())
fmt.Printf("收到 tts stop: %d\n", s.ttsStop.Load())
fmt.Printf("Opus 帧总数: %d\n", s.opusFrames.Load())
if s.storySent.Load() > 0 {
fmt.Printf("平均帧数/故事: %d\n", s.opusFrames.Load()/max(s.ttsStop.Load(), 1))
}
if s.firstFrameCnt.Load() > 0 {
avgMs := s.firstFrameNs.Load() / s.firstFrameCnt.Load() / 1e6
fmt.Printf("首帧延迟(avg): %dms\n", avgMs)
}
fmt.Printf("错误总数: %d\n", s.errors.Load())
fmt.Println("==============================")
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@ -1,94 +0,0 @@
# hw_service_go 压力测试报告
> 测试时间2026-03-03
> 测试环境K8s 线上环境2 Pod 副本)
---
## 优化背景
**问题**hw_service_go 每次播放故事都实时执行 `MP3下载 → ffmpeg转码 → Opus编码`ffmpeg 是 CPU 密集型操作0.5 核 CPU 下 5 个并发首帧延迟达 4.5s。
**方案**TTS 生成 MP3 后立即预转码为 Opus 帧数据JSON 格式)上传 OSS。播放时直接下载预处理好的 Opus 数据,跳过 ffmpeg。
---
## 测试配置
| 项目 | 值 |
|------|------|
| 目标地址 | `wss://qiyuan-rtc-api.airlabs.art/xiaozhi/v1/` |
| 服务副本数 | 2 Pod |
| 建连速率 | 20/s |
| 测试时长 | 60s |
| MAC 前缀 | AA:BB:CC:DD |
---
## 优化前基准ffmpeg 实时转码)
| 并发数 | 首帧延迟 | CPU 占用 | 备注 |
|--------|---------|---------|------|
| 5 | ~4,500ms | 接近 100%0.5核) | CPU 瓶颈明显 |
| 10+ | 超时/失败 | 过载 | 无法正常服务 |
---
## 优化后测试结果Opus 预转码)
| 并发故事数 | 首帧延迟 | 连接成功率 | 故事播完率 | 错误数 | 平均帧数/故事 |
|-----------|---------|-----------|-----------|--------|-------------|
| 20 | 74ms | 100%20/20 | 100%20/20 | 0 | 796 |
| 100 | 89ms | 100%100/100 | 100%100/100 | 0 | 796 |
| 200 | 84ms | 100%200/200 | 100%200/200 | 0 | 796 |
| 400 | 82ms | 100%400/400 | 100%400/400 | 0 | 796 |
| **800** | **80ms** | **100%800/800** | **100%800/800** | **0** | **796** |
---
## 关键指标对比
| 指标 | 优化前ffmpeg | 优化后(预转码) | 提升 |
|------|----------------|----------------|------|
| 首帧延迟 | ~4,500ms | ~80ms | **56 倍** |
| 最大并发(故事播放) | 5-10 | 800+(未触顶) | **80-160 倍** |
| CPU 开销 | ffmpeg 转码吃满 CPU | 几乎为零(仅网络 I/O | - |
| 帧推送稳定性 | 高并发丢帧 | 796 帧/故事,零丢帧 | - |
| 错误率 | 高并发下频繁超时 | 0% | - |
---
## 数据分析
### 帧推送吞吐量
800 并发时,每 2 秒推送约 26,600 帧800 路 × ~33 帧/2s与理论值60ms/帧)完全吻合,说明服务端帧调度精准、无积压。
### 首帧延迟稳定性
从 20 到 800 并发,首帧延迟始终保持在 74-89ms 范围内,无明显上升趋势。延迟主要来自 OSS 下载 Opus JSON 文件(~80ms与并发数无关。
### 负载均衡
2 个 Pod 均分连接800 并发时每个 Pod 承担 400 个连接,负载均衡工作正常。
---
## 商用容量评估
| 设备规模 | 预估高峰并发故事 | 当前支撑能力 | 是否满足 |
|---------|----------------|------------|---------|
| 2,000 台 | 100-200 | 800+2 Pod | 充裕 |
| 5,000 台 | 250-500 | 800+2 Pod | 满足 |
| 10,000 台 | 500-1,000 | 扩容至 4 Pod 即可 | 可支撑 |
> 说明:高峰并发按在线率 50%、同时播放率 20% 估算。儿童故事机使用集中在下午 4-6 点和晚上 7-9 点。
---
## 结论
1. Opus 预转码方案效果显著,首帧延迟从 **4.5s 降至 80ms**(提升 56 倍)
2. 800 并发同时播放故事0 错误、0 丢帧,服务器未触及性能瓶颈
3. **2,000 台设备商用完全没有问题**,且有充足余量应对突发流量
4. 如需支撑更大规模K8s 水平扩 Pod 即可线性提升容量

View File

@ -1,667 +0,0 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>hw_service_go 硬件通讯测试</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #f5f5f5;
color: #333;
padding: 20px;
}
.container {
max-width: 800px;
margin: 0 auto;
background: #fff;
border-radius: 12px;
box-shadow: 0 2px 12px rgba(0,0,0,0.1);
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: #fff;
padding: 20px 24px;
}
.header h1 { font-size: 20px; font-weight: 600; }
.header p { font-size: 13px; opacity: 0.8; margin-top: 4px; }
.section {
padding: 16px 24px;
border-bottom: 1px solid #eee;
}
.section:last-child { border-bottom: none; }
.section-title {
font-size: 13px;
font-weight: 600;
color: #888;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 12px;
}
.form-row {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 10px;
}
.form-row:last-child { margin-bottom: 0; }
.form-row label {
min-width: 80px;
font-size: 14px;
font-weight: 500;
color: #555;
}
.form-row input {
flex: 1;
padding: 8px 12px;
border: 1px solid #ddd;
border-radius: 6px;
font-size: 14px;
font-family: 'SF Mono', Monaco, monospace;
outline: none;
transition: border-color 0.2s;
}
.form-row input:focus { border-color: #667eea; }
.btn {
padding: 8px 18px;
border: none;
border-radius: 6px;
font-size: 13px;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
display: inline-flex;
align-items: center;
gap: 6px;
}
.btn:disabled { opacity: 0.5; cursor: not-allowed; }
.btn-primary { background: #667eea; color: #fff; }
.btn-primary:hover:not(:disabled) { background: #5a6fd6; }
.btn-danger { background: #e74c3c; color: #fff; }
.btn-danger:hover:not(:disabled) { background: #c0392b; }
.btn-success { background: #27ae60; color: #fff; }
.btn-success:hover:not(:disabled) { background: #219a52; }
.btn-secondary { background: #95a5a6; color: #fff; }
.btn-secondary:hover:not(:disabled) { background: #7f8c8d; }
.btn-small { padding: 4px 10px; font-size: 12px; }
.controls {
display: flex;
align-items: center;
gap: 12px;
flex-wrap: wrap;
}
.status-indicator {
display: inline-flex;
align-items: center;
gap: 6px;
font-size: 14px;
font-weight: 500;
}
.status-dot {
width: 10px;
height: 10px;
border-radius: 50%;
background: #bdc3c7;
transition: background 0.3s;
}
.status-dot.connected { background: #27ae60; }
.status-dot.connecting { background: #f39c12; animation: pulse 1s infinite; }
.status-dot.error { background: #e74c3c; }
@keyframes pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.4; }
}
.log-container {
background: #1e1e1e;
color: #d4d4d4;
border-radius: 8px;
padding: 12px;
height: 400px;
overflow-y: auto;
font-family: 'SF Mono', Monaco, 'Cascadia Code', monospace;
font-size: 12px;
line-height: 1.6;
}
.log-container::-webkit-scrollbar { width: 6px; }
.log-container::-webkit-scrollbar-track { background: transparent; }
.log-container::-webkit-scrollbar-thumb { background: #555; border-radius: 3px; }
.log-entry { white-space: pre-wrap; word-break: break-all; }
.log-time { color: #858585; }
.log-send { color: #dcdcaa; }
.log-recv { color: #9cdcfe; }
.log-binary { color: #ce9178; }
.log-audio { color: #c586c0; }
.log-error { color: #f44747; }
.log-success { color: #6a9955; }
.log-warning { color: #d7ba7d; }
.stats-bar {
display: flex;
gap: 24px;
padding: 12px 24px;
background: #fafafa;
font-size: 13px;
color: #666;
}
.stats-bar span { font-weight: 600; color: #333; }
.log-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 12px;
}
</style>
<!-- Opus WASM 解码库 -->
<script src="libopus.js"></script>
</head>
<body>
<div class="container">
<!-- 标题 -->
<div class="header">
<h1>hw_service_go 硬件通讯测试</h1>
<p>模拟 ESP32 硬件,测试 WebSocket 故事推送与 Opus 音频播放</p>
</div>
<!-- 连接配置 -->
<div class="section">
<div class="section-title">连接配置</div>
<div class="form-row">
<label>服务地址</label>
<input type="text" id="wsUrl" value="ws://localhost:8888/xiaozhi/v1/">
<button class="btn btn-secondary btn-small" id="btnEnvToggle" onclick="toggleEnv()">切换线上</button>
</div>
<div class="form-row">
<label>device-id</label>
<input type="text" id="deviceId" value="20:6E:F1:B9:AF:A2">
</div>
<div class="form-row">
<label>client-id</label>
<input type="text" id="clientId" value="">
<button class="btn btn-secondary btn-small" onclick="generateClientId()">随机生成</button>
</div>
</div>
<!-- 控制面板 -->
<div class="section">
<div class="controls">
<button class="btn btn-primary" id="btnConnect" onclick="connect()">连接</button>
<button class="btn btn-danger" id="btnDisconnect" onclick="disconnect()" disabled>断开</button>
<div style="width: 1px; height: 24px; background: #ddd;"></div>
<button class="btn btn-success" id="btnStory" onclick="triggerStory()" disabled>&#9654; 触发故事播放</button>
<button class="btn btn-danger" id="btnStop" onclick="stopPlayback()" disabled>&#9632; 停止</button>
<div style="flex:1"></div>
<div class="status-indicator">
<div class="status-dot" id="statusDot"></div>
<span id="statusText">未连接</span>
</div>
</div>
</div>
<!-- 统计栏 -->
<div class="stats-bar">
<div>Opus 帧: <span id="statFrames">0</span></div>
<div>音频时长: <span id="statDuration">0.0s</span></div>
<div>Opus 库: <span id="statOpus">加载中...</span></div>
</div>
<!-- 消息日志 -->
<div class="section">
<div class="log-header">
<div class="section-title" style="margin-bottom:0">消息日志</div>
<button class="btn btn-secondary btn-small" onclick="clearLog()">清空</button>
</div>
<div class="log-container" id="logContainer"></div>
</div>
</div>
<script>
// ============================================================
// 全局状态
// ============================================================
let ws = null;
let audioCtx = null;
let opusDecoder = null;
let opusReady = false;
let handshaked = false; // hello 握手是否完成
// 播放状态
let opusFrameCount = 0;
let pcmBufferQueue = []; // Float32Array 队列
let isPlaying = false;
let nextPlayTime = 0;
// ============================================================
// 工具函数
// ============================================================
function $(id) { return document.getElementById(id); }
function log(msg, type = 'info') {
const container = $('logContainer');
const now = new Date();
const ts = `${now.toLocaleTimeString()}.${String(now.getMilliseconds()).padStart(3, '0')}`;
const entry = document.createElement('div');
entry.className = 'log-entry';
const typeClass = {
send: 'log-send',
recv: 'log-recv',
binary: 'log-binary',
audio: 'log-audio',
error: 'log-error',
success: 'log-success',
warning: 'log-warning',
}[type] || '';
const arrow = type === 'send' ? '→ ' : type === 'recv' ? '← ' : type === 'binary' ? '← ' : '';
entry.innerHTML = `<span class="log-time">[${ts}]</span> <span class="${typeClass}">${arrow}${escapeHtml(msg)}</span>`;
container.appendChild(entry);
container.scrollTop = container.scrollHeight;
}
function escapeHtml(str) {
const div = document.createElement('div');
div.textContent = str;
return div.innerHTML;
}
function updateStatus(state, text) {
const dot = $('statusDot');
dot.className = 'status-dot';
if (state) dot.classList.add(state);
$('statusText').textContent = text;
}
function updateStats() {
$('statFrames').textContent = opusFrameCount;
const duration = (opusFrameCount * 60 / 1000).toFixed(1);
$('statDuration').textContent = `${duration}s`;
}
function generateClientId() {
const id = 'test-' + Math.random().toString(36).substring(2, 10);
$('clientId').value = id;
}
const ENV_LOCAL = { url: 'ws://localhost:8888/xiaozhi/v1/', label: '切换线上' };
const ENV_PROD = { url: 'wss://qiyuan-rtc-api.airlabs.art/xiaozhi/v1/', label: '切换本地' };
let currentEnv = 'local';
function toggleEnv() {
if (currentEnv === 'local') {
$('wsUrl').value = ENV_PROD.url;
$('btnEnvToggle').textContent = ENV_PROD.label;
currentEnv = 'prod';
} else {
$('wsUrl').value = ENV_LOCAL.url;
$('btnEnvToggle').textContent = ENV_LOCAL.label;
currentEnv = 'local';
}
}
function clearLog() {
$('logContainer').innerHTML = '';
}
// ============================================================
// Opus 解码器初始化
// ============================================================
function initOpusDecoder() {
try {
let mod = null;
// 检查 Module.instance 或全局 Module
if (typeof Module !== 'undefined') {
if (Module.instance && typeof Module.instance._opus_decoder_get_size === 'function') {
mod = Module.instance;
} else if (typeof Module._opus_decoder_get_size === 'function') {
mod = Module;
}
}
if (!mod) {
log('Opus 库未就绪,等待加载...', 'warning');
$('statOpus').textContent = '加载失败';
return false;
}
const SAMPLE_RATE = 16000;
const CHANNELS = 1;
const FRAME_SIZE = 960; // 60ms @ 16kHz
// 获取解码器大小并分配内存
const decoderSize = mod._opus_decoder_get_size(CHANNELS);
const decoderPtr = mod._malloc(decoderSize);
if (!decoderPtr) throw new Error('无法分配解码器内存');
// 初始化解码器
const err = mod._opus_decoder_init(decoderPtr, SAMPLE_RATE, CHANNELS);
if (err < 0) throw new Error(`Opus 解码器初始化失败: ${err}`);
opusDecoder = {
mod,
decoderPtr,
frameSize: FRAME_SIZE,
decode(opusData) {
// 为 Opus 数据分配内存
const opusPtr = mod._malloc(opusData.length);
mod.HEAPU8.set(opusData, opusPtr);
// 为 PCM 输出分配内存 (Int16 = 2 bytes)
const pcmPtr = mod._malloc(FRAME_SIZE * 2);
// 解码
const decodedSamples = mod._opus_decode(
decoderPtr, opusPtr, opusData.length,
pcmPtr, FRAME_SIZE, 0
);
if (decodedSamples < 0) {
mod._free(opusPtr);
mod._free(pcmPtr);
throw new Error(`Opus 解码失败: ${decodedSamples}`);
}
// 读取 Int16 并转为 Float32
const float32 = new Float32Array(decodedSamples);
for (let i = 0; i < decodedSamples; i++) {
const sample = mod.HEAP16[(pcmPtr >> 1) + i];
float32[i] = sample / (sample < 0 ? 0x8000 : 0x7FFF);
}
mod._free(opusPtr);
mod._free(pcmPtr);
return float32;
},
destroy() {
if (decoderPtr) mod._free(decoderPtr);
}
};
opusReady = true;
log('Opus 解码器初始化成功 (16kHz, 单声道, 60ms/帧)', 'success');
$('statOpus').textContent = '就绪';
$('statOpus').style.color = '#27ae60';
return true;
} catch (e) {
log(`Opus 初始化失败: ${e.message}`, 'error');
$('statOpus').textContent = '失败';
$('statOpus').style.color = '#e74c3c';
return false;
}
}
// ============================================================
// WebSocket 连接
// ============================================================
function connect() {
if (ws && ws.readyState === WebSocket.OPEN) {
log('已经连接,请先断开', 'warning');
return;
}
const baseUrl = $('wsUrl').value.trim();
const deviceId = $('deviceId').value.trim();
const clientId = $('clientId').value.trim();
if (!deviceId) { log('请输入 device-id (MAC 地址)', 'error'); return; }
if (!clientId) { log('请输入 client-id', 'error'); return; }
// 确保 AudioContext 存在(需要用户交互后创建)
if (!audioCtx) {
audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
}
// 确保 Opus 解码器已初始化
if (!opusReady) {
if (!initOpusDecoder()) {
log('Opus 解码器未就绪,无法连接', 'error');
return;
}
}
// 构建 URL
const url = new URL(baseUrl);
url.searchParams.set('device-id', deviceId);
url.searchParams.set('client-id', clientId);
const connUrl = url.toString();
log(`正在连接: ${connUrl}`, 'info');
updateStatus('connecting', '连接中...');
$('btnConnect').disabled = true;
ws = new WebSocket(connUrl);
ws.binaryType = 'arraybuffer';
ws.onopen = () => {
log('WebSocket 连接成功,发送 hello 握手...', 'success');
updateStatus('connecting', '握手中...');
$('btnConnect').disabled = true;
$('btnDisconnect').disabled = false;
handshaked = false;
// 发送 hello 握手消息
const helloMsg = JSON.stringify({ type: 'hello', mac: deviceId });
ws.send(helloMsg);
log(`发送: ${helloMsg}`, 'send');
};
ws.onmessage = (event) => {
if (event.data instanceof ArrayBuffer) {
handleBinaryMessage(event.data);
} else {
handleTextMessage(event.data);
}
};
ws.onerror = (err) => {
log('WebSocket 错误', 'error');
updateStatus('error', '连接错误');
};
ws.onclose = (event) => {
log(`WebSocket 已关闭 (code=${event.code}, reason=${event.reason || '无'})`, 'warning');
updateStatus(null, '未连接');
$('btnConnect').disabled = false;
$('btnDisconnect').disabled = true;
$('btnStory').disabled = true;
$('btnStop').disabled = true;
ws = null;
};
}
function disconnect() {
if (ws) {
ws.close();
log('主动断开连接', 'info');
}
handshaked = false;
stopPlayback();
}
// ============================================================
// 消息处理
// ============================================================
function handleTextMessage(data) {
try {
const msg = JSON.parse(data);
log(`收到: ${JSON.stringify(msg)}`, 'recv');
// 处理 hello 握手响应
if (msg.type === 'hello' && msg.status === 'ok') {
handshaked = true;
log(`握手成功session_id=${msg.session_id}`, 'success');
updateStatus('connected', '已连接');
$('btnStory').disabled = false;
return;
}
if (msg.type === 'tts') {
switch (msg.state) {
case 'start':
log('故事推送开始', 'audio');
resetPlaybackState();
$('btnStop').disabled = false;
break;
case 'sentence_start':
if (msg.text) {
log(`故事标题: ${msg.text}`, 'audio');
}
break;
case 'stop':
log('故事推送结束', 'audio');
$('btnStop').disabled = true;
// 标记流结束,等待播放完成
log(`共接收 ${opusFrameCount} 个 Opus 帧,约 ${(opusFrameCount * 60 / 1000).toFixed(1)}s 音频`, 'success');
break;
}
}
} catch (e) {
log(`收到文本: ${data}`, 'recv');
}
}
function handleBinaryMessage(data) {
const frame = new Uint8Array(data);
opusFrameCount++;
updateStats();
// 每 20 帧打印一次,避免刷屏
if (opusFrameCount <= 3 || opusFrameCount % 20 === 0) {
log(`[Binary] Opus 帧 #${opusFrameCount} (${frame.length} bytes)`, 'binary');
}
if (!opusDecoder) {
log('Opus 解码器未初始化,丢弃帧', 'error');
return;
}
try {
const pcmFloat32 = opusDecoder.decode(frame);
if (pcmFloat32 && pcmFloat32.length > 0) {
pcmBufferQueue.push(pcmFloat32);
schedulePlayback();
}
} catch (e) {
log(`解码帧 #${opusFrameCount} 失败: ${e.message}`, 'error');
}
}
// ============================================================
// 音频播放(按时序排队)
// ============================================================
function resetPlaybackState() {
opusFrameCount = 0;
pcmBufferQueue = [];
isPlaying = false;
nextPlayTime = 0;
updateStats();
}
function schedulePlayback() {
// 预缓冲:等待至少 3 帧再开始播放
if (!isPlaying && pcmBufferQueue.length < 3) return;
if (!isPlaying) {
isPlaying = true;
log('开始音频播放...', 'audio');
}
// 如果 AudioContext 被暂停(浏览器策略),恢复它
if (audioCtx && audioCtx.state === 'suspended') {
audioCtx.resume();
}
// 直接把队列中所有帧排入播放时间线
while (pcmBufferQueue.length > 0) {
playPcmChunk(pcmBufferQueue.shift());
}
}
function playPcmChunk(pcmFloat32) {
const buffer = audioCtx.createBuffer(1, pcmFloat32.length, 16000);
buffer.copyToChannel(pcmFloat32, 0);
const source = audioCtx.createBufferSource();
source.buffer = buffer;
const now = audioCtx.currentTime;
const startTime = Math.max(now, nextPlayTime);
source.connect(audioCtx.destination);
source.start(startTime);
// 下一帧紧接当前帧播放
nextPlayTime = startTime + buffer.duration;
}
function stopPlayback() {
pcmBufferQueue = [];
isPlaying = false;
nextPlayTime = 0;
if (audioCtx) {
// 创建新的 AudioContext 来停止所有播放
audioCtx.close();
audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
}
log('播放已停止', 'audio');
$('btnStop').disabled = true;
}
// ============================================================
// 触发故事
// ============================================================
function triggerStory() {
if (!ws || ws.readyState !== WebSocket.OPEN) {
log('WebSocket 未连接', 'error');
return;
}
if (!handshaked) {
log('握手尚未完成,请等待', 'warning');
return;
}
const msg = JSON.stringify({ type: 'story' });
ws.send(msg);
log(`发送: ${msg}`, 'send');
// 重置统计
resetPlaybackState();
$('btnStop').disabled = false;
}
// ============================================================
// 页面初始化
// ============================================================
window.addEventListener('DOMContentLoaded', () => {
// 生成默认 client-id
generateClientId();
// 延迟初始化 Opus等 WASM 加载完)
const checkOpus = () => {
if (typeof Module !== 'undefined' &&
((Module.instance && typeof Module.instance._opus_decoder_get_size === 'function') ||
typeof Module._opus_decoder_get_size === 'function')) {
initOpusDecoder();
} else {
setTimeout(checkOpus, 200);
}
};
setTimeout(checkOpus, 500);
log('页面加载完成,等待 Opus 库初始化...', 'info');
});
</script>
</body>
</html>

View File

@ -1,25 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
.idea/
*.iml

View File

@ -1,9 +0,0 @@
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.
Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de>

View File

@ -1,22 +0,0 @@
Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,33 +0,0 @@
# Gorilla WebSocket
[![GoDoc](https://godoc.org/github.com/gorilla/websocket?status.svg)](https://godoc.org/github.com/gorilla/websocket)
[![CircleCI](https://circleci.com/gh/gorilla/websocket.svg?style=svg)](https://circleci.com/gh/gorilla/websocket)
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
* [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat)
* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command)
* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo)
* [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch)
### Status
The Gorilla WebSocket package provides a complete and tested implementation of
the [WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol. The
package API is stable.
### Installation
go get github.com/gorilla/websocket
### Protocol Compliance
The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).

View File

@ -1,434 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
"time"
)
// ErrBadHandshake is returned when the server response to opening handshake is
// invalid.
var ErrBadHandshake = errors.New("websocket: bad handshake")
var errInvalidCompression = errors.New("websocket: invalid compression negotiation")
// NewClient creates a new client connection using the given net connection.
// The URL u specifies the host and request URI. Use requestHeader to specify
// the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
// (Cookie). Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etc.
//
// Deprecated: Use Dialer instead.
func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
d := Dialer{
ReadBufferSize: readBufSize,
WriteBufferSize: writeBufSize,
NetDial: func(net, addr string) (net.Conn, error) {
return netConn, nil
},
}
return d.Dial(u.String(), requestHeader)
}
// A Dialer contains options for connecting to WebSocket server.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*http.Request) (*url.URL, error)
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then a useful default size is used. The I/O buffer sizes
// do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the client's requested subprotocols.
Subprotocols []string
// EnableCompression specifies if the client should attempt to negotiate
// per message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
// Jar specifies the cookie jar.
// If Jar is nil, cookies are not sent in requests and ignored
// in responses.
Jar http.CookieJar
}
// Dial creates a new client connection by calling DialContext with a background context.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
return d.DialContext(context.Background(), urlStr, requestHeader)
}
var errMalformedURL = errors.New("malformed ws or wss URL")
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host
hostNoPort = u.Host
if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
hostNoPort = hostNoPort[:i]
} else {
switch u.Scheme {
case "wss":
hostPort += ":443"
case "https":
hostPort += ":443"
default:
hostPort += ":80"
}
}
return hostPort, hostNoPort
}
// DefaultDialer is a dialer with all fields set to the default values.
var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
}
// nilDialer is dialer to use when receiver is nil.
var nilDialer = *DefaultDialer
// DialContext creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
//
// The context will be used in the request and in the Dialer.
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not
// need to be closed by the application.
func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil {
d = &nilDialer
}
challengeKey, err := generateChallengeKey()
if err != nil {
return nil, nil, err
}
u, err := url.Parse(urlStr)
if err != nil {
return nil, nil, err
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
default:
return nil, nil, errMalformedURL
}
if u.User != nil {
// User name and password are not allowed in websocket URIs.
return nil, nil, errMalformedURL
}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
req = req.WithContext(ctx)
// Set the cookies present in the cookie jar of the dialer
if d.Jar != nil {
for _, cookie := range d.Jar.Cookies(u) {
req.AddCookie(cookie)
}
}
// Set the request headers using the capitalization for names and values in
// RFC examples. Although the capitalization shouldn't matter, there are
// servers that depend on it. The Header.Set method is not used because the
// method canonicalizes the header names.
req.Header["Upgrade"] = []string{"websocket"}
req.Header["Connection"] = []string{"Upgrade"}
req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
req.Header["Sec-WebSocket-Version"] = []string{"13"}
if len(d.Subprotocols) > 0 {
req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
}
for k, vs := range requestHeader {
switch {
case k == "Host":
if len(vs) > 0 {
req.Host = vs[0]
}
case k == "Upgrade" ||
k == "Connection" ||
k == "Sec-Websocket-Key" ||
k == "Sec-Websocket-Version" ||
k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
case k == "Sec-Websocket-Protocol":
req.Header["Sec-WebSocket-Protocol"] = vs
default:
req.Header[k] = vs
}
}
if d.EnableCompression {
req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
}
if d.HandshakeTimeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
defer cancel()
}
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}
// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
if err != nil {
return nil, nil, err
}
netDial = dialer.Dial
}
}
hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
if trace != nil && trace.GetConn != nil {
trace.GetConn(hostPort)
}
netConn, err := netDial("tcp", hostPort)
if err != nil {
return nil, nil, err
}
if trace != nil && trace.GotConn != nil {
trace.GotConn(httptrace.GotConnInfo{
Conn: netConn,
})
}
defer func() {
if netConn != nil {
netConn.Close()
}
}()
if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
}
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
return nil, nil, err
}
}
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil {
return nil, nil, err
}
if trace != nil && trace.GotFirstResponseByte != nil {
if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
trace.GotFirstResponseByte()
}
}
resp, err := http.ReadResponse(conn.br, req)
if err != nil {
if d.TLSClientConfig != nil {
for _, proto := range d.TLSClientConfig.NextProtos {
if proto != "http/1.1" {
return nil, nil, fmt.Errorf(
"websocket: protocol %q was given but is not supported;"+
"sharing tls.Config with net/http Transport can cause this error: %w",
proto, err,
)
}
}
}
return nil, nil, err
}
if d.Jar != nil {
if rc := resp.Cookies(); len(rc) > 0 {
d.Jar.SetCookies(u, rc)
}
}
if resp.StatusCode != 101 ||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
// debugging.
buf := make([]byte, 1024)
n, _ := io.ReadFull(resp.Body, buf)
resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
return nil, resp, ErrBadHandshake
}
for _, ext := range parseExtensions(resp.Header) {
if ext[""] != "permessage-deflate" {
continue
}
_, snct := ext["server_no_context_takeover"]
_, cnct := ext["client_no_context_takeover"]
if !snct || !cnct {
return nil, resp, errInvalidCompression
}
conn.newCompressionWriter = compressNoContextTakeover
conn.newDecompressionReader = decompressNoContextTakeover
break
}
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
netConn.SetDeadline(time.Time{})
netConn = nil // to avoid close in defer.
return conn, resp, nil
}
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View File

@ -1,148 +0,0 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"compress/flate"
"errors"
"io"
"strings"
"sync"
)
const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1
)
var (
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
)
func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"
fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
}
func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}
func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}
// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}
func (w *truncWriter) Write(p []byte) (int, error) {
n := 0
// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}
m := len(p)
if m > len(w.p) {
m = len(w.p)
}
if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}
copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}
type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
p *sync.Pool
}
func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}
func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
err1 := w.fw.Flush()
w.p.Put(w.fw)
w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}
type flateReadWrapper struct {
fr io.ReadCloser
}
func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}
func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}

File diff suppressed because it is too large Load Diff

View File

@ -1,227 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package websocket implements the WebSocket protocol defined in RFC 6455.
//
// Overview
//
// The Conn type represents a WebSocket connection. A server application calls
// the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
//
// var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024,
// WriteBufferSize: 1024,
// }
//
// func handler(w http.ResponseWriter, r *http.Request) {
// conn, err := upgrader.Upgrade(w, r, nil)
// if err != nil {
// log.Println(err)
// return
// }
// ... Use conn to send and receive messages.
// }
//
// Call the connection's WriteMessage and ReadMessage methods to send and
// receive messages as a slice of bytes. This snippet of code shows how to echo
// messages using these methods:
//
// for {
// messageType, p, err := conn.ReadMessage()
// if err != nil {
// log.Println(err)
// return
// }
// if err := conn.WriteMessage(messageType, p); err != nil {
// log.Println(err)
// return
// }
// }
//
// In above snippet of code, p is a []byte and messageType is an int with value
// websocket.BinaryMessage or websocket.TextMessage.
//
// An application can also send and receive messages using the io.WriteCloser
// and io.Reader interfaces. To send a message, call the connection NextWriter
// method to get an io.WriteCloser, write the message to the writer and close
// the writer when done. To receive a message, call the connection NextReader
// method to get an io.Reader and read until io.EOF is returned. This snippet
// shows how to echo messages using the NextWriter and NextReader methods:
//
// for {
// messageType, r, err := conn.NextReader()
// if err != nil {
// return
// }
// w, err := conn.NextWriter(messageType)
// if err != nil {
// return err
// }
// if _, err := io.Copy(w, r); err != nil {
// return err
// }
// if err := w.Close(); err != nil {
// return err
// }
// }
//
// Data Messages
//
// The WebSocket protocol distinguishes between text and binary data messages.
// Text messages are interpreted as UTF-8 encoded text. The interpretation of
// binary messages is left to the application.
//
// This package uses the TextMessage and BinaryMessage integer constants to
// identify the two data message types. The ReadMessage and NextReader methods
// return the type of the received message. The messageType argument to the
// WriteMessage and NextWriter methods specifies the type of a sent message.
//
// It is the application's responsibility to ensure that text messages are
// valid UTF-8 encoded text.
//
// Control Messages
//
// The WebSocket protocol defines three types of control messages: close, ping
// and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer.
//
// Connections handle received close messages by calling the handler function
// set with the SetCloseHandler method and by returning a *CloseError from the
// NextReader, ReadMessage or the message Read method. The default close
// handler sends a close message to the peer.
//
// Connections handle received ping messages by calling the handler function
// set with the SetPingHandler method. The default ping handler sends a pong
// message to the peer.
//
// Connections handle received pong messages by calling the handler function
// set with the SetPongHandler method. The default pong handler does nothing.
// If an application sends ping messages, then the application should set a
// pong handler to receive the corresponding pong.
//
// The control message handler functions are called from the NextReader,
// ReadMessage and message reader Read methods. The default close and ping
// handlers can block these methods for a short time when the handler writes to
// the connection.
//
// The application must read the connection to process close, ping and pong
// messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is:
//
// func readLoop(c *websocket.Conn) {
// for {
// if _, _, err := c.NextReader(); err != nil {
// c.Close()
// break
// }
// }
// }
//
// Concurrency
//
// Connections support one concurrent reader and one concurrent writer.
//
// Applications are responsible for ensuring that no more than one goroutine
// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage,
// WriteJSON, EnableWriteCompression, SetCompressionLevel) concurrently and
// that no more than one goroutine calls the read methods (NextReader,
// SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, SetPingHandler)
// concurrently.
//
// The Close and WriteControl methods can be called concurrently with all other
// methods.
//
// Origin Considerations
//
// Web browsers allow Javascript applications to open a WebSocket connection to
// any host. It's up to the server to enforce an origin policy using the Origin
// request header sent by the browser.
//
// The Upgrader calls the function specified in the CheckOrigin field to check
// the origin. If the CheckOrigin function returns false, then the Upgrade
// method fails the WebSocket handshake with HTTP status 403.
//
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and the Origin host is
// not equal to the Host request header.
//
// The deprecated package-level Upgrade function does not perform origin
// checking. The application is responsible for checking the Origin header
// before calling the Upgrade function.
//
// Buffers
//
// Connections buffer network input and output to reduce the number
// of system calls when reading or writing messages.
//
// Write buffers are also used for constructing WebSocket frames. See RFC 6455,
// Section 5 for a discussion of message framing. A WebSocket frame header is
// written to the network each time a write buffer is flushed to the network.
// Decreasing the size of the write buffer can increase the amount of framing
// overhead on the connection.
//
// The buffer sizes in bytes are specified by the ReadBufferSize and
// WriteBufferSize fields in the Dialer and Upgrader. The Dialer uses a default
// size of 4096 when a buffer size field is set to zero. The Upgrader reuses
// buffers created by the HTTP server when a buffer size field is set to zero.
// The HTTP server buffers have a size of 4096 at the time of this writing.
//
// The buffer sizes do not limit the size of a message that can be read or
// written by a connection.
//
// Buffers are held for the lifetime of the connection by default. If the
// Dialer or Upgrader WriteBufferPool field is set, then a connection holds the
// write buffer only when writing a message.
//
// Applications should tune the buffer sizes to balance memory use and
// performance. Increasing the buffer size uses more memory, but can reduce the
// number of system calls to read or write the network. In the case of writing,
// increasing the buffer size can reduce the number of frame headers written to
// the network.
//
// Some guidelines for setting buffer parameters are:
//
// Limit the buffer sizes to the maximum expected message size. Buffers larger
// than the largest message do not provide any benefit.
//
// Depending on the distribution of message sizes, setting the buffer size to
// a value less than the maximum expected message size can greatly reduce memory
// use with a small impact on performance. Here's an example: If 99% of the
// messages are smaller than 256 bytes and the maximum message size is 512
// bytes, then a buffer size of 256 bytes will result in 1.01 more system calls
// than a buffer size of 512 bytes. The memory savings is 50%.
//
// A write buffer pool is useful when the application has a modest number
// writes over a large number of connections. when buffers are pooled, a larger
// buffer size has a reduced impact on total memory use and has the benefit of
// reducing system calls and frame overhead.
//
// Compression EXPERIMENTAL
//
// Per message compression extensions (RFC 7692) are experimentally supported
// by this package in a limited capacity. Setting the EnableCompression option
// to true in Dialer or Upgrader will attempt to negotiate per message deflate
// support.
//
// var upgrader = websocket.Upgrader{
// EnableCompression: true,
// }
//
// If compression was successfully negotiated with the connection's peer, any
// message received in compressed form will be automatically decompressed.
// All Read methods will return uncompressed bytes.
//
// Per message compression of messages written to a connection can be enabled
// or disabled by calling the corresponding Conn method:
//
// conn.EnableWriteCompression(false)
//
// Currently this package does not support compression with "context takeover".
// This means that messages must be compressed and decompressed in isolation,
// without retaining sliding window or dictionary state across messages. For
// more details refer to RFC 7692.
//
// Use of compression is experimental and may result in decreased performance.
package websocket

View File

@ -1,42 +0,0 @@
// Copyright 2019 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"io"
"strings"
)
// JoinMessages concatenates received messages to create a single io.Reader.
// The string term is appended to each message. The returned reader does not
// support concurrent calls to the Read method.
func JoinMessages(c *Conn, term string) io.Reader {
return &joinReader{c: c, term: term}
}
type joinReader struct {
c *Conn
term string
r io.Reader
}
func (r *joinReader) Read(p []byte) (int, error) {
if r.r == nil {
var err error
_, r.r, err = r.c.NextReader()
if err != nil {
return 0, err
}
if r.term != "" {
r.r = io.MultiReader(r.r, strings.NewReader(r.term))
}
}
n, err := r.r.Read(p)
if err == io.EOF {
err = nil
r.r = nil
}
return n, err
}

View File

@ -1,60 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"encoding/json"
"io"
)
// WriteJSON writes the JSON encoding of v as a message.
//
// Deprecated: Use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v)
}
// WriteJSON writes the JSON encoding of v as a message.
//
// See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON.
func (c *Conn) WriteJSON(v interface{}) error {
w, err := c.NextWriter(TextMessage)
if err != nil {
return err
}
err1 := json.NewEncoder(w).Encode(v)
err2 := w.Close()
if err1 != nil {
return err1
}
return err2
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// Deprecated: Use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v)
}
// ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// See the documentation for the encoding/json Unmarshal function for details
// about the conversion of JSON to a Go value.
func (c *Conn) ReadJSON(v interface{}) error {
_, r, err := c.NextReader()
if err != nil {
return err
}
err = json.NewDecoder(r).Decode(v)
if err == io.EOF {
// One value is expected in the message.
err = io.ErrUnexpectedEOF
}
return err
}

View File

@ -1,55 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build !appengine
// +build !appengine
package websocket
import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers.
if len(b) < 2*wordSize {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}
// Mask one byte at a time to word boundary.
if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 {
n = wordSize - n
for i := range b[:n] {
b[i] ^= key[pos&3]
pos++
}
b = b[n:]
}
// Create aligned word size key.
var k [wordSize]byte
for i := range k {
k[i] = key[(pos+i)&3]
}
kw := *(*uintptr)(unsafe.Pointer(&k))
// Mask one word at a time.
n := (len(b) / wordSize) * wordSize
for i := 0; i < n; i += wordSize {
*(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw
}
// Mask one byte at a time for remaining bytes.
b = b[n:]
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

View File

@ -1,16 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. Use of
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build appengine
// +build appengine
package websocket
func maskBytes(key [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= key[pos&3]
pos++
}
return pos & 3
}

View File

@ -1,102 +0,0 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bytes"
"net"
"sync"
"time"
)
// PreparedMessage caches on the wire representations of a message payload.
// Use PreparedMessage to efficiently send a message payload to multiple
// connections. PreparedMessage is especially useful when compression is used
// because the CPU and memory expensive compression operation can be executed
// once for a given set of compression options.
type PreparedMessage struct {
messageType int
data []byte
mu sync.Mutex
frames map[prepareKey]*preparedFrame
}
// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage.
type prepareKey struct {
isServer bool
compress bool
compressionLevel int
}
// preparedFrame contains data in wire representation.
type preparedFrame struct {
once sync.Once
data []byte
}
// NewPreparedMessage returns an initialized PreparedMessage. You can then send
// it to connection using WritePreparedMessage method. Valid wire
// representation will be calculated lazily only once for a set of current
// connection options.
func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) {
pm := &PreparedMessage{
messageType: messageType,
frames: make(map[prepareKey]*preparedFrame),
data: data,
}
// Prepare a plain server frame.
_, frameData, err := pm.frame(prepareKey{isServer: true, compress: false})
if err != nil {
return nil, err
}
// To protect against caller modifying the data argument, remember the data
// copied to the plain server frame.
pm.data = frameData[len(frameData)-len(data):]
return pm, nil
}
func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) {
pm.mu.Lock()
frame, ok := pm.frames[key]
if !ok {
frame = &preparedFrame{}
pm.frames[key] = frame
}
pm.mu.Unlock()
var err error
frame.once.Do(func() {
// Prepare a frame using a 'fake' connection.
// TODO: Refactor code in conn.go to allow more direct construction of
// the frame.
mu := make(chan struct{}, 1)
mu <- struct{}{}
var nc prepareConn
c := &Conn{
conn: &nc,
mu: mu,
isServer: key.isServer,
compressionLevel: key.compressionLevel,
enableWriteCompression: true,
writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize),
}
if key.compress {
c.newCompressionWriter = compressNoContextTakeover
}
err = c.WriteMessage(pm.messageType, pm.data)
frame.data = nc.buf.Bytes()
})
return pm.messageType, frame.data, err
}
type prepareConn struct {
buf bytes.Buffer
net.Conn
}
func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) }
func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil }

View File

@ -1,77 +0,0 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"encoding/base64"
"errors"
"net"
"net/http"
"net/url"
"strings"
)
type netDialerFunc func(network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}
func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDialer.Dial}, nil
})
}
type httpProxyDialer struct {
proxyURL *url.URL
forwardDial func(network, addr string) (net.Conn, error)
}
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.forwardDial(network, hostPort)
if err != nil {
return nil, err
}
connectHeader := make(http.Header)
if user := hpd.proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: http.MethodConnect,
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,
}
if err := connectReq.Write(conn); err != nil {
conn.Close()
return nil, err
}
// Read response. It's OK to use and discard buffered reader here becaue
// the remote server does not speak until spoken to.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
}
if resp.StatusCode != 200 {
conn.Close()
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
}
return conn, nil
}

View File

@ -1,365 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"errors"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
//
// It is safe to call Upgrader's methods concurrently.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client. If there's no match, then no protocol is
// negotiated (the Sec-Websocket-Protocol header is not included in the
// handshake response).
Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, then a safe default is used: return false if the
// Origin request header is present and the origin host is not equal to
// request Host header.
//
// A CheckOrigin function should carefully validate the request origin to
// prevent cross-site request forgery.
CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
err := HandshakeError{reason}
if u.Error != nil {
u.Error(w, r, status, err)
} else {
w.Header().Set("Sec-Websocket-Version", "13")
http.Error(w, http.StatusText(status), status)
}
return nil, err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin(r *http.Request) bool {
origin := r.Header["Origin"]
if len(origin) == 0 {
return true
}
u, err := url.Parse(origin[0])
if err != nil {
return false
}
return equalASCIIFold(u.Host, r.Host)
}
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
if u.Subprotocols != nil {
clientProtocols := Subprotocols(r)
for _, serverProtocol := range u.Subprotocols {
for _, clientProtocol := range clientProtocols {
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader.Get("Sec-Websocket-Protocol")
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
const badHandshake = "websocket: the client is not using the websocket protocol: "
if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
}
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != http.MethodGet {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
}
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
}
checkOrigin := u.CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
}
challengeKey := r.Header.Get("Sec-Websocket-Key")
if !isValidChallengeKey(challengeKey) {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
}
subprotocol := u.selectSubprotocol(r, responseHeader)
// Negotiate PMCE
var compress bool
if u.EnableCompression {
for _, ext := range parseExtensions(r.Header) {
if ext[""] != "permessage-deflate" {
continue
}
compress = true
break
}
}
h, ok := w.(http.Hijacker)
if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
}
var brw *bufio.ReadWriter
netConn, brw, err := h.Hijack()
if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error())
}
if brw.Reader.Buffered() > 0 {
netConn.Close()
return nil, errors.New("websocket: client sent data before handshake is complete")
}
var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol
if compress {
c.newCompressionWriter = compressNoContextTakeover
c.newDecompressionReader = decompressNoContextTakeover
}
// Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...)
if c.subprotocol != "" {
p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...)
p = append(p, "\r\n"...)
}
if compress {
p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
}
for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _, v := range vs {
p = append(p, k...)
p = append(p, ": "...)
for i := 0; i < len(v); i++ {
b := v[i]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append(p, b)
}
p = append(p, "\r\n"...)
}
}
p = append(p, "\r\n"...)
// Clear deadlines set by HTTP server.
netConn.SetDeadline(time.Time{})
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
}
if _, err = netConn.Write(p); err != nil {
netConn.Close()
return nil, err
}
if u.HandshakeTimeout > 0 {
netConn.SetWriteDeadline(time.Time{})
}
return c, nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// Deprecated: Use websocket.Upgrader instead.
//
// Upgrade does not perform origin checking. The application is responsible for
// checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", http.StatusForbidden)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
// don't return errors to maintain backwards compatibility
}
u.CheckOrigin = func(r *http.Request) bool {
// allow all connections by default
return true
}
return u.Upgrade(w, r, responseHeader)
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols(r *http.Request) []string {
h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
if h == "" {
return nil
}
protocols := strings.Split(h, ",")
for i := range protocols {
protocols[i] = strings.TrimSpace(protocols[i])
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket")
}
// bufioReaderSize size returns the size of a bufio.Reader.
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader)
if p, err := br.Peek(0); err == nil {
return cap(p)
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

View File

@ -1,21 +0,0 @@
//go:build go1.17
// +build go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -1,21 +0,0 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -1,298 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"io"
"net/http"
"strings"
"unicode/utf8"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func computeAcceptKey(challengeKey string) string {
h := sha1.New()
h.Write([]byte(challengeKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
func generateChallengeKey() (string, error) {
p := make([]byte, 16)
if _, err := io.ReadFull(rand.Reader, p); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(p), nil
}
// Token octets per RFC 2616.
var isTokenOctet = [256]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
// skipSpace returns a slice of the string s with all leading RFC 2616 linear
// whitespace removed.
func skipSpace(s string) (rest string) {
i := 0
for ; i < len(s); i++ {
if b := s[i]; b != ' ' && b != '\t' {
break
}
}
return s[i:]
}
// nextToken returns the leading RFC 2616 token of s and the string following
// the token.
func nextToken(s string) (token, rest string) {
i := 0
for ; i < len(s); i++ {
if !isTokenOctet[s[i]] {
break
}
}
return s[:i], s[i:]
}
// nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
// and the string following the token or quoted string.
func nextTokenOrQuoted(s string) (value string, rest string) {
if !strings.HasPrefix(s, "\"") {
return nextToken(s)
}
s = s[1:]
for i := 0; i < len(s); i++ {
switch s[i] {
case '"':
return s[:i], s[i+1:]
case '\\':
p := make([]byte, len(s)-1)
j := copy(p, s[:i])
escape := true
for i = i + 1; i < len(s); i++ {
b := s[i]
switch {
case escape:
escape = false
p[j] = b
j++
case b == '\\':
escape = true
case b == '"':
return string(p[:j]), s[i+1:]
default:
p[j] = b
j++
}
}
return "", ""
}
}
return "", ""
}
// equalASCIIFold returns true if s is equal to t with ASCII case folding as
// defined in RFC 4790.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
// tokenListContainsValue returns true if the 1#token header with the given
// name contains a token equal to value with ASCII case folding.
func tokenListContainsValue(header http.Header, name string, value string) bool {
headers:
for _, s := range header[name] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
s = skipSpace(s)
if s != "" && s[0] != ',' {
continue headers
}
if equalASCIIFold(t, value) {
return true
}
if s == "" {
continue headers
}
s = s[1:]
}
}
return false
}
// parseExtensions parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455:
//
// Sec-WebSocket-Extensions = extension-list
// extension-list = 1#extension
// extension = extension-token *( ";" extension-param )
// extension-token = registered-token
// registered-token = token
// extension-param = token [ "=" (token | quoted-string) ]
// ;When using the quoted-string syntax variant, the value
// ;after quoted-string unescaping MUST conform to the
// ;'token' ABNF.
var result []map[string]string
headers:
for _, s := range header["Sec-Websocket-Extensions"] {
for {
var t string
t, s = nextToken(skipSpace(s))
if t == "" {
continue headers
}
ext := map[string]string{"": t}
for {
s = skipSpace(s)
if !strings.HasPrefix(s, ";") {
break
}
var k string
k, s = nextToken(skipSpace(s[1:]))
if k == "" {
continue headers
}
s = skipSpace(s)
var v string
if strings.HasPrefix(s, "=") {
v, s = nextTokenOrQuoted(skipSpace(s[1:]))
s = skipSpace(s)
}
if s != "" && s[0] != ',' && s[0] != ';' {
continue headers
}
ext[k] = v
}
if s != "" && s[0] != ',' {
continue headers
}
result = append(result, ext)
if s == "" {
continue headers
}
s = s[1:]
}
}
return result
}
// isValidChallengeKey checks if the argument meets RFC6455 specification.
func isValidChallengeKey(s string) bool {
// From RFC6455:
//
// A |Sec-WebSocket-Key| header field with a base64-encoded (see
// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
// length.
if s == "" {
return false
}
decoded, err := base64.StdEncoding.DecodeString(s)
return err == nil && len(decoded) == 16
}

View File

@ -1,473 +0,0 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
// Package proxy provides support for a variety of protocols to proxy network
// data.
//
package websocket
import (
"errors"
"io"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
)
type proxy_direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var proxy_Direct = proxy_direct{}
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type proxy_PerHost struct {
def, bypass proxy_Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
return &proxy_PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *proxy_PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *proxy_PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *proxy_PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *proxy_PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// A Dialer is a means to establish a connection.
type proxy_Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type proxy_Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func proxy_FromEnvironment() proxy_Dialer {
allProxy := proxy_allProxyEnv.Get()
if len(allProxy) == 0 {
return proxy_Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return proxy_Direct
}
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
if err != nil {
return proxy_Direct
}
noProxy := proxy_noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := proxy_NewPerHost(proxy, proxy_Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
if proxy_proxySchemes == nil {
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
}
proxy_proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
var auth *proxy_Auth
if u.User != nil {
auth = new(proxy_Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxy_SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxy_proxySchemes != nil {
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
proxy_allProxyEnv = &proxy_envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
proxy_noProxyEnv = &proxy_envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type proxy_envOnce struct {
names []string
once sync.Once
val string
}
func (e *proxy_envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *proxy_envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928 and RFC 1929.
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
s := &proxy_socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxy_socks5 struct {
user, password string
network, addr string
forward proxy_Dialer
}
const proxy_socks5Version = 5
const (
proxy_socks5AuthNone = 0
proxy_socks5AuthPassword = 2
)
const proxy_socks5Connect = 1
const (
proxy_socks5IP4 = 1
proxy_socks5Domain = 3
proxy_socks5IP6 = 4
)
var proxy_socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// connect takes an existing connection to a socks5 proxy server,
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxy_socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxy_socks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxy_socks5IP4)
ip = ip4
} else {
buf = append(buf, proxy_socks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxy_socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxy_socks5Errors) {
failure = proxy_socks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case proxy_socks5IP4:
bytesToDiscard = net.IPv4len
case proxy_socks5IP6:
bytesToDiscard = net.IPv6len
case proxy_socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}

View File

@ -1,20 +0,0 @@
# go noise
*.6
*.8
*.o
*.so
*.out
*.go~
*.cgo?.*
_cgo_*
_obj
_test
_testmain.go
*.test
# Vim noise
*.swp
# Just noise
*~
*.orig

View File

@ -1,12 +0,0 @@
All code and content in this project is Copyright © 2015-2022 Go Opus Authors
Go Opus Authors and copyright holders of this package are listed below, in no
particular order. By adding yourself to this list you agree to license your
contributions under the relevant license (see the LICENSE file).
Hraban Luyat <hraban@0brg.net>
Dejian Xu <xudejian2008@gmail.com>
Tobias Wellnitz <tobias.wellnitz@gmail.com>
Elinor Natanzon <stop.start.dev@gmail.com>
Victor Gaydov <victor@enise.org>
Randy Reddig <ydnar@shaderlab.com>

View File

@ -1,19 +0,0 @@
Copyright © 2015-2022 Go Opus Authors (see AUTHORS file)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.

View File

@ -1,302 +0,0 @@
[![Test](https://github.com/hraban/opus/workflows/Test/badge.svg)](https://github.com/hraban/opus/actions?query=workflow%3ATest)
## Go wrapper for Opus
This package provides Go bindings for the xiph.org C libraries libopus and
libopusfile.
The C libraries and docs are hosted at https://opus-codec.org/. This package
just handles the wrapping in Go, and is unaffiliated with xiph.org.
Features:
- ✅ encode and decode raw PCM data to raw Opus data
- ✅ useful when you control the recording device, _and_ the playback
- ✅ decode .opus and .ogg files into raw audio data ("PCM")
- ✅ reuse the system libraries for opus decoding (libopus)
- ✅ works easily on Linux, Mac and Docker; needs libs on Windows
- ❌ does not _create_ .opus or .ogg files (but feel free to send a PR)
- ❌ does not work with .wav files (you need a separate .wav library for that)
- ❌ no self-contained binary (you need the xiph.org libopus lib, e.g. through a package manager)
- ❌ no cross compiling (because it uses CGo)
Good use cases:
- 👍 you are writing a music player app in Go, and you want to play back .opus files
- 👍 you record raw wav in a web app or mobile app, you encode it as Opus on the client, you send the opus to a remote webserver written in Go, and you want to decode it back to raw audio data on that server
## Details
This wrapper provides a Go translation layer for three elements from the
xiph.org opus libs:
* encoders
* decoders
* files & streams
### Import
```go
import "gopkg.in/hraban/opus.v2"
```
### Encoding
To encode raw audio to the Opus format, create an encoder first:
```go
const sampleRate = 48000
const channels = 1 // mono; 2 for stereo
enc, err := opus.NewEncoder(sampleRate, channels, opus.AppVoIP)
if err != nil {
...
}
```
Then pass it some raw PCM data to encode.
Make sure that the raw PCM data you want to encode has a legal Opus frame size.
This means it must be exactly 2.5, 5, 10, 20, 40 or 60 ms long. The number of
bytes this corresponds to depends on the sample rate (see the [libopus
documentation](https://www.opus-codec.org/docs/opus_api-1.1.3/group__opus__encoder.html)).
```go
var pcm []int16 = ... // obtain your raw PCM data somewhere
const bufferSize = 1000 // choose any buffer size you like. 1k is plenty.
// Check the frame size. You don't need to do this if you trust your input.
frameSize := len(pcm) // must be interleaved if stereo
frameSizeMs := float32(frameSize) / channels * 1000 / sampleRate
switch frameSizeMs {
case 2.5, 5, 10, 20, 40, 60:
// Good.
default:
return fmt.Errorf("Illegal frame size: %d bytes (%f ms)", frameSize, frameSizeMs)
}
data := make([]byte, bufferSize)
n, err := enc.Encode(pcm, data)
if err != nil {
...
}
data = data[:n] // only the first N bytes are opus data. Just like io.Reader.
```
Note that you must choose a target buffer size, and this buffer size will affect
the encoding process:
> Size of the allocated memory for the output payload. This may be used to
> impose an upper limit on the instant bitrate, but should not be used as the
> only bitrate control. Use `OPUS_SET_BITRATE` to control the bitrate.
-- https://opus-codec.org/docs/opus_api-1.1.3/group__opus__encoder.html
### Decoding
To decode opus data to raw PCM format, first create a decoder:
```go
dec, err := opus.NewDecoder(sampleRate, channels)
if err != nil {
...
}
```
Now pass it the opus bytes, and a buffer to store the PCM sound in:
```go
var frameSizeMs float32 = ... // if you don't know, go with 60 ms.
frameSize := channels * frameSizeMs * sampleRate / 1000
pcm := make([]int16, int(frameSize))
n, err := dec.Decode(data, pcm)
if err != nil {
...
}
// To get all samples (interleaved if multiple channels):
pcm = pcm[:n*channels] // only necessary if you didn't know the right frame size
// or access sample per sample, directly:
for i := 0; i < n; i++ {
ch1 := pcm[i*channels+0]
// For stereo output: copy ch1 into ch2 in mono mode, or deinterleave stereo
ch2 := pcm[(i*channels)+(channels-1)]
}
```
To handle packet loss from an unreliable network, see the
[DecodePLC](https://godoc.org/gopkg.in/hraban/opus.v2#Decoder.DecodePLC) and
[DecodeFEC](https://godoc.org/gopkg.in/hraban/opus.v2#Decoder.DecodeFEC)
options.
### Streams (and Files)
To decode a .opus file (or .ogg with Opus data), or to decode a "Opus stream"
(which is a Ogg stream with Opus data), use the `Stream` interface. It wraps an
io.Reader providing the raw stream bytes and returns the decoded Opus data.
A crude example for reading from a .opus file:
```go
f, err := os.Open(fname)
if err != nil {
...
}
s, err := opus.NewStream(f)
if err != nil {
...
}
defer s.Close()
pcmbuf := make([]int16, 16384)
for {
n, err = s.Read(pcmbuf)
if err == io.EOF {
break
} else if err != nil {
...
}
pcm := pcmbuf[:n*channels]
// send pcm to audio device here, or write to a .wav file
}
```
See https://godoc.org/gopkg.in/hraban/opus.v2#Stream for further info.
### "My .ogg/.opus file doesn't play!" or "How do I play Opus in VLC / mplayer / ...?"
Note: this package only does _encoding_ of your audio, to _raw opus data_. You can't just dump those all in one big file and play it back. You need extra info. First of all, you need to know how big each individual block is. Remember: opus data is a stream of encoded separate blocks, not one big stream of bytes. Second, you need meta-data: how many channels? What's the sampling rate? Frame size? Etc.
Look closely at the decoding sample code (not stream), above: we're passing all that meta-data in, hard-coded. If you just put all your encoded bytes in one big file and gave that to a media player, it wouldn't know what to do with it. It wouldn't even know that it's Opus data. It would just look like `/dev/random`.
What you need is a [container format](https://en.wikipedia.org/wiki/Container_format_(computing)).
Compare it to video:
* Encodings: MPEG[1234], VP9, H26[45], AV1
* Container formats: .mkv, .avi, .mov, .ogv
For Opus audio, the most common container format is OGG, aka .ogg or .opus. You'll know OGG from OGG/Vorbis: that's [Vorbis](https://xiph.org/vorbis/) encoded audio in an OGG container. So for Opus, you'd call it OGG/Opus. But technically you could stick opus data in any container format that supports it, including e.g. Matroska (.mka for audio, you probably know it from .mkv for video).
Note: libopus, the C library that this wraps, technically comes with libopusfile, which can help with the creation of OGG/Opus streams from raw audio data. I just never needed it myself, so I haven't added the necessary code for it. If you find yourself adding it: send me a PR and we'll get it merged.
This libopus wrapper _does_ come with code for _decoding_ an OGG/Opus stream. Just not for writing one.
### API Docs
Go wrapper API reference:
https://godoc.org/gopkg.in/hraban/opus.v2
Full libopus C API reference:
https://www.opus-codec.org/docs/opus_api-1.1.3/
For more examples, see the `_test.go` files.
## Build & Installation
This package requires libopus and libopusfile development packages to be
installed on your system. These are available on Debian based systems from
aptitude as `libopus-dev` and `libopusfile-dev`, and on Mac OS X from homebrew.
They are linked into the app using pkg-config.
Debian, Ubuntu, ...:
```sh
sudo apt-get install pkg-config libopus-dev libopusfile-dev
```
Mac:
```sh
brew install pkg-config opus opusfile
```
### Building Without `libopusfile`
This package can be built without `libopusfile` by using the build tag `nolibopusfile`.
This enables the compilation of statically-linked binaries with no external
dependencies on operating systems without a static `libopusfile`, such as
[Alpine Linux](https://pkgs.alpinelinux.org/contents?branch=edge&name=opusfile-dev&arch=x86_64&repo=main).
**Note:** this will disable all file and `Stream` APIs.
To enable this feature, add `-tags nolibopusfile` to your `go build` or `go test` commands:
```sh
# Build
go build -tags nolibopusfile ...
# Test
go test -tags nolibopusfile ./...
```
### Using in Docker
If your Dockerized app has this library as a dependency (directly or
indirectly), it will need to install the aforementioned packages, too.
This means you can't use the standard `golang:*-onbuild` images, because those
will try to build the app from source before allowing you to install extra
dependencies. Instead, try this as a Dockerfile:
```Dockerfile
# Choose any golang image, just make sure it doesn't have -onbuild
FROM golang:1
RUN apt-get update && apt-get -y install libopus-dev libopusfile-dev
# Everything below is copied manually from the official -onbuild image,
# with the ONBUILD keywords removed.
RUN mkdir -p /go/src/app
WORKDIR /go/src/app
CMD ["go-wrapper", "run"]
COPY . /go/src/app
RUN go-wrapper download
RUN go-wrapper install
```
For more information, see <https://hub.docker.com/_/golang/>.
### Linking libopus and libopusfile
The opus and opusfile libraries will be linked into your application
dynamically. This means everyone who uses the resulting binary will need those
libraries available on their system. E.g. if you use this wrapper to write a
music app in Go, everyone using that music app will need libopus and libopusfile
on their system. On Debian systems the packages are called `libopus0` and
`libopusfile0`.
The "cleanest" way to do this is to publish your software through a package
manager and specify libopus and libopusfile as dependencies of your program. If
that is not an option, you can compile the dynamic libraries yourself and ship
them with your software as seperate (.dll or .so) files.
On Linux, for example, you would need the libopus.so.0 and libopusfile.so.0
files in the same directory as the binary. Set your ELF binary's rpath to
`$ORIGIN` (this is not a shell variable but elf magic):
```sh
patchelf --set-origin '$ORIGIN' your-app-binary
```
Now you can run the binary and it will automatically pick up shared library
files from its own directory.
Wrap it all in a .zip, and ship.
I know there is a similar trick for Mac (involving prefixing the shared library
names with `./`, which is, arguably, better). And Windows... probably just picks
up .dll files from the same dir by default? I don't know. But there are ways.
## License
The licensing terms for the Go bindings are found in the LICENSE file. The
authors and copyright holders are listed in the AUTHORS file.
The copyright notice uses range notation to indicate all years in between are
subject to copyright, as well. This statement is necessary, apparently. For all
those nefarious actors ready to abuse a copyright notice with incorrect
notation, but thwarted by a mention in the README. Pfew!

View File

@ -1,29 +0,0 @@
// +build !nolibopusfile
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
// Allocate callback struct in C to ensure it's not managed by the Go GC. This
// plays nice with the CGo rules and avoids any confusion.
#include <opusfile.h>
#include <stdint.h>
// Defined in Go. Uses the same signature as Go, no need for proxy function.
int go_readcallback(void *p, unsigned char *buf, int nbytes);
static struct OpusFileCallbacks callbacks = {
.read = go_readcallback,
};
// Proxy function for op_open_callbacks, because it takes a void * context but
// we want to pass it non-pointer data, namely an arbitrary uintptr_t
// value. This is legal C, but go test -race (-d=checkptr) complains anyway. So
// we have this wrapper function to shush it.
// https://groups.google.com/g/golang-nuts/c/995uZyRPKlU
OggOpusFile *
my_open_callbacks(uintptr_t p, int *error)
{
return op_open_callbacks((void *)p, &callbacks, NULL, 0, error);
}

View File

@ -1,262 +0,0 @@
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
package opus
import (
"fmt"
"unsafe"
)
/*
#cgo pkg-config: opus
#include <opus.h>
int
bridge_decoder_get_last_packet_duration(OpusDecoder *st, opus_int32 *samples)
{
return opus_decoder_ctl(st, OPUS_GET_LAST_PACKET_DURATION(samples));
}
*/
import "C"
var errDecUninitialized = fmt.Errorf("opus decoder uninitialized")
type Decoder struct {
p *C.struct_OpusDecoder
// Same purpose as encoder struct
mem []byte
sample_rate int
channels int
}
// NewDecoder allocates a new Opus decoder and initializes it with the
// appropriate parameters. All related memory is managed by the Go GC.
func NewDecoder(sample_rate int, channels int) (*Decoder, error) {
var dec Decoder
err := dec.Init(sample_rate, channels)
if err != nil {
return nil, err
}
return &dec, nil
}
func (dec *Decoder) Init(sample_rate int, channels int) error {
if dec.p != nil {
return fmt.Errorf("opus decoder already initialized")
}
if channels != 1 && channels != 2 {
return fmt.Errorf("Number of channels must be 1 or 2: %d", channels)
}
size := C.opus_decoder_get_size(C.int(channels))
dec.sample_rate = sample_rate
dec.channels = channels
dec.mem = make([]byte, size)
dec.p = (*C.OpusDecoder)(unsafe.Pointer(&dec.mem[0]))
errno := C.opus_decoder_init(
dec.p,
C.opus_int32(sample_rate),
C.int(channels))
if errno != 0 {
return Error(errno)
}
return nil
}
// Decode encoded Opus data into the supplied buffer. On success, returns the
// number of samples correctly written to the target buffer.
func (dec *Decoder) Decode(data []byte, pcm []int16) (int, error) {
if dec.p == nil {
return 0, errDecUninitialized
}
if len(data) == 0 {
return 0, fmt.Errorf("opus: no data supplied")
}
if len(pcm) == 0 {
return 0, fmt.Errorf("opus: target buffer empty")
}
if cap(pcm)%dec.channels != 0 {
return 0, fmt.Errorf("opus: target buffer capacity must be multiple of channels")
}
n := int(C.opus_decode(
dec.p,
(*C.uchar)(&data[0]),
C.opus_int32(len(data)),
(*C.opus_int16)(&pcm[0]),
C.int(cap(pcm)/dec.channels),
0))
if n < 0 {
return 0, Error(n)
}
return n, nil
}
// Decode encoded Opus data into the supplied buffer. On success, returns the
// number of samples correctly written to the target buffer.
func (dec *Decoder) DecodeFloat32(data []byte, pcm []float32) (int, error) {
if dec.p == nil {
return 0, errDecUninitialized
}
if len(data) == 0 {
return 0, fmt.Errorf("opus: no data supplied")
}
if len(pcm) == 0 {
return 0, fmt.Errorf("opus: target buffer empty")
}
if cap(pcm)%dec.channels != 0 {
return 0, fmt.Errorf("opus: target buffer capacity must be multiple of channels")
}
n := int(C.opus_decode_float(
dec.p,
(*C.uchar)(&data[0]),
C.opus_int32(len(data)),
(*C.float)(&pcm[0]),
C.int(cap(pcm)/dec.channels),
0))
if n < 0 {
return 0, Error(n)
}
return n, nil
}
// DecodeFEC encoded Opus data into the supplied buffer with forward error
// correction.
//
// It is to be used on the packet directly following the lost one. The supplied
// buffer needs to be exactly the duration of audio that is missing
//
// When a packet is considered "lost", DecodeFEC can be called on the next
// packet in order to try and recover some of the lost data. The PCM needs to be
// exactly the duration of audio that is missing. `LastPacketDuration()` can be
// used on the decoder to get the length of the last packet. Note also that in
// order to use this feature the encoder needs to be configured with
// SetInBandFEC(true) and SetPacketLossPerc(x) options.
//
// Note that DecodeFEC automatically falls back to PLC when no FEC data is
// available in the provided packet.
func (dec *Decoder) DecodeFEC(data []byte, pcm []int16) error {
if dec.p == nil {
return errDecUninitialized
}
if len(data) == 0 {
return fmt.Errorf("opus: no data supplied")
}
if len(pcm) == 0 {
return fmt.Errorf("opus: target buffer empty")
}
if cap(pcm)%dec.channels != 0 {
return fmt.Errorf("opus: target buffer capacity must be multiple of channels")
}
n := int(C.opus_decode(
dec.p,
(*C.uchar)(&data[0]),
C.opus_int32(len(data)),
(*C.opus_int16)(&pcm[0]),
C.int(cap(pcm)/dec.channels),
1))
if n < 0 {
return Error(n)
}
return nil
}
// DecodeFECFloat32 encoded Opus data into the supplied buffer with forward error
// correction. It is to be used on the packet directly following the lost one.
// The supplied buffer needs to be exactly the duration of audio that is missing
func (dec *Decoder) DecodeFECFloat32(data []byte, pcm []float32) error {
if dec.p == nil {
return errDecUninitialized
}
if len(data) == 0 {
return fmt.Errorf("opus: no data supplied")
}
if len(pcm) == 0 {
return fmt.Errorf("opus: target buffer empty")
}
if cap(pcm)%dec.channels != 0 {
return fmt.Errorf("opus: target buffer capacity must be multiple of channels")
}
n := int(C.opus_decode_float(
dec.p,
(*C.uchar)(&data[0]),
C.opus_int32(len(data)),
(*C.float)(&pcm[0]),
C.int(cap(pcm)/dec.channels),
1))
if n < 0 {
return Error(n)
}
return nil
}
// DecodePLC recovers a lost packet using Opus Packet Loss Concealment feature.
//
// The supplied buffer needs to be exactly the duration of audio that is missing.
// When a packet is considered "lost", `DecodePLC` and `DecodePLCFloat32` methods
// can be called in order to obtain something better sounding than just silence.
// The PCM needs to be exactly the duration of audio that is missing.
// `LastPacketDuration()` can be used on the decoder to get the length of the
// last packet.
//
// This option does not require any additional encoder options. Unlike FEC,
// PLC does not introduce additional latency. It is calculated from the previous
// packet, not from the next one.
func (dec *Decoder) DecodePLC(pcm []int16) error {
if dec.p == nil {
return errDecUninitialized
}
if len(pcm) == 0 {
return fmt.Errorf("opus: target buffer empty")
}
if cap(pcm)%dec.channels != 0 {
return fmt.Errorf("opus: output buffer capacity must be multiple of channels")
}
n := int(C.opus_decode(
dec.p,
nil,
0,
(*C.opus_int16)(&pcm[0]),
C.int(cap(pcm)/dec.channels),
0))
if n < 0 {
return Error(n)
}
return nil
}
// DecodePLCFloat32 recovers a lost packet using Opus Packet Loss Concealment feature.
// The supplied buffer needs to be exactly the duration of audio that is missing.
func (dec *Decoder) DecodePLCFloat32(pcm []float32) error {
if dec.p == nil {
return errDecUninitialized
}
if len(pcm) == 0 {
return fmt.Errorf("opus: target buffer empty")
}
if cap(pcm)%dec.channels != 0 {
return fmt.Errorf("opus: output buffer capacity must be multiple of channels")
}
n := int(C.opus_decode_float(
dec.p,
nil,
0,
(*C.float)(&pcm[0]),
C.int(cap(pcm)/dec.channels),
0))
if n < 0 {
return Error(n)
}
return nil
}
// LastPacketDuration gets the duration (in samples)
// of the last packet successfully decoded or concealed.
func (dec *Decoder) LastPacketDuration() (int, error) {
var samples C.opus_int32
res := C.bridge_decoder_get_last_packet_duration(dec.p, &samples)
if res != C.OPUS_OK {
return 0, Error(res)
}
return int(samples), nil
}

View File

@ -1,402 +0,0 @@
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
package opus
import (
"fmt"
"unsafe"
)
/*
#cgo pkg-config: opus
#include <opus.h>
int
bridge_encoder_set_dtx(OpusEncoder *st, opus_int32 use_dtx)
{
return opus_encoder_ctl(st, OPUS_SET_DTX(use_dtx));
}
int
bridge_encoder_get_dtx(OpusEncoder *st, opus_int32 *dtx)
{
return opus_encoder_ctl(st, OPUS_GET_DTX(dtx));
}
int
bridge_encoder_get_in_dtx(OpusEncoder *st, opus_int32 *in_dtx)
{
return opus_encoder_ctl(st, OPUS_GET_IN_DTX(in_dtx));
}
int
bridge_encoder_get_sample_rate(OpusEncoder *st, opus_int32 *sample_rate)
{
return opus_encoder_ctl(st, OPUS_GET_SAMPLE_RATE(sample_rate));
}
int
bridge_encoder_set_bitrate(OpusEncoder *st, opus_int32 bitrate)
{
return opus_encoder_ctl(st, OPUS_SET_BITRATE(bitrate));
}
int
bridge_encoder_get_bitrate(OpusEncoder *st, opus_int32 *bitrate)
{
return opus_encoder_ctl(st, OPUS_GET_BITRATE(bitrate));
}
int
bridge_encoder_set_complexity(OpusEncoder *st, opus_int32 complexity)
{
return opus_encoder_ctl(st, OPUS_SET_COMPLEXITY(complexity));
}
int
bridge_encoder_get_complexity(OpusEncoder *st, opus_int32 *complexity)
{
return opus_encoder_ctl(st, OPUS_GET_COMPLEXITY(complexity));
}
int
bridge_encoder_set_max_bandwidth(OpusEncoder *st, opus_int32 max_bw)
{
return opus_encoder_ctl(st, OPUS_SET_MAX_BANDWIDTH(max_bw));
}
int
bridge_encoder_get_max_bandwidth(OpusEncoder *st, opus_int32 *max_bw)
{
return opus_encoder_ctl(st, OPUS_GET_MAX_BANDWIDTH(max_bw));
}
int
bridge_encoder_set_inband_fec(OpusEncoder *st, opus_int32 fec)
{
return opus_encoder_ctl(st, OPUS_SET_INBAND_FEC(fec));
}
int
bridge_encoder_get_inband_fec(OpusEncoder *st, opus_int32 *fec)
{
return opus_encoder_ctl(st, OPUS_GET_INBAND_FEC(fec));
}
int
bridge_encoder_set_packet_loss_perc(OpusEncoder *st, opus_int32 loss_perc)
{
return opus_encoder_ctl(st, OPUS_SET_PACKET_LOSS_PERC(loss_perc));
}
int
bridge_encoder_get_packet_loss_perc(OpusEncoder *st, opus_int32 *loss_perc)
{
return opus_encoder_ctl(st, OPUS_GET_PACKET_LOSS_PERC(loss_perc));
}
int
bridge_encoder_reset_state(OpusEncoder *st)
{
return opus_encoder_ctl(st, OPUS_RESET_STATE);
}
*/
import "C"
type Bandwidth int
const (
// 4 kHz passband
Narrowband = Bandwidth(C.OPUS_BANDWIDTH_NARROWBAND)
// 6 kHz passband
Mediumband = Bandwidth(C.OPUS_BANDWIDTH_MEDIUMBAND)
// 8 kHz passband
Wideband = Bandwidth(C.OPUS_BANDWIDTH_WIDEBAND)
// 12 kHz passband
SuperWideband = Bandwidth(C.OPUS_BANDWIDTH_SUPERWIDEBAND)
// 20 kHz passband
Fullband = Bandwidth(C.OPUS_BANDWIDTH_FULLBAND)
)
var errEncUninitialized = fmt.Errorf("opus encoder uninitialized")
// Encoder contains the state of an Opus encoder for libopus.
type Encoder struct {
p *C.struct_OpusEncoder
channels int
// Memory for the encoder struct allocated on the Go heap to allow Go GC to
// manage it (and obviate need to free())
mem []byte
}
// NewEncoder allocates a new Opus encoder and initializes it with the
// appropriate parameters. All related memory is managed by the Go GC.
func NewEncoder(sample_rate int, channels int, application Application) (*Encoder, error) {
var enc Encoder
err := enc.Init(sample_rate, channels, application)
if err != nil {
return nil, err
}
return &enc, nil
}
// Init initializes a pre-allocated opus encoder. Unless the encoder has been
// created using NewEncoder, this method must be called exactly once in the
// life-time of this object, before calling any other methods.
func (enc *Encoder) Init(sample_rate int, channels int, application Application) error {
if enc.p != nil {
return fmt.Errorf("opus encoder already initialized")
}
if channels != 1 && channels != 2 {
return fmt.Errorf("Number of channels must be 1 or 2: %d", channels)
}
size := C.opus_encoder_get_size(C.int(channels))
enc.channels = channels
enc.mem = make([]byte, size)
enc.p = (*C.OpusEncoder)(unsafe.Pointer(&enc.mem[0]))
errno := int(C.opus_encoder_init(
enc.p,
C.opus_int32(sample_rate),
C.int(channels),
C.int(application)))
if errno != 0 {
return Error(int(errno))
}
return nil
}
// Encode raw PCM data and store the result in the supplied buffer. On success,
// returns the number of bytes used up by the encoded data.
func (enc *Encoder) Encode(pcm []int16, data []byte) (int, error) {
if enc.p == nil {
return 0, errEncUninitialized
}
if len(pcm) == 0 {
return 0, fmt.Errorf("opus: no data supplied")
}
if len(data) == 0 {
return 0, fmt.Errorf("opus: no target buffer")
}
// libopus talks about samples as 1 sample containing multiple channels. So
// e.g. 20 samples of 2-channel data is actually 40 raw data points.
if len(pcm)%enc.channels != 0 {
return 0, fmt.Errorf("opus: input buffer length must be multiple of channels")
}
samples := len(pcm) / enc.channels
n := int(C.opus_encode(
enc.p,
(*C.opus_int16)(&pcm[0]),
C.int(samples),
(*C.uchar)(&data[0]),
C.opus_int32(cap(data))))
if n < 0 {
return 0, Error(n)
}
return n, nil
}
// Encode raw PCM data and store the result in the supplied buffer. On success,
// returns the number of bytes used up by the encoded data.
func (enc *Encoder) EncodeFloat32(pcm []float32, data []byte) (int, error) {
if enc.p == nil {
return 0, errEncUninitialized
}
if len(pcm) == 0 {
return 0, fmt.Errorf("opus: no data supplied")
}
if len(data) == 0 {
return 0, fmt.Errorf("opus: no target buffer")
}
if len(pcm)%enc.channels != 0 {
return 0, fmt.Errorf("opus: input buffer length must be multiple of channels")
}
samples := len(pcm) / enc.channels
n := int(C.opus_encode_float(
enc.p,
(*C.float)(&pcm[0]),
C.int(samples),
(*C.uchar)(&data[0]),
C.opus_int32(cap(data))))
if n < 0 {
return 0, Error(n)
}
return n, nil
}
// SetDTX configures the encoder's use of discontinuous transmission (DTX).
func (enc *Encoder) SetDTX(dtx bool) error {
i := 0
if dtx {
i = 1
}
res := C.bridge_encoder_set_dtx(enc.p, C.opus_int32(i))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// DTX reports whether this encoder is configured to use discontinuous
// transmission (DTX).
func (enc *Encoder) DTX() (bool, error) {
var dtx C.opus_int32
res := C.bridge_encoder_get_dtx(enc.p, &dtx)
if res != C.OPUS_OK {
return false, Error(res)
}
return dtx != 0, nil
}
// InDTX returns whether the last encoded frame was either a comfort noise update
// during DTX or not encoded because of DTX.
func (enc *Encoder) InDTX() (bool, error) {
var inDTX C.opus_int32
res := C.bridge_encoder_get_in_dtx(enc.p, &inDTX)
if res != C.OPUS_OK {
return false, Error(res)
}
return inDTX != 0, nil
}
// SampleRate returns the encoder sample rate in Hz.
func (enc *Encoder) SampleRate() (int, error) {
var sr C.opus_int32
res := C.bridge_encoder_get_sample_rate(enc.p, &sr)
if res != C.OPUS_OK {
return 0, Error(res)
}
return int(sr), nil
}
// SetBitrate sets the bitrate of the Encoder
func (enc *Encoder) SetBitrate(bitrate int) error {
res := C.bridge_encoder_set_bitrate(enc.p, C.opus_int32(bitrate))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// SetBitrateToAuto will allow the encoder to automatically set the bitrate
func (enc *Encoder) SetBitrateToAuto() error {
res := C.bridge_encoder_set_bitrate(enc.p, C.opus_int32(C.OPUS_AUTO))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// SetBitrateToMax causes the encoder to use as much rate as it can. This can be
// useful for controlling the rate by adjusting the output buffer size.
func (enc *Encoder) SetBitrateToMax() error {
res := C.bridge_encoder_set_bitrate(enc.p, C.opus_int32(C.OPUS_BITRATE_MAX))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// Bitrate returns the bitrate of the Encoder
func (enc *Encoder) Bitrate() (int, error) {
var bitrate C.opus_int32
res := C.bridge_encoder_get_bitrate(enc.p, &bitrate)
if res != C.OPUS_OK {
return 0, Error(res)
}
return int(bitrate), nil
}
// SetComplexity sets the encoder's computational complexity
func (enc *Encoder) SetComplexity(complexity int) error {
res := C.bridge_encoder_set_complexity(enc.p, C.opus_int32(complexity))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// Complexity returns the computational complexity used by the encoder
func (enc *Encoder) Complexity() (int, error) {
var complexity C.opus_int32
res := C.bridge_encoder_get_complexity(enc.p, &complexity)
if res != C.OPUS_OK {
return 0, Error(res)
}
return int(complexity), nil
}
// SetMaxBandwidth configures the maximum bandpass that the encoder will select
// automatically
func (enc *Encoder) SetMaxBandwidth(maxBw Bandwidth) error {
res := C.bridge_encoder_set_max_bandwidth(enc.p, C.opus_int32(maxBw))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// MaxBandwidth gets the encoder's configured maximum allowed bandpass.
func (enc *Encoder) MaxBandwidth() (Bandwidth, error) {
var maxBw C.opus_int32
res := C.bridge_encoder_get_max_bandwidth(enc.p, &maxBw)
if res != C.OPUS_OK {
return 0, Error(res)
}
return Bandwidth(maxBw), nil
}
// SetInBandFEC configures the encoder's use of inband forward error
// correction (FEC)
func (enc *Encoder) SetInBandFEC(fec bool) error {
i := 0
if fec {
i = 1
}
res := C.bridge_encoder_set_inband_fec(enc.p, C.opus_int32(i))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// InBandFEC gets the encoder's configured inband forward error correction (FEC)
func (enc *Encoder) InBandFEC() (bool, error) {
var fec C.opus_int32
res := C.bridge_encoder_get_inband_fec(enc.p, &fec)
if res != C.OPUS_OK {
return false, Error(res)
}
return fec != 0, nil
}
// SetPacketLossPerc configures the encoder's expected packet loss percentage.
func (enc *Encoder) SetPacketLossPerc(lossPerc int) error {
res := C.bridge_encoder_set_packet_loss_perc(enc.p, C.opus_int32(lossPerc))
if res != C.OPUS_OK {
return Error(res)
}
return nil
}
// PacketLossPerc gets the encoder's configured packet loss percentage.
func (enc *Encoder) PacketLossPerc() (int, error) {
var lossPerc C.opus_int32
res := C.bridge_encoder_get_packet_loss_perc(enc.p, &lossPerc)
if res != C.OPUS_OK {
return 0, Error(res)
}
return int(lossPerc), nil
}
// Reset resets the codec state to be equivalent to a freshly initialized state.
func (enc *Encoder) Reset() error {
res := C.bridge_encoder_reset_state(enc.p)
if res != C.OPUS_OK {
return Error(res)
}
return nil
}

View File

@ -1,36 +0,0 @@
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
package opus
import (
"fmt"
)
/*
#cgo pkg-config: opus
#include <opus.h>
*/
import "C"
type Error int
var _ error = Error(0)
// Libopus errors
const (
ErrOK = Error(C.OPUS_OK)
ErrBadArg = Error(C.OPUS_BAD_ARG)
ErrBufferTooSmall = Error(C.OPUS_BUFFER_TOO_SMALL)
ErrInternalError = Error(C.OPUS_INTERNAL_ERROR)
ErrInvalidPacket = Error(C.OPUS_INVALID_PACKET)
ErrUnimplemented = Error(C.OPUS_UNIMPLEMENTED)
ErrInvalidState = Error(C.OPUS_INVALID_STATE)
ErrAllocFail = Error(C.OPUS_ALLOC_FAIL)
)
// Error string (in human readable format) for libopus errors.
func (e Error) Error() string {
return fmt.Sprintf("opus: %s", C.GoString(C.opus_strerror(C.int(e))))
}

View File

@ -1,36 +0,0 @@
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
package opus
/*
// Link opus using pkg-config.
#cgo pkg-config: opus
#include <opus.h>
*/
import "C"
type Application int
const (
// Optimize encoding for VoIP
AppVoIP = Application(C.OPUS_APPLICATION_VOIP)
// Optimize encoding for non-voice signals like music
AppAudio = Application(C.OPUS_APPLICATION_AUDIO)
// Optimize encoding for low latency applications
AppRestrictedLowdelay = Application(C.OPUS_APPLICATION_RESTRICTED_LOWDELAY)
)
const (
xMAX_BITRATE = 48000
xMAX_FRAME_SIZE_MS = 60
xMAX_FRAME_SIZE = xMAX_BITRATE * xMAX_FRAME_SIZE_MS / 1000
// Maximum size of an encoded frame. I actually have no idea, but this
// looks like it's big enough.
maxEncodedFrameSize = 10000
)
func Version() string {
return C.GoString(C.opus_get_version_string())
}

View File

@ -1,183 +0,0 @@
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
// +build !nolibopusfile
package opus
import (
"fmt"
"io"
"unsafe"
)
/*
#cgo pkg-config: opusfile
#include <opusfile.h>
#include <stdint.h>
#include <string.h>
OggOpusFile *my_open_callbacks(uintptr_t p, int *error);
*/
import "C"
// Stream wraps a io.Reader in a decoding layer. It provides an API similar to
// io.Reader, but it provides raw PCM data instead of the encoded Opus data.
//
// This is not the same as directly decoding the bytes on the io.Reader; opus
// streams are Ogg Opus audio streams, which package raw Opus data.
//
// This wraps libopusfile. For more information, see the api docs on xiph.org:
//
// https://www.opus-codec.org/docs/opusfile_api-0.7/index.html
type Stream struct {
id uintptr
oggfile *C.OggOpusFile
read io.Reader
// Preallocated buffer to pass to the reader
buf []byte
}
var streams = newStreamsMap()
//export go_readcallback
func go_readcallback(p unsafe.Pointer, cbuf *C.uchar, cmaxbytes C.int) C.int {
streamId := uintptr(p)
stream := streams.Get(streamId)
if stream == nil {
// This is bad
return -1
}
maxbytes := int(cmaxbytes)
if maxbytes > cap(stream.buf) {
maxbytes = cap(stream.buf)
}
// Don't bother cleaning up old data because that's not required by the
// io.Reader API.
n, err := stream.read.Read(stream.buf[:maxbytes])
// Go allows returning non-nil error (like EOF) and n>0, libopusfile doesn't
// expect that. So return n first to indicate the valid bytes, let the
// subsequent call (which will be n=0, same-error) handle the actual error.
if n == 0 && err != nil {
if err == io.EOF {
return 0
} else {
return -1
}
}
C.memcpy(unsafe.Pointer(cbuf), unsafe.Pointer(&stream.buf[0]), C.size_t(n))
return C.int(n)
}
// NewStream creates and initializes a new stream. Don't call .Init() on this.
func NewStream(read io.Reader) (*Stream, error) {
var s Stream
err := s.Init(read)
if err != nil {
return nil, err
}
return &s, nil
}
// Init initializes a stream with an io.Reader to fetch opus encoded data from
// on demand. Errors from the reader are all transformed to an EOF, any actual
// error information is lost. The same happens when a read returns succesfully,
// but with zero bytes.
func (s *Stream) Init(read io.Reader) error {
if s.oggfile != nil {
return fmt.Errorf("opus stream is already initialized")
}
if read == nil {
return fmt.Errorf("Reader must be non-nil")
}
s.read = read
s.buf = make([]byte, maxEncodedFrameSize)
s.id = streams.NextId()
var errno C.int
// Immediately delete the stream after .Init to avoid leaking if the
// caller forgets to (/ doesn't want to) call .Close(). No need for that,
// since the callback is only ever called during a .Read operation; just
// Save and Delete from the map around that every time a reader function is
// called.
streams.Save(s)
defer streams.Del(s)
oggfile := C.my_open_callbacks(C.uintptr_t(s.id), &errno)
if errno != 0 {
return StreamError(errno)
}
s.oggfile = oggfile
return nil
}
// Read a chunk of raw opus data from the stream and decode it. Returns the
// number of decoded samples per channel. This means that a dual channel
// (stereo) feed will have twice as many samples as the value returned.
//
// Read may successfully read less bytes than requested, but it will never read
// exactly zero bytes succesfully if a non-zero buffer is supplied.
//
// The number of channels in the output data must be known in advance. It is
// possible to extract this information from the stream itself, but I'm not
// motivated to do that. Feel free to send a pull request.
func (s *Stream) Read(pcm []int16) (int, error) {
if s.oggfile == nil {
return 0, fmt.Errorf("opus stream is uninitialized or already closed")
}
if len(pcm) == 0 {
return 0, nil
}
streams.Save(s)
defer streams.Del(s)
n := C.op_read(
s.oggfile,
(*C.opus_int16)(&pcm[0]),
C.int(len(pcm)),
nil)
if n < 0 {
return 0, StreamError(n)
}
if n == 0 {
return 0, io.EOF
}
return int(n), nil
}
// ReadFloat32 is the same as Read, but decodes to float32 instead of int16.
func (s *Stream) ReadFloat32(pcm []float32) (int, error) {
if s.oggfile == nil {
return 0, fmt.Errorf("opus stream is uninitialized or already closed")
}
if len(pcm) == 0 {
return 0, nil
}
streams.Save(s)
defer streams.Del(s)
n := C.op_read_float(
s.oggfile,
(*C.float)(&pcm[0]),
C.int(len(pcm)),
nil)
if n < 0 {
return 0, StreamError(n)
}
if n == 0 {
return 0, io.EOF
}
return int(n), nil
}
func (s *Stream) Close() error {
if s.oggfile == nil {
return fmt.Errorf("opus stream is uninitialized or already closed")
}
C.op_free(s.oggfile)
if closer, ok := s.read.(io.Closer); ok {
return closer.Close()
}
return nil
}

View File

@ -1,75 +0,0 @@
// Copyright © 2015-2017 Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
// +build !nolibopusfile
package opus
/*
#cgo pkg-config: opusfile
#include <opusfile.h>
*/
import "C"
// StreamError represents an error from libopusfile.
type StreamError int
var _ error = StreamError(0)
// Libopusfile errors. The names are copied verbatim from the libopusfile
// library.
const (
ErrStreamFalse = StreamError(C.OP_FALSE)
ErrStreamEOF = StreamError(C.OP_EOF)
ErrStreamHole = StreamError(C.OP_HOLE)
ErrStreamRead = StreamError(C.OP_EREAD)
ErrStreamFault = StreamError(C.OP_EFAULT)
ErrStreamImpl = StreamError(C.OP_EIMPL)
ErrStreamInval = StreamError(C.OP_EINVAL)
ErrStreamNotFormat = StreamError(C.OP_ENOTFORMAT)
ErrStreamBadHeader = StreamError(C.OP_EBADHEADER)
ErrStreamVersion = StreamError(C.OP_EVERSION)
ErrStreamNotAudio = StreamError(C.OP_ENOTAUDIO)
ErrStreamBadPacked = StreamError(C.OP_EBADPACKET)
ErrStreamBadLink = StreamError(C.OP_EBADLINK)
ErrStreamNoSeek = StreamError(C.OP_ENOSEEK)
ErrStreamBadTimestamp = StreamError(C.OP_EBADTIMESTAMP)
)
func (i StreamError) Error() string {
switch i {
case ErrStreamFalse:
return "OP_FALSE"
case ErrStreamEOF:
return "OP_EOF"
case ErrStreamHole:
return "OP_HOLE"
case ErrStreamRead:
return "OP_EREAD"
case ErrStreamFault:
return "OP_EFAULT"
case ErrStreamImpl:
return "OP_EIMPL"
case ErrStreamInval:
return "OP_EINVAL"
case ErrStreamNotFormat:
return "OP_ENOTFORMAT"
case ErrStreamBadHeader:
return "OP_EBADHEADER"
case ErrStreamVersion:
return "OP_EVERSION"
case ErrStreamNotAudio:
return "OP_ENOTAUDIO"
case ErrStreamBadPacked:
return "OP_EBADPACKET"
case ErrStreamBadLink:
return "OP_EBADLINK"
case ErrStreamNoSeek:
return "OP_ENOSEEK"
case ErrStreamBadTimestamp:
return "OP_EBADTIMESTAMP"
default:
return "libopusfile error: %d (unknown code)"
}
}

View File

@ -1,64 +0,0 @@
// Copyright © Go Opus Authors (see AUTHORS file)
//
// License for use of this code is detailed in the LICENSE file
// +build !nolibopusfile
package opus
import (
"sync"
"sync/atomic"
)
// A map of simple integers to the actual pointers to stream structs. Avoids
// passing pointers into the Go heap to C.
//
// As per the CGo pointers design doc for go 1.6:
//
// A particular unsafe area is C code that wants to hold on to Go func and
// pointer values for future callbacks from C to Go. This works today but is not
// permitted by the invariant. It is hard to detect. One safe approach is: Go
// code that wants to preserve funcs/pointers stores them into a map indexed by
// an int. Go code calls the C code, passing the int, which the C code may store
// freely. When the C code wants to call into Go, it passes the int to a Go
// function that looks in the map and makes the call. An explicit call is
// required to release the value from the map if it is no longer needed, but
// that was already true before.
//
// - https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
type streamsMap struct {
sync.RWMutex
m map[uintptr]*Stream
counter uintptr
}
func (sm *streamsMap) Get(id uintptr) *Stream {
sm.RLock()
defer sm.RUnlock()
return sm.m[id]
}
func (sm *streamsMap) Del(s *Stream) {
sm.Lock()
defer sm.Unlock()
delete(sm.m, s.id)
}
// NextId returns a unique ID for each call.
func (sm *streamsMap) NextId() uintptr {
return atomic.AddUintptr(&sm.counter, 1)
}
func (sm *streamsMap) Save(s *Stream) {
sm.Lock()
defer sm.Unlock()
sm.m[s.id] = s
}
func newStreamsMap() *streamsMap {
return &streamsMap{
counter: 0,
m: map[uintptr]*Stream{},
}
}

View File

@ -1,6 +0,0 @@
# github.com/gorilla/websocket v1.5.3
## explicit; go 1.12
github.com/gorilla/websocket
# github.com/hraban/opus v0.0.0-20230925203106-0188a62cb302
## explicit
github.com/hraban/opus

Some files were not shown because too many files have changed in this diff Show More