repair-agent a3222d1fe5
All checks were successful
Build and Deploy Backend / build-and-deploy (push) Successful in 4m14s
压测工具
2026-03-03 16:22:16 +08:00

380 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// hw_service_go 并发压力测试工具
//
// 用法:
//
// go run main.go -conns 100 -stories 0 # 100 个空闲连接
// go run main.go -conns 50 -stories 10 # 50 连接10 个触发故事
// go run main.go -url wss://example.com/xiaozhi/v1/ -conns 50
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/gorilla/websocket"
)
// ── 命令行参数 ─────────────────────────────────────────────────
var (
flagURL = flag.String("url", "ws://localhost:8888/xiaozhi/v1/", "WebSocket 服务地址")
flagConns = flag.Int("conns", 100, "总连接数")
flagStories = flag.Int("stories", 10, "同时触发故事的连接数")
flagRamp = flag.Int("ramp", 20, "每秒建立的连接数")
flagDuration = flag.Duration("duration", 60*time.Second, "测试持续时间")
flagMACPrefix = flag.String("mac-prefix", "AA:BB:CC:DD", "模拟 MAC 前缀")
)
// ── 统计指标原子操作goroutine 安全) ──────────────────────
type stats struct {
connAttempts atomic.Int64
connSuccess atomic.Int64
connFailed atomic.Int64
handshaked atomic.Int64
handshakeFail atomic.Int64
storySent atomic.Int64
ttsStart atomic.Int64
ttsStop atomic.Int64
opusFrames atomic.Int64
errors atomic.Int64
firstFrameNs atomic.Int64 // 所有设备首帧延迟总和(纳秒),用于算均值
firstFrameCnt atomic.Int64 // 收到首帧的设备数
}
var s stats
// ── 模拟设备 ────────────────────────────────────────────────
type device struct {
id int
mac string
clientID string
ws *websocket.Conn
triggerStory bool
}
func newDevice(id int, macPrefix string, triggerStory bool) *device {
hi := byte((id >> 8) & 0xFF)
lo := byte(id & 0xFF)
mac := fmt.Sprintf("%s:%02X:%02X", macPrefix, hi, lo)
return &device{
id: id,
mac: mac,
clientID: fmt.Sprintf("stress-%d", id),
triggerStory: triggerStory,
}
}
func (d *device) run(baseURL string, wg *sync.WaitGroup, done <-chan struct{}) {
defer wg.Done()
s.connAttempts.Add(1)
// 1. 建立 WebSocket 连接
u, err := url.Parse(baseURL)
if err != nil {
log.Printf("[dev-%d] invalid URL: %v", d.id, err)
s.connFailed.Add(1)
return
}
q := u.Query()
q.Set("device-id", d.mac)
q.Set("client-id", d.clientID)
u.RawQuery = q.Encode()
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
ws, _, err := dialer.Dial(u.String(), nil)
if err != nil {
log.Printf("[dev-%d] connect failed: %v", d.id, err)
s.connFailed.Add(1)
return
}
d.ws = ws
s.connSuccess.Add(1)
defer ws.Close()
// 2. 发送 hello 握手
helloMsg, _ := json.Marshal(map[string]string{
"type": "hello",
"mac": d.mac,
})
if err := ws.WriteMessage(websocket.TextMessage, helloMsg); err != nil {
log.Printf("[dev-%d] hello send failed: %v", d.id, err)
s.handshakeFail.Add(1)
return
}
// 3. 等待 hello 响应5s 超时)
ws.SetReadDeadline(time.Now().Add(5 * time.Second))
_, msg, err := ws.ReadMessage()
if err != nil {
log.Printf("[dev-%d] hello read failed: %v", d.id, err)
s.handshakeFail.Add(1)
return
}
ws.SetReadDeadline(time.Time{}) // 清除超时
var helloResp struct {
Type string `json:"type"`
Status string `json:"status"`
}
if err := json.Unmarshal(msg, &helloResp); err != nil || helloResp.Type != "hello" || helloResp.Status != "ok" {
log.Printf("[dev-%d] hello failed: %s", d.id, string(msg))
s.handshakeFail.Add(1)
return
}
s.handshaked.Add(1)
// 4. 如果被选为活跃设备,触发故事
var storySentTime time.Time
var gotFirstFrame bool
if d.triggerStory {
storyMsg, _ := json.Marshal(map[string]string{"type": "story"})
if err := ws.WriteMessage(websocket.TextMessage, storyMsg); err != nil {
log.Printf("[dev-%d] story send failed: %v", d.id, err)
s.errors.Add(1)
} else {
s.storySent.Add(1)
storySentTime = time.Now()
}
}
// 5. 消息接收循环
msgCh := make(chan struct{}, 1) // 用于通知有新消息
go func() {
for {
msgType, data, err := ws.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
select {
case <-done:
// 正常关闭,不算错误
default:
s.errors.Add(1)
}
}
return
}
if msgType == websocket.BinaryMessage {
// Opus 帧
s.opusFrames.Add(1)
if d.triggerStory && !gotFirstFrame && !storySentTime.IsZero() {
gotFirstFrame = true
latency := time.Since(storySentTime)
s.firstFrameNs.Add(latency.Nanoseconds())
s.firstFrameCnt.Add(1)
}
_ = data // 不需要解码,只计数
} else {
// 文本消息
var envelope struct {
Type string `json:"type"`
State string `json:"state"`
}
if json.Unmarshal(data, &envelope) == nil {
if envelope.Type == "tts" {
switch envelope.State {
case "start":
s.ttsStart.Add(1)
case "stop":
s.ttsStop.Add(1)
}
}
}
}
select {
case msgCh <- struct{}{}:
default:
}
}
}()
// 6. 等待测试结束
<-done
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
)
}
// ── healthz 查询 ─────────────────────────────────────────────
func queryHealthz(baseURL string) string {
// 从 ws:// URL 推导 http:// URL
u, err := url.Parse(baseURL)
if err != nil {
return "N/A"
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
}
// 去掉 /xiaozhi/v1/ 路径,换成 /healthz
u.Path = "/healthz"
u.RawQuery = ""
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(u.String())
if err != nil {
return "N/A"
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return strings.TrimSpace(string(body))
}
// ── 主函数 ──────────────────────────────────────────────────
func main() {
flag.Parse()
if *flagStories > *flagConns {
*flagStories = *flagConns
}
fmt.Println("========================================")
fmt.Println(" hw_service_go 并发压力测试")
fmt.Println("========================================")
fmt.Printf(" 目标地址: %s\n", *flagURL)
fmt.Printf(" 总连接数: %d\n", *flagConns)
fmt.Printf(" 触发故事: %d\n", *flagStories)
fmt.Printf(" 建连速率: %d/s\n", *flagRamp)
fmt.Printf(" 测试时长: %s\n", *flagDuration)
fmt.Printf(" MAC 前缀: %s\n", *flagMACPrefix)
fmt.Println("========================================")
fmt.Println()
done := make(chan struct{})
var wg sync.WaitGroup
// 信号处理Ctrl+C 提前结束
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
fmt.Println("\n收到退出信号正在停止...")
close(done)
}()
// 实时统计输出
startTime := time.Now()
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
elapsed := time.Since(startTime).Truncate(time.Second)
health := queryHealthz(*flagURL)
fmt.Printf("\r\033[K[%s] conns: %d/%d handshaked: %d stories: %d sent frames: %d errors: %d healthz: %s",
elapsed,
s.connSuccess.Load(), *flagConns,
s.handshaked.Load(),
s.storySent.Load(),
s.opusFrames.Load(),
s.errors.Load(),
health,
)
case <-done:
return
}
}
}()
// 按 ramp 速率建立连接
rampInterval := time.Second / time.Duration(*flagRamp)
for i := 1; i <= *flagConns; i++ {
select {
case <-done:
goto waitDone
default:
}
triggerStory := i <= *flagStories
dev := newDevice(i, *flagMACPrefix, triggerStory)
wg.Add(1)
go dev.run(*flagURL, &wg, done)
// 控制建连速率
if i < *flagConns {
time.Sleep(rampInterval)
}
}
// 所有连接建立后,等待 duration 到期
fmt.Printf("\n所有连接已发起等待 %s...\n", *flagDuration)
select {
case <-time.After(*flagDuration):
fmt.Println("\n测试时长到期正在停止...")
close(done)
case <-done:
}
waitDone:
// 等待所有 goroutine 退出(最多 10s
waitCh := make(chan struct{})
go func() {
wg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
case <-time.After(10 * time.Second):
fmt.Println("等待超时,强制退出")
}
// 最终报告
printReport()
}
func printReport() {
fmt.Println()
fmt.Println("========== 测试报告 ==========")
fmt.Printf("目标连接数: %d\n", *flagConns)
fmt.Printf("连接尝试: %d\n", s.connAttempts.Load())
fmt.Printf("成功连接: %d\n", s.connSuccess.Load())
fmt.Printf("连接失败: %d\n", s.connFailed.Load())
fmt.Printf("握手成功: %d\n", s.handshaked.Load())
fmt.Printf("握手失败: %d\n", s.handshakeFail.Load())
fmt.Println("------------------------------")
fmt.Printf("触发故事数: %d\n", s.storySent.Load())
fmt.Printf("收到 tts start: %d\n", s.ttsStart.Load())
fmt.Printf("收到 tts stop: %d\n", s.ttsStop.Load())
fmt.Printf("Opus 帧总数: %d\n", s.opusFrames.Load())
if s.storySent.Load() > 0 {
fmt.Printf("平均帧数/故事: %d\n", s.opusFrames.Load()/max(s.ttsStop.Load(), 1))
}
if s.firstFrameCnt.Load() > 0 {
avgMs := s.firstFrameNs.Load() / s.firstFrameCnt.Load() / 1e6
fmt.Printf("首帧延迟(avg): %dms\n", avgMs)
}
fmt.Printf("错误总数: %d\n", s.errors.Load())
fmt.Println("==============================")
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}