// hw_service_go 并发压力测试工具 // // 用法: // // go run main.go -conns 100 -stories 0 # 100 个空闲连接 // go run main.go -conns 50 -stories 10 # 50 连接,10 个触发故事 // go run main.go -url wss://example.com/xiaozhi/v1/ -conns 50 package main import ( "encoding/json" "flag" "fmt" "io" "log" "net/http" "net/url" "os" "os/signal" "strings" "sync" "sync/atomic" "syscall" "time" "github.com/gorilla/websocket" ) // ── 命令行参数 ───────────────────────────────────────────────── var ( flagURL = flag.String("url", "ws://localhost:8888/xiaozhi/v1/", "WebSocket 服务地址") flagConns = flag.Int("conns", 100, "总连接数") flagStories = flag.Int("stories", 10, "同时触发故事的连接数") flagRamp = flag.Int("ramp", 20, "每秒建立的连接数") flagDuration = flag.Duration("duration", 60*time.Second, "测试持续时间") flagMACPrefix = flag.String("mac-prefix", "AA:BB:CC:DD", "模拟 MAC 前缀") ) // ── 统计指标(原子操作,goroutine 安全) ────────────────────── type stats struct { connAttempts atomic.Int64 connSuccess atomic.Int64 connFailed atomic.Int64 handshaked atomic.Int64 handshakeFail atomic.Int64 storySent atomic.Int64 ttsStart atomic.Int64 ttsStop atomic.Int64 opusFrames atomic.Int64 errors atomic.Int64 firstFrameNs atomic.Int64 // 所有设备首帧延迟总和(纳秒),用于算均值 firstFrameCnt atomic.Int64 // 收到首帧的设备数 } var s stats // ── 模拟设备 ──────────────────────────────────────────────── type device struct { id int mac string clientID string ws *websocket.Conn triggerStory bool } func newDevice(id int, macPrefix string, triggerStory bool) *device { hi := byte((id >> 8) & 0xFF) lo := byte(id & 0xFF) mac := fmt.Sprintf("%s:%02X:%02X", macPrefix, hi, lo) return &device{ id: id, mac: mac, clientID: fmt.Sprintf("stress-%d", id), triggerStory: triggerStory, } } func (d *device) run(baseURL string, wg *sync.WaitGroup, done <-chan struct{}) { defer wg.Done() s.connAttempts.Add(1) // 1. 建立 WebSocket 连接 u, err := url.Parse(baseURL) if err != nil { log.Printf("[dev-%d] invalid URL: %v", d.id, err) s.connFailed.Add(1) return } q := u.Query() q.Set("device-id", d.mac) q.Set("client-id", d.clientID) u.RawQuery = q.Encode() dialer := websocket.Dialer{ HandshakeTimeout: 10 * time.Second, } ws, _, err := dialer.Dial(u.String(), nil) if err != nil { log.Printf("[dev-%d] connect failed: %v", d.id, err) s.connFailed.Add(1) return } d.ws = ws s.connSuccess.Add(1) defer ws.Close() // 2. 发送 hello 握手 helloMsg, _ := json.Marshal(map[string]string{ "type": "hello", "mac": d.mac, }) if err := ws.WriteMessage(websocket.TextMessage, helloMsg); err != nil { log.Printf("[dev-%d] hello send failed: %v", d.id, err) s.handshakeFail.Add(1) return } // 3. 等待 hello 响应(5s 超时) ws.SetReadDeadline(time.Now().Add(5 * time.Second)) _, msg, err := ws.ReadMessage() if err != nil { log.Printf("[dev-%d] hello read failed: %v", d.id, err) s.handshakeFail.Add(1) return } ws.SetReadDeadline(time.Time{}) // 清除超时 var helloResp struct { Type string `json:"type"` Status string `json:"status"` } if err := json.Unmarshal(msg, &helloResp); err != nil || helloResp.Type != "hello" || helloResp.Status != "ok" { log.Printf("[dev-%d] hello failed: %s", d.id, string(msg)) s.handshakeFail.Add(1) return } s.handshaked.Add(1) // 4. 如果被选为活跃设备,触发故事 var storySentTime time.Time var gotFirstFrame bool if d.triggerStory { storyMsg, _ := json.Marshal(map[string]string{"type": "story"}) if err := ws.WriteMessage(websocket.TextMessage, storyMsg); err != nil { log.Printf("[dev-%d] story send failed: %v", d.id, err) s.errors.Add(1) } else { s.storySent.Add(1) storySentTime = time.Now() } } // 5. 消息接收循环 msgCh := make(chan struct{}, 1) // 用于通知有新消息 go func() { for { msgType, data, err := ws.ReadMessage() if err != nil { if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { select { case <-done: // 正常关闭,不算错误 default: s.errors.Add(1) } } return } if msgType == websocket.BinaryMessage { // Opus 帧 s.opusFrames.Add(1) if d.triggerStory && !gotFirstFrame && !storySentTime.IsZero() { gotFirstFrame = true latency := time.Since(storySentTime) s.firstFrameNs.Add(latency.Nanoseconds()) s.firstFrameCnt.Add(1) } _ = data // 不需要解码,只计数 } else { // 文本消息 var envelope struct { Type string `json:"type"` State string `json:"state"` } if json.Unmarshal(data, &envelope) == nil { if envelope.Type == "tts" { switch envelope.State { case "start": s.ttsStart.Add(1) case "stop": s.ttsStop.Add(1) } } } } select { case msgCh <- struct{}{}: default: } } }() // 6. 等待测试结束 <-done ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), ) } // ── healthz 查询 ───────────────────────────────────────────── func queryHealthz(baseURL string) string { // 从 ws:// URL 推导 http:// URL u, err := url.Parse(baseURL) if err != nil { return "N/A" } switch u.Scheme { case "ws": u.Scheme = "http" case "wss": u.Scheme = "https" } // 去掉 /xiaozhi/v1/ 路径,换成 /healthz u.Path = "/healthz" u.RawQuery = "" client := &http.Client{Timeout: 3 * time.Second} resp, err := client.Get(u.String()) if err != nil { return "N/A" } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) return strings.TrimSpace(string(body)) } // ── 主函数 ────────────────────────────────────────────────── func main() { flag.Parse() if *flagStories > *flagConns { *flagStories = *flagConns } fmt.Println("========================================") fmt.Println(" hw_service_go 并发压力测试") fmt.Println("========================================") fmt.Printf(" 目标地址: %s\n", *flagURL) fmt.Printf(" 总连接数: %d\n", *flagConns) fmt.Printf(" 触发故事: %d\n", *flagStories) fmt.Printf(" 建连速率: %d/s\n", *flagRamp) fmt.Printf(" 测试时长: %s\n", *flagDuration) fmt.Printf(" MAC 前缀: %s\n", *flagMACPrefix) fmt.Println("========================================") fmt.Println() done := make(chan struct{}) var wg sync.WaitGroup // 信号处理:Ctrl+C 提前结束 sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { <-sigCh fmt.Println("\n收到退出信号,正在停止...") close(done) }() // 实时统计输出 startTime := time.Now() go func() { ticker := time.NewTicker(2 * time.Second) defer ticker.Stop() for { select { case <-ticker.C: elapsed := time.Since(startTime).Truncate(time.Second) health := queryHealthz(*flagURL) fmt.Printf("\r\033[K[%s] conns: %d/%d handshaked: %d stories: %d sent frames: %d errors: %d healthz: %s", elapsed, s.connSuccess.Load(), *flagConns, s.handshaked.Load(), s.storySent.Load(), s.opusFrames.Load(), s.errors.Load(), health, ) case <-done: return } } }() // 按 ramp 速率建立连接 rampInterval := time.Second / time.Duration(*flagRamp) for i := 1; i <= *flagConns; i++ { select { case <-done: goto waitDone default: } triggerStory := i <= *flagStories dev := newDevice(i, *flagMACPrefix, triggerStory) wg.Add(1) go dev.run(*flagURL, &wg, done) // 控制建连速率 if i < *flagConns { time.Sleep(rampInterval) } } // 所有连接建立后,等待 duration 到期 fmt.Printf("\n所有连接已发起,等待 %s...\n", *flagDuration) select { case <-time.After(*flagDuration): fmt.Println("\n测试时长到期,正在停止...") close(done) case <-done: } waitDone: // 等待所有 goroutine 退出(最多 10s) waitCh := make(chan struct{}) go func() { wg.Wait() close(waitCh) }() select { case <-waitCh: case <-time.After(10 * time.Second): fmt.Println("等待超时,强制退出") } // 最终报告 printReport() } func printReport() { fmt.Println() fmt.Println("========== 测试报告 ==========") fmt.Printf("目标连接数: %d\n", *flagConns) fmt.Printf("连接尝试: %d\n", s.connAttempts.Load()) fmt.Printf("成功连接: %d\n", s.connSuccess.Load()) fmt.Printf("连接失败: %d\n", s.connFailed.Load()) fmt.Printf("握手成功: %d\n", s.handshaked.Load()) fmt.Printf("握手失败: %d\n", s.handshakeFail.Load()) fmt.Println("------------------------------") fmt.Printf("触发故事数: %d\n", s.storySent.Load()) fmt.Printf("收到 tts start: %d\n", s.ttsStart.Load()) fmt.Printf("收到 tts stop: %d\n", s.ttsStop.Load()) fmt.Printf("Opus 帧总数: %d\n", s.opusFrames.Load()) if s.storySent.Load() > 0 { fmt.Printf("平均帧数/故事: %d\n", s.opusFrames.Load()/max(s.ttsStop.Load(), 1)) } if s.firstFrameCnt.Load() > 0 { avgMs := s.firstFrameNs.Load() / s.firstFrameCnt.Load() / 1e6 fmt.Printf("首帧延迟(avg): %dms\n", avgMs) } fmt.Printf("错误总数: %d\n", s.errors.Load()) fmt.Println("==============================") } func max(a, b int64) int64 { if a > b { return a } return b }