压测工具
All checks were successful
Build and Deploy Backend / build-and-deploy (push) Successful in 4m14s

This commit is contained in:
repair-agent 2026-03-03 16:22:16 +08:00
parent 0bf556018e
commit a3222d1fe5
5 changed files with 424 additions and 29 deletions

View File

@ -324,38 +324,30 @@ class DeviceViewSet(viewsets.ViewSet):
mac = mac.upper().replace('-', ':')
from apps.stories.models import Story
story = None
# 1. 尝试查找设备 → 绑定用户 → 用户故事
try:
device = Device.objects.get(mac_address=mac)
user_device = (
UserDevice.objects
.filter(device=device, is_active=True, bind_type='owner')
.select_related('user')
.first()
)
if user_device:
story = (
Story.objects
.filter(user=user_device.user)
.exclude(audio_url='')
.order_by('?')
.first()
)
except Device.DoesNotExist:
return error(
code=ErrorCode.DEVICE_NOT_FOUND,
message='未找到对应设备',
status_code=status.HTTP_404_NOT_FOUND
)
pass
user_device = (
UserDevice.objects
.filter(device=device, is_active=True, bind_type='owner')
.select_related('user')
.first()
)
if not user_device:
return error(
code=ErrorCode.NOT_FOUND,
message='该设备尚未绑定用户',
status_code=status.HTTP_404_NOT_FOUND
)
from apps.stories.models import Story
# 优先随机取用户自己有 audio_url 的故事
story = (
Story.objects
.filter(user=user_device.user)
.exclude(audio_url='')
.order_by('?')
.first()
)
# 兜底:用户暂无故事时使用系统默认故事
# 2. 兜底:设备不存在/未绑定/用户无故事 → 使用系统默认故事
if not story:
story = (
Story.objects

View File

@ -0,0 +1,5 @@
module stress
go 1.23
require github.com/gorilla/websocket v1.5.3

View File

@ -0,0 +1,2 @@
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=

View File

@ -0,0 +1,379 @@
// hw_service_go 并发压力测试工具
//
// 用法:
//
// go run main.go -conns 100 -stories 0 # 100 个空闲连接
// go run main.go -conns 50 -stories 10 # 50 连接10 个触发故事
// go run main.go -url wss://example.com/xiaozhi/v1/ -conns 50
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/gorilla/websocket"
)
// ── 命令行参数 ─────────────────────────────────────────────────
var (
flagURL = flag.String("url", "ws://localhost:8888/xiaozhi/v1/", "WebSocket 服务地址")
flagConns = flag.Int("conns", 100, "总连接数")
flagStories = flag.Int("stories", 10, "同时触发故事的连接数")
flagRamp = flag.Int("ramp", 20, "每秒建立的连接数")
flagDuration = flag.Duration("duration", 60*time.Second, "测试持续时间")
flagMACPrefix = flag.String("mac-prefix", "AA:BB:CC:DD", "模拟 MAC 前缀")
)
// ── 统计指标原子操作goroutine 安全) ──────────────────────
type stats struct {
connAttempts atomic.Int64
connSuccess atomic.Int64
connFailed atomic.Int64
handshaked atomic.Int64
handshakeFail atomic.Int64
storySent atomic.Int64
ttsStart atomic.Int64
ttsStop atomic.Int64
opusFrames atomic.Int64
errors atomic.Int64
firstFrameNs atomic.Int64 // 所有设备首帧延迟总和(纳秒),用于算均值
firstFrameCnt atomic.Int64 // 收到首帧的设备数
}
var s stats
// ── 模拟设备 ────────────────────────────────────────────────
type device struct {
id int
mac string
clientID string
ws *websocket.Conn
triggerStory bool
}
func newDevice(id int, macPrefix string, triggerStory bool) *device {
hi := byte((id >> 8) & 0xFF)
lo := byte(id & 0xFF)
mac := fmt.Sprintf("%s:%02X:%02X", macPrefix, hi, lo)
return &device{
id: id,
mac: mac,
clientID: fmt.Sprintf("stress-%d", id),
triggerStory: triggerStory,
}
}
func (d *device) run(baseURL string, wg *sync.WaitGroup, done <-chan struct{}) {
defer wg.Done()
s.connAttempts.Add(1)
// 1. 建立 WebSocket 连接
u, err := url.Parse(baseURL)
if err != nil {
log.Printf("[dev-%d] invalid URL: %v", d.id, err)
s.connFailed.Add(1)
return
}
q := u.Query()
q.Set("device-id", d.mac)
q.Set("client-id", d.clientID)
u.RawQuery = q.Encode()
dialer := websocket.Dialer{
HandshakeTimeout: 10 * time.Second,
}
ws, _, err := dialer.Dial(u.String(), nil)
if err != nil {
log.Printf("[dev-%d] connect failed: %v", d.id, err)
s.connFailed.Add(1)
return
}
d.ws = ws
s.connSuccess.Add(1)
defer ws.Close()
// 2. 发送 hello 握手
helloMsg, _ := json.Marshal(map[string]string{
"type": "hello",
"mac": d.mac,
})
if err := ws.WriteMessage(websocket.TextMessage, helloMsg); err != nil {
log.Printf("[dev-%d] hello send failed: %v", d.id, err)
s.handshakeFail.Add(1)
return
}
// 3. 等待 hello 响应5s 超时)
ws.SetReadDeadline(time.Now().Add(5 * time.Second))
_, msg, err := ws.ReadMessage()
if err != nil {
log.Printf("[dev-%d] hello read failed: %v", d.id, err)
s.handshakeFail.Add(1)
return
}
ws.SetReadDeadline(time.Time{}) // 清除超时
var helloResp struct {
Type string `json:"type"`
Status string `json:"status"`
}
if err := json.Unmarshal(msg, &helloResp); err != nil || helloResp.Type != "hello" || helloResp.Status != "ok" {
log.Printf("[dev-%d] hello failed: %s", d.id, string(msg))
s.handshakeFail.Add(1)
return
}
s.handshaked.Add(1)
// 4. 如果被选为活跃设备,触发故事
var storySentTime time.Time
var gotFirstFrame bool
if d.triggerStory {
storyMsg, _ := json.Marshal(map[string]string{"type": "story"})
if err := ws.WriteMessage(websocket.TextMessage, storyMsg); err != nil {
log.Printf("[dev-%d] story send failed: %v", d.id, err)
s.errors.Add(1)
} else {
s.storySent.Add(1)
storySentTime = time.Now()
}
}
// 5. 消息接收循环
msgCh := make(chan struct{}, 1) // 用于通知有新消息
go func() {
for {
msgType, data, err := ws.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
select {
case <-done:
// 正常关闭,不算错误
default:
s.errors.Add(1)
}
}
return
}
if msgType == websocket.BinaryMessage {
// Opus 帧
s.opusFrames.Add(1)
if d.triggerStory && !gotFirstFrame && !storySentTime.IsZero() {
gotFirstFrame = true
latency := time.Since(storySentTime)
s.firstFrameNs.Add(latency.Nanoseconds())
s.firstFrameCnt.Add(1)
}
_ = data // 不需要解码,只计数
} else {
// 文本消息
var envelope struct {
Type string `json:"type"`
State string `json:"state"`
}
if json.Unmarshal(data, &envelope) == nil {
if envelope.Type == "tts" {
switch envelope.State {
case "start":
s.ttsStart.Add(1)
case "stop":
s.ttsStop.Add(1)
}
}
}
}
select {
case msgCh <- struct{}{}:
default:
}
}
}()
// 6. 等待测试结束
<-done
ws.WriteMessage(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
)
}
// ── healthz 查询 ─────────────────────────────────────────────
func queryHealthz(baseURL string) string {
// 从 ws:// URL 推导 http:// URL
u, err := url.Parse(baseURL)
if err != nil {
return "N/A"
}
switch u.Scheme {
case "ws":
u.Scheme = "http"
case "wss":
u.Scheme = "https"
}
// 去掉 /xiaozhi/v1/ 路径,换成 /healthz
u.Path = "/healthz"
u.RawQuery = ""
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(u.String())
if err != nil {
return "N/A"
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
return strings.TrimSpace(string(body))
}
// ── 主函数 ──────────────────────────────────────────────────
func main() {
flag.Parse()
if *flagStories > *flagConns {
*flagStories = *flagConns
}
fmt.Println("========================================")
fmt.Println(" hw_service_go 并发压力测试")
fmt.Println("========================================")
fmt.Printf(" 目标地址: %s\n", *flagURL)
fmt.Printf(" 总连接数: %d\n", *flagConns)
fmt.Printf(" 触发故事: %d\n", *flagStories)
fmt.Printf(" 建连速率: %d/s\n", *flagRamp)
fmt.Printf(" 测试时长: %s\n", *flagDuration)
fmt.Printf(" MAC 前缀: %s\n", *flagMACPrefix)
fmt.Println("========================================")
fmt.Println()
done := make(chan struct{})
var wg sync.WaitGroup
// 信号处理Ctrl+C 提前结束
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
fmt.Println("\n收到退出信号正在停止...")
close(done)
}()
// 实时统计输出
startTime := time.Now()
go func() {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
elapsed := time.Since(startTime).Truncate(time.Second)
health := queryHealthz(*flagURL)
fmt.Printf("\r\033[K[%s] conns: %d/%d handshaked: %d stories: %d sent frames: %d errors: %d healthz: %s",
elapsed,
s.connSuccess.Load(), *flagConns,
s.handshaked.Load(),
s.storySent.Load(),
s.opusFrames.Load(),
s.errors.Load(),
health,
)
case <-done:
return
}
}
}()
// 按 ramp 速率建立连接
rampInterval := time.Second / time.Duration(*flagRamp)
for i := 1; i <= *flagConns; i++ {
select {
case <-done:
goto waitDone
default:
}
triggerStory := i <= *flagStories
dev := newDevice(i, *flagMACPrefix, triggerStory)
wg.Add(1)
go dev.run(*flagURL, &wg, done)
// 控制建连速率
if i < *flagConns {
time.Sleep(rampInterval)
}
}
// 所有连接建立后,等待 duration 到期
fmt.Printf("\n所有连接已发起等待 %s...\n", *flagDuration)
select {
case <-time.After(*flagDuration):
fmt.Println("\n测试时长到期正在停止...")
close(done)
case <-done:
}
waitDone:
// 等待所有 goroutine 退出(最多 10s
waitCh := make(chan struct{})
go func() {
wg.Wait()
close(waitCh)
}()
select {
case <-waitCh:
case <-time.After(10 * time.Second):
fmt.Println("等待超时,强制退出")
}
// 最终报告
printReport()
}
func printReport() {
fmt.Println()
fmt.Println("========== 测试报告 ==========")
fmt.Printf("目标连接数: %d\n", *flagConns)
fmt.Printf("连接尝试: %d\n", s.connAttempts.Load())
fmt.Printf("成功连接: %d\n", s.connSuccess.Load())
fmt.Printf("连接失败: %d\n", s.connFailed.Load())
fmt.Printf("握手成功: %d\n", s.handshaked.Load())
fmt.Printf("握手失败: %d\n", s.handshakeFail.Load())
fmt.Println("------------------------------")
fmt.Printf("触发故事数: %d\n", s.storySent.Load())
fmt.Printf("收到 tts start: %d\n", s.ttsStart.Load())
fmt.Printf("收到 tts stop: %d\n", s.ttsStop.Load())
fmt.Printf("Opus 帧总数: %d\n", s.opusFrames.Load())
if s.storySent.Load() > 0 {
fmt.Printf("平均帧数/故事: %d\n", s.opusFrames.Load()/max(s.ttsStop.Load(), 1))
}
if s.firstFrameCnt.Load() > 0 {
avgMs := s.firstFrameNs.Load() / s.firstFrameCnt.Load() / 1e6
fmt.Printf("首帧延迟(avg): %dms\n", avgMs)
}
fmt.Printf("错误总数: %d\n", s.errors.Load())
fmt.Println("==============================")
}
func max(a, b int64) int64 {
if a > b {
return a
}
return b
}

