From a3222d1fe5dc4475576f901baecd5f66069156a5 Mon Sep 17 00:00:00 2001 From: repair-agent Date: Tue, 3 Mar 2026 16:22:16 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8E=8B=E6=B5=8B=E5=B7=A5=E5=85=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/devices/views.py | 48 ++-- hw_service_go/test/stress/go.mod | 5 + hw_service_go/test/stress/go.sum | 2 + hw_service_go/test/stress/main.go | 379 ++++++++++++++++++++++++++++++ hw_service_go/test/test.html | 19 +- 5 files changed, 424 insertions(+), 29 deletions(-) create mode 100644 hw_service_go/test/stress/go.mod create mode 100644 hw_service_go/test/stress/go.sum create mode 100644 hw_service_go/test/stress/main.go diff --git a/apps/devices/views.py b/apps/devices/views.py index d9acb7a..092325c 100644 --- a/apps/devices/views.py +++ b/apps/devices/views.py @@ -324,38 +324,30 @@ class DeviceViewSet(viewsets.ViewSet): mac = mac.upper().replace('-', ':') + from apps.stories.models import Story + story = None + + # 1. 尝试查找设备 → 绑定用户 → 用户故事 try: device = Device.objects.get(mac_address=mac) + user_device = ( + UserDevice.objects + .filter(device=device, is_active=True, bind_type='owner') + .select_related('user') + .first() + ) + if user_device: + story = ( + Story.objects + .filter(user=user_device.user) + .exclude(audio_url='') + .order_by('?') + .first() + ) except Device.DoesNotExist: - return error( - code=ErrorCode.DEVICE_NOT_FOUND, - message='未找到对应设备', - status_code=status.HTTP_404_NOT_FOUND - ) + pass - user_device = ( - UserDevice.objects - .filter(device=device, is_active=True, bind_type='owner') - .select_related('user') - .first() - ) - if not user_device: - return error( - code=ErrorCode.NOT_FOUND, - message='该设备尚未绑定用户', - status_code=status.HTTP_404_NOT_FOUND - ) - - from apps.stories.models import Story - # 优先随机取用户自己有 audio_url 的故事 - story = ( - Story.objects - .filter(user=user_device.user) - .exclude(audio_url='') - .order_by('?') - .first() - ) - # 兜底:用户暂无故事时使用系统默认故事 + # 2. 兜底:设备不存在/未绑定/用户无故事 → 使用系统默认故事 if not story: story = ( Story.objects diff --git a/hw_service_go/test/stress/go.mod b/hw_service_go/test/stress/go.mod new file mode 100644 index 0000000..315a80d --- /dev/null +++ b/hw_service_go/test/stress/go.mod @@ -0,0 +1,5 @@ +module stress + +go 1.23 + +require github.com/gorilla/websocket v1.5.3 diff --git a/hw_service_go/test/stress/go.sum b/hw_service_go/test/stress/go.sum new file mode 100644 index 0000000..25a9fc4 --- /dev/null +++ b/hw_service_go/test/stress/go.sum @@ -0,0 +1,2 @@ +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/hw_service_go/test/stress/main.go b/hw_service_go/test/stress/main.go new file mode 100644 index 0000000..6a5a03b --- /dev/null +++ b/hw_service_go/test/stress/main.go @@ -0,0 +1,379 @@ +// hw_service_go 并发压力测试工具 +// +// 用法: +// +// go run main.go -conns 100 -stories 0 # 100 个空闲连接 +// go run main.go -conns 50 -stories 10 # 50 连接,10 个触发故事 +// go run main.go -url wss://example.com/xiaozhi/v1/ -conns 50 +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "os/signal" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/gorilla/websocket" +) + +// ── 命令行参数 ───────────────────────────────────────────────── + +var ( + flagURL = flag.String("url", "ws://localhost:8888/xiaozhi/v1/", "WebSocket 服务地址") + flagConns = flag.Int("conns", 100, "总连接数") + flagStories = flag.Int("stories", 10, "同时触发故事的连接数") + flagRamp = flag.Int("ramp", 20, "每秒建立的连接数") + flagDuration = flag.Duration("duration", 60*time.Second, "测试持续时间") + flagMACPrefix = flag.String("mac-prefix", "AA:BB:CC:DD", "模拟 MAC 前缀") +) + +// ── 统计指标(原子操作,goroutine 安全) ────────────────────── + +type stats struct { + connAttempts atomic.Int64 + connSuccess atomic.Int64 + connFailed atomic.Int64 + handshaked atomic.Int64 + handshakeFail atomic.Int64 + storySent atomic.Int64 + ttsStart atomic.Int64 + ttsStop atomic.Int64 + opusFrames atomic.Int64 + errors atomic.Int64 + firstFrameNs atomic.Int64 // 所有设备首帧延迟总和(纳秒),用于算均值 + firstFrameCnt atomic.Int64 // 收到首帧的设备数 +} + +var s stats + +// ── 模拟设备 ──────────────────────────────────────────────── + +type device struct { + id int + mac string + clientID string + ws *websocket.Conn + triggerStory bool +} + +func newDevice(id int, macPrefix string, triggerStory bool) *device { + hi := byte((id >> 8) & 0xFF) + lo := byte(id & 0xFF) + mac := fmt.Sprintf("%s:%02X:%02X", macPrefix, hi, lo) + return &device{ + id: id, + mac: mac, + clientID: fmt.Sprintf("stress-%d", id), + triggerStory: triggerStory, + } +} + +func (d *device) run(baseURL string, wg *sync.WaitGroup, done <-chan struct{}) { + defer wg.Done() + + s.connAttempts.Add(1) + + // 1. 建立 WebSocket 连接 + u, err := url.Parse(baseURL) + if err != nil { + log.Printf("[dev-%d] invalid URL: %v", d.id, err) + s.connFailed.Add(1) + return + } + q := u.Query() + q.Set("device-id", d.mac) + q.Set("client-id", d.clientID) + u.RawQuery = q.Encode() + + dialer := websocket.Dialer{ + HandshakeTimeout: 10 * time.Second, + } + ws, _, err := dialer.Dial(u.String(), nil) + if err != nil { + log.Printf("[dev-%d] connect failed: %v", d.id, err) + s.connFailed.Add(1) + return + } + d.ws = ws + s.connSuccess.Add(1) + defer ws.Close() + + // 2. 发送 hello 握手 + helloMsg, _ := json.Marshal(map[string]string{ + "type": "hello", + "mac": d.mac, + }) + if err := ws.WriteMessage(websocket.TextMessage, helloMsg); err != nil { + log.Printf("[dev-%d] hello send failed: %v", d.id, err) + s.handshakeFail.Add(1) + return + } + + // 3. 等待 hello 响应(5s 超时) + ws.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, msg, err := ws.ReadMessage() + if err != nil { + log.Printf("[dev-%d] hello read failed: %v", d.id, err) + s.handshakeFail.Add(1) + return + } + ws.SetReadDeadline(time.Time{}) // 清除超时 + + var helloResp struct { + Type string `json:"type"` + Status string `json:"status"` + } + if err := json.Unmarshal(msg, &helloResp); err != nil || helloResp.Type != "hello" || helloResp.Status != "ok" { + log.Printf("[dev-%d] hello failed: %s", d.id, string(msg)) + s.handshakeFail.Add(1) + return + } + s.handshaked.Add(1) + + // 4. 如果被选为活跃设备,触发故事 + var storySentTime time.Time + var gotFirstFrame bool + + if d.triggerStory { + storyMsg, _ := json.Marshal(map[string]string{"type": "story"}) + if err := ws.WriteMessage(websocket.TextMessage, storyMsg); err != nil { + log.Printf("[dev-%d] story send failed: %v", d.id, err) + s.errors.Add(1) + } else { + s.storySent.Add(1) + storySentTime = time.Now() + } + } + + // 5. 消息接收循环 + msgCh := make(chan struct{}, 1) // 用于通知有新消息 + go func() { + for { + msgType, data, err := ws.ReadMessage() + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + select { + case <-done: + // 正常关闭,不算错误 + default: + s.errors.Add(1) + } + } + return + } + + if msgType == websocket.BinaryMessage { + // Opus 帧 + s.opusFrames.Add(1) + if d.triggerStory && !gotFirstFrame && !storySentTime.IsZero() { + gotFirstFrame = true + latency := time.Since(storySentTime) + s.firstFrameNs.Add(latency.Nanoseconds()) + s.firstFrameCnt.Add(1) + } + _ = data // 不需要解码,只计数 + } else { + // 文本消息 + var envelope struct { + Type string `json:"type"` + State string `json:"state"` + } + if json.Unmarshal(data, &envelope) == nil { + if envelope.Type == "tts" { + switch envelope.State { + case "start": + s.ttsStart.Add(1) + case "stop": + s.ttsStop.Add(1) + } + } + } + } + + select { + case msgCh <- struct{}{}: + default: + } + } + }() + + // 6. 等待测试结束 + <-done + ws.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), + ) +} + +// ── healthz 查询 ───────────────────────────────────────────── + +func queryHealthz(baseURL string) string { + // 从 ws:// URL 推导 http:// URL + u, err := url.Parse(baseURL) + if err != nil { + return "N/A" + } + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + } + // 去掉 /xiaozhi/v1/ 路径,换成 /healthz + u.Path = "/healthz" + u.RawQuery = "" + + client := &http.Client{Timeout: 3 * time.Second} + resp, err := client.Get(u.String()) + if err != nil { + return "N/A" + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return strings.TrimSpace(string(body)) +} + +// ── 主函数 ────────────────────────────────────────────────── + +func main() { + flag.Parse() + + if *flagStories > *flagConns { + *flagStories = *flagConns + } + + fmt.Println("========================================") + fmt.Println(" hw_service_go 并发压力测试") + fmt.Println("========================================") + fmt.Printf(" 目标地址: %s\n", *flagURL) + fmt.Printf(" 总连接数: %d\n", *flagConns) + fmt.Printf(" 触发故事: %d\n", *flagStories) + fmt.Printf(" 建连速率: %d/s\n", *flagRamp) + fmt.Printf(" 测试时长: %s\n", *flagDuration) + fmt.Printf(" MAC 前缀: %s\n", *flagMACPrefix) + fmt.Println("========================================") + fmt.Println() + + done := make(chan struct{}) + var wg sync.WaitGroup + + // 信号处理:Ctrl+C 提前结束 + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + fmt.Println("\n收到退出信号,正在停止...") + close(done) + }() + + // 实时统计输出 + startTime := time.Now() + go func() { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + elapsed := time.Since(startTime).Truncate(time.Second) + health := queryHealthz(*flagURL) + fmt.Printf("\r\033[K[%s] conns: %d/%d handshaked: %d stories: %d sent frames: %d errors: %d healthz: %s", + elapsed, + s.connSuccess.Load(), *flagConns, + s.handshaked.Load(), + s.storySent.Load(), + s.opusFrames.Load(), + s.errors.Load(), + health, + ) + case <-done: + return + } + } + }() + + // 按 ramp 速率建立连接 + rampInterval := time.Second / time.Duration(*flagRamp) + for i := 1; i <= *flagConns; i++ { + select { + case <-done: + goto waitDone + default: + } + + triggerStory := i <= *flagStories + dev := newDevice(i, *flagMACPrefix, triggerStory) + wg.Add(1) + go dev.run(*flagURL, &wg, done) + + // 控制建连速率 + if i < *flagConns { + time.Sleep(rampInterval) + } + } + + // 所有连接建立后,等待 duration 到期 + fmt.Printf("\n所有连接已发起,等待 %s...\n", *flagDuration) + select { + case <-time.After(*flagDuration): + fmt.Println("\n测试时长到期,正在停止...") + close(done) + case <-done: + } + +waitDone: + // 等待所有 goroutine 退出(最多 10s) + waitCh := make(chan struct{}) + go func() { + wg.Wait() + close(waitCh) + }() + select { + case <-waitCh: + case <-time.After(10 * time.Second): + fmt.Println("等待超时,强制退出") + } + + // 最终报告 + printReport() +} + +func printReport() { + fmt.Println() + fmt.Println("========== 测试报告 ==========") + fmt.Printf("目标连接数: %d\n", *flagConns) + fmt.Printf("连接尝试: %d\n", s.connAttempts.Load()) + fmt.Printf("成功连接: %d\n", s.connSuccess.Load()) + fmt.Printf("连接失败: %d\n", s.connFailed.Load()) + fmt.Printf("握手成功: %d\n", s.handshaked.Load()) + fmt.Printf("握手失败: %d\n", s.handshakeFail.Load()) + fmt.Println("------------------------------") + fmt.Printf("触发故事数: %d\n", s.storySent.Load()) + fmt.Printf("收到 tts start: %d\n", s.ttsStart.Load()) + fmt.Printf("收到 tts stop: %d\n", s.ttsStop.Load()) + fmt.Printf("Opus 帧总数: %d\n", s.opusFrames.Load()) + if s.storySent.Load() > 0 { + fmt.Printf("平均帧数/故事: %d\n", s.opusFrames.Load()/max(s.ttsStop.Load(), 1)) + } + if s.firstFrameCnt.Load() > 0 { + avgMs := s.firstFrameNs.Load() / s.firstFrameCnt.Load() / 1e6 + fmt.Printf("首帧延迟(avg): %dms\n", avgMs) + } + fmt.Printf("错误总数: %d\n", s.errors.Load()) + fmt.Println("==============================") +} + +func max(a, b int64) int64 { + if a > b { + return a + } + return b +} diff --git a/hw_service_go/test/test.html b/hw_service_go/test/test.html index dc96656..0a20333 100644 --- a/hw_service_go/test/test.html +++ b/hw_service_go/test/test.html @@ -177,10 +177,11 @@
+
- +
@@ -290,6 +291,22 @@ function generateClientId() { $('clientId').value = id; } +const ENV_LOCAL = { url: 'ws://localhost:8888/xiaozhi/v1/', label: '切换线上' }; +const ENV_PROD = { url: 'wss://qiyuan-rtc-api.airlabs.art/xiaozhi/v1/', label: '切换本地' }; +let currentEnv = 'local'; + +function toggleEnv() { + if (currentEnv === 'local') { + $('wsUrl').value = ENV_PROD.url; + $('btnEnvToggle').textContent = ENV_PROD.label; + currentEnv = 'prod'; + } else { + $('wsUrl').value = ENV_LOCAL.url; + $('btnEnvToggle').textContent = ENV_LOCAL.label; + currentEnv = 'local'; + } +} + function clearLog() { $('logContainer').innerHTML = ''; }