All checks were successful
Build and Deploy Backend / build-and-deploy (push) Successful in 4m14s
380 lines
9.9 KiB
Go
380 lines
9.9 KiB
Go
// hw_service_go 并发压力测试工具
|
||
//
|
||
// 用法:
|
||
//
|
||
// go run main.go -conns 100 -stories 0 # 100 个空闲连接
|
||
// go run main.go -conns 50 -stories 10 # 50 连接,10 个触发故事
|
||
// go run main.go -url wss://example.com/xiaozhi/v1/ -conns 50
|
||
package main
|
||
|
||
import (
|
||
"encoding/json"
|
||
"flag"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"net/url"
|
||
"os"
|
||
"os/signal"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"syscall"
|
||
"time"
|
||
|
||
"github.com/gorilla/websocket"
|
||
)
|
||
|
||
// ── 命令行参数 ─────────────────────────────────────────────────
|
||
|
||
var (
|
||
flagURL = flag.String("url", "ws://localhost:8888/xiaozhi/v1/", "WebSocket 服务地址")
|
||
flagConns = flag.Int("conns", 100, "总连接数")
|
||
flagStories = flag.Int("stories", 10, "同时触发故事的连接数")
|
||
flagRamp = flag.Int("ramp", 20, "每秒建立的连接数")
|
||
flagDuration = flag.Duration("duration", 60*time.Second, "测试持续时间")
|
||
flagMACPrefix = flag.String("mac-prefix", "AA:BB:CC:DD", "模拟 MAC 前缀")
|
||
)
|
||
|
||
// ── 统计指标(原子操作,goroutine 安全) ──────────────────────
|
||
|
||
type stats struct {
|
||
connAttempts atomic.Int64
|
||
connSuccess atomic.Int64
|
||
connFailed atomic.Int64
|
||
handshaked atomic.Int64
|
||
handshakeFail atomic.Int64
|
||
storySent atomic.Int64
|
||
ttsStart atomic.Int64
|
||
ttsStop atomic.Int64
|
||
opusFrames atomic.Int64
|
||
errors atomic.Int64
|
||
firstFrameNs atomic.Int64 // 所有设备首帧延迟总和(纳秒),用于算均值
|
||
firstFrameCnt atomic.Int64 // 收到首帧的设备数
|
||
}
|
||
|
||
var s stats
|
||
|
||
// ── 模拟设备 ────────────────────────────────────────────────
|
||
|
||
type device struct {
|
||
id int
|
||
mac string
|
||
clientID string
|
||
ws *websocket.Conn
|
||
triggerStory bool
|
||
}
|
||
|
||
func newDevice(id int, macPrefix string, triggerStory bool) *device {
|
||
hi := byte((id >> 8) & 0xFF)
|
||
lo := byte(id & 0xFF)
|
||
mac := fmt.Sprintf("%s:%02X:%02X", macPrefix, hi, lo)
|
||
return &device{
|
||
id: id,
|
||
mac: mac,
|
||
clientID: fmt.Sprintf("stress-%d", id),
|
||
triggerStory: triggerStory,
|
||
}
|
||
}
|
||
|
||
func (d *device) run(baseURL string, wg *sync.WaitGroup, done <-chan struct{}) {
|
||
defer wg.Done()
|
||
|
||
s.connAttempts.Add(1)
|
||
|
||
// 1. 建立 WebSocket 连接
|
||
u, err := url.Parse(baseURL)
|
||
if err != nil {
|
||
log.Printf("[dev-%d] invalid URL: %v", d.id, err)
|
||
s.connFailed.Add(1)
|
||
return
|
||
}
|
||
q := u.Query()
|
||
q.Set("device-id", d.mac)
|
||
q.Set("client-id", d.clientID)
|
||
u.RawQuery = q.Encode()
|
||
|
||
dialer := websocket.Dialer{
|
||
HandshakeTimeout: 10 * time.Second,
|
||
}
|
||
ws, _, err := dialer.Dial(u.String(), nil)
|
||
if err != nil {
|
||
log.Printf("[dev-%d] connect failed: %v", d.id, err)
|
||
s.connFailed.Add(1)
|
||
return
|
||
}
|
||
d.ws = ws
|
||
s.connSuccess.Add(1)
|
||
defer ws.Close()
|
||
|
||
// 2. 发送 hello 握手
|
||
helloMsg, _ := json.Marshal(map[string]string{
|
||
"type": "hello",
|
||
"mac": d.mac,
|
||
})
|
||
if err := ws.WriteMessage(websocket.TextMessage, helloMsg); err != nil {
|
||
log.Printf("[dev-%d] hello send failed: %v", d.id, err)
|
||
s.handshakeFail.Add(1)
|
||
return
|
||
}
|
||
|
||
// 3. 等待 hello 响应(5s 超时)
|
||
ws.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||
_, msg, err := ws.ReadMessage()
|
||
if err != nil {
|
||
log.Printf("[dev-%d] hello read failed: %v", d.id, err)
|
||
s.handshakeFail.Add(1)
|
||
return
|
||
}
|
||
ws.SetReadDeadline(time.Time{}) // 清除超时
|
||
|
||
var helloResp struct {
|
||
Type string `json:"type"`
|
||
Status string `json:"status"`
|
||
}
|
||
if err := json.Unmarshal(msg, &helloResp); err != nil || helloResp.Type != "hello" || helloResp.Status != "ok" {
|
||
log.Printf("[dev-%d] hello failed: %s", d.id, string(msg))
|
||
s.handshakeFail.Add(1)
|
||
return
|
||
}
|
||
s.handshaked.Add(1)
|
||
|
||
// 4. 如果被选为活跃设备,触发故事
|
||
var storySentTime time.Time
|
||
var gotFirstFrame bool
|
||
|
||
if d.triggerStory {
|
||
storyMsg, _ := json.Marshal(map[string]string{"type": "story"})
|
||
if err := ws.WriteMessage(websocket.TextMessage, storyMsg); err != nil {
|
||
log.Printf("[dev-%d] story send failed: %v", d.id, err)
|
||
s.errors.Add(1)
|
||
} else {
|
||
s.storySent.Add(1)
|
||
storySentTime = time.Now()
|
||
}
|
||
}
|
||
|
||
// 5. 消息接收循环
|
||
msgCh := make(chan struct{}, 1) // 用于通知有新消息
|
||
go func() {
|
||
for {
|
||
msgType, data, err := ws.ReadMessage()
|
||
if err != nil {
|
||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||
select {
|
||
case <-done:
|
||
// 正常关闭,不算错误
|
||
default:
|
||
s.errors.Add(1)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
if msgType == websocket.BinaryMessage {
|
||
// Opus 帧
|
||
s.opusFrames.Add(1)
|
||
if d.triggerStory && !gotFirstFrame && !storySentTime.IsZero() {
|
||
gotFirstFrame = true
|
||
latency := time.Since(storySentTime)
|
||
s.firstFrameNs.Add(latency.Nanoseconds())
|
||
s.firstFrameCnt.Add(1)
|
||
}
|
||
_ = data // 不需要解码,只计数
|
||
} else {
|
||
// 文本消息
|
||
var envelope struct {
|
||
Type string `json:"type"`
|
||
State string `json:"state"`
|
||
}
|
||
if json.Unmarshal(data, &envelope) == nil {
|
||
if envelope.Type == "tts" {
|
||
switch envelope.State {
|
||
case "start":
|
||
s.ttsStart.Add(1)
|
||
case "stop":
|
||
s.ttsStop.Add(1)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
select {
|
||
case msgCh <- struct{}{}:
|
||
default:
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 6. 等待测试结束
|
||
<-done
|
||
ws.WriteMessage(websocket.CloseMessage,
|
||
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
|
||
)
|
||
}
|
||
|
||
// ── healthz 查询 ─────────────────────────────────────────────
|
||
|
||
func queryHealthz(baseURL string) string {
|
||
// 从 ws:// URL 推导 http:// URL
|
||
u, err := url.Parse(baseURL)
|
||
if err != nil {
|
||
return "N/A"
|
||
}
|
||
switch u.Scheme {
|
||
case "ws":
|
||
u.Scheme = "http"
|
||
case "wss":
|
||
u.Scheme = "https"
|
||
}
|
||
// 去掉 /xiaozhi/v1/ 路径,换成 /healthz
|
||
u.Path = "/healthz"
|
||
u.RawQuery = ""
|
||
|
||
client := &http.Client{Timeout: 3 * time.Second}
|
||
resp, err := client.Get(u.String())
|
||
if err != nil {
|
||
return "N/A"
|
||
}
|
||
defer resp.Body.Close()
|
||
body, _ := io.ReadAll(resp.Body)
|
||
return strings.TrimSpace(string(body))
|
||
}
|
||
|
||
// ── 主函数 ──────────────────────────────────────────────────
|
||
|
||
func main() {
|
||
flag.Parse()
|
||
|
||
if *flagStories > *flagConns {
|
||
*flagStories = *flagConns
|
||
}
|
||
|
||
fmt.Println("========================================")
|
||
fmt.Println(" hw_service_go 并发压力测试")
|
||
fmt.Println("========================================")
|
||
fmt.Printf(" 目标地址: %s\n", *flagURL)
|
||
fmt.Printf(" 总连接数: %d\n", *flagConns)
|
||
fmt.Printf(" 触发故事: %d\n", *flagStories)
|
||
fmt.Printf(" 建连速率: %d/s\n", *flagRamp)
|
||
fmt.Printf(" 测试时长: %s\n", *flagDuration)
|
||
fmt.Printf(" MAC 前缀: %s\n", *flagMACPrefix)
|
||
fmt.Println("========================================")
|
||
fmt.Println()
|
||
|
||
done := make(chan struct{})
|
||
var wg sync.WaitGroup
|
||
|
||
// 信号处理:Ctrl+C 提前结束
|
||
sigCh := make(chan os.Signal, 1)
|
||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||
go func() {
|
||
<-sigCh
|
||
fmt.Println("\n收到退出信号,正在停止...")
|
||
close(done)
|
||
}()
|
||
|
||
// 实时统计输出
|
||
startTime := time.Now()
|
||
go func() {
|
||
ticker := time.NewTicker(2 * time.Second)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
elapsed := time.Since(startTime).Truncate(time.Second)
|
||
health := queryHealthz(*flagURL)
|
||
fmt.Printf("\r\033[K[%s] conns: %d/%d handshaked: %d stories: %d sent frames: %d errors: %d healthz: %s",
|
||
elapsed,
|
||
s.connSuccess.Load(), *flagConns,
|
||
s.handshaked.Load(),
|
||
s.storySent.Load(),
|
||
s.opusFrames.Load(),
|
||
s.errors.Load(),
|
||
health,
|
||
)
|
||
case <-done:
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
// 按 ramp 速率建立连接
|
||
rampInterval := time.Second / time.Duration(*flagRamp)
|
||
for i := 1; i <= *flagConns; i++ {
|
||
select {
|
||
case <-done:
|
||
goto waitDone
|
||
default:
|
||
}
|
||
|
||
triggerStory := i <= *flagStories
|
||
dev := newDevice(i, *flagMACPrefix, triggerStory)
|
||
wg.Add(1)
|
||
go dev.run(*flagURL, &wg, done)
|
||
|
||
// 控制建连速率
|
||
if i < *flagConns {
|
||
time.Sleep(rampInterval)
|
||
}
|
||
}
|
||
|
||
// 所有连接建立后,等待 duration 到期
|
||
fmt.Printf("\n所有连接已发起,等待 %s...\n", *flagDuration)
|
||
select {
|
||
case <-time.After(*flagDuration):
|
||
fmt.Println("\n测试时长到期,正在停止...")
|
||
close(done)
|
||
case <-done:
|
||
}
|
||
|
||
waitDone:
|
||
// 等待所有 goroutine 退出(最多 10s)
|
||
waitCh := make(chan struct{})
|
||
go func() {
|
||
wg.Wait()
|
||
close(waitCh)
|
||
}()
|
||
select {
|
||
case <-waitCh:
|
||
case <-time.After(10 * time.Second):
|
||
fmt.Println("等待超时,强制退出")
|
||
}
|
||
|
||
// 最终报告
|
||
printReport()
|
||
}
|
||
|
||
func printReport() {
|
||
fmt.Println()
|
||
fmt.Println("========== 测试报告 ==========")
|
||
fmt.Printf("目标连接数: %d\n", *flagConns)
|
||
fmt.Printf("连接尝试: %d\n", s.connAttempts.Load())
|
||
fmt.Printf("成功连接: %d\n", s.connSuccess.Load())
|
||
fmt.Printf("连接失败: %d\n", s.connFailed.Load())
|
||
fmt.Printf("握手成功: %d\n", s.handshaked.Load())
|
||
fmt.Printf("握手失败: %d\n", s.handshakeFail.Load())
|
||
fmt.Println("------------------------------")
|
||
fmt.Printf("触发故事数: %d\n", s.storySent.Load())
|
||
fmt.Printf("收到 tts start: %d\n", s.ttsStart.Load())
|
||
fmt.Printf("收到 tts stop: %d\n", s.ttsStop.Load())
|
||
fmt.Printf("Opus 帧总数: %d\n", s.opusFrames.Load())
|
||
if s.storySent.Load() > 0 {
|
||
fmt.Printf("平均帧数/故事: %d\n", s.opusFrames.Load()/max(s.ttsStop.Load(), 1))
|
||
}
|
||
if s.firstFrameCnt.Load() > 0 {
|
||
avgMs := s.firstFrameNs.Load() / s.firstFrameCnt.Load() / 1e6
|
||
fmt.Printf("首帧延迟(avg): %dms\n", avgMs)
|
||
}
|
||
fmt.Printf("错误总数: %d\n", s.errors.Load())
|
||
fmt.Println("==============================")
|
||
}
|
||
|
||
func max(a, b int64) int64 {
|
||
if a > b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|