View File

@ -177,10 +177,11 @@
<div class="form-row">
<label>服务地址</label>
<input type="text" id="wsUrl" value="ws://localhost:8888/xiaozhi/v1/">
<button class="btn btn-secondary btn-small" id="btnEnvToggle" onclick="toggleEnv()">切换线上</button>
</div>
<div class="form-row">
<label>device-id</label>
<input type="text" id="deviceId" placeholder="AA:BB:CC:DD:EE:FF">
<input type="text" id="deviceId" value="20:6E:F1:B9:AF:A2">
</div>
<div class="form-row">
<label>client-id</label>
@ -290,6 +291,22 @@ function generateClientId() {
$('clientId').value = id;
}
const ENV_LOCAL = { url: 'ws://localhost:8888/xiaozhi/v1/', label: '切换线上' };
const ENV_PROD = { url: 'wss://qiyuan-rtc-api.airlabs.art/xiaozhi/v1/', label: '切换本地' };
let currentEnv = 'local';
function toggleEnv() {
if (currentEnv === 'local') {
$('wsUrl').value = ENV_PROD.url;
$('btnEnvToggle').textContent = ENV_PROD.label;
currentEnv = 'prod';
} else {
$('wsUrl').value = ENV_LOCAL.url;
$('btnEnvToggle').textContent = ENV_LOCAL.label;
currentEnv = 'local';
}
}
function clearLog() {
$('logContainer').innerHTML = '';
}