2026-03-02 17:33:56 +08:00

273 lines
7.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Package server 实现 WebSocket 服务器,管理硬件设备连接的生命周期。
package server
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/qy/hw-ws-service/internal/connection"
"github.com/qy/hw-ws-service/internal/handler"
"github.com/qy/hw-ws-service/internal/rtcclient"
)
const (
// maxConnections 最大并发连接数,防止资源耗尽。
maxConnections = 500
// maxMessageBytes WebSocket 单条消息上限4KB防止内存耗尽攻击。
maxMessageBytes = 4 * 1024
// helloTimeout 握手超时:连接建立后必须在此时间内发送 hello否则断开。
helloTimeout = 10 * time.Second
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// IoT 设备无浏览器 Origin允许所有来源
CheckOrigin: func(r *http.Request) bool { return true },
}
// Server 管理所有活跃的设备连接。
type Server struct {
client *rtcclient.Client
httpServer *http.Server
mu sync.Mutex
conns map[string]*connection.Connection // key: DeviceID
wg sync.WaitGroup // 跟踪所有连接 goroutine
}
// New 创建 Serveraddr 形如 "0.0.0.0:8888"。
func New(addr string, client *rtcclient.Client) *Server {
s := &Server{
client: client,
conns: make(map[string]*connection.Connection),
}
mux := http.NewServeMux()
mux.HandleFunc("/xiaozhi/v1/healthz", s.handleStatus)
mux.HandleFunc("/xiaozhi/v1/", s.handleConn)
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
s.httpServer = &http.Server{
Addr: addr,
Handler: mux,
}
return s
}
// ListenAndServe 启动服务器,阻塞直到服务器关闭。
func (s *Server) ListenAndServe() error {
log.Printf("server: listening on %s", s.httpServer.Addr)
err := s.httpServer.ListenAndServe()
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}
// Shutdown 优雅关闭:先停止接受新连接,再等待所有连接 goroutine 退出。
func (s *Server) Shutdown(ctx context.Context) {
log.Println("server: shutting down...")
s.httpServer.Shutdown(ctx) //nolint:errcheck
// 等待所有连接 goroutine 退出(由 ctx 超时兜底)
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("server: all connections closed gracefully")
case <-ctx.Done():
log.Println("server: shutdown timeout, forcing close")
}
}
// handleConn 处理单个 WebSocket 连接的完整生命周期。
// URL 格式:/xiaozhi/v1/?device-id=<MAC>&client-id=<UUID>
func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) {
deviceID := r.URL.Query().Get("device-id")
clientID := r.URL.Query().Get("client-id")
if deviceID == "" {
http.Error(w, "missing device-id", http.StatusBadRequest)
return
}
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("server: upgrade failed for %s: %v", deviceID, err)
return
}
// 设置单条消息大小上限
ws.SetReadLimit(maxMessageBytes)
conn := connection.New(ws, deviceID, clientID)
if err := s.register(conn); err != nil {
log.Printf("server: register %s failed: %v", deviceID, err)
ws.Close()
return
}
s.wg.Add(1)
defer func() {
conn.StopPlayback()
s.unregister(deviceID)
ws.Close()
s.wg.Done()
log.Printf("server: device %s disconnected, active=%d", deviceID, s.activeCount())
}()
log.Printf("server: device %s connected, active=%d", deviceID, s.activeCount())
// 阶段1等待 hello 握手(超时 helloTimeout
ws.SetReadDeadline(time.Now().Add(helloTimeout)) //nolint:errcheck
if !s.waitForHello(conn) {
log.Printf("server: device %s hello timeout or failed", deviceID)
return
}
ws.SetReadDeadline(time.Time{}) //nolint:errcheck // 握手成功,取消读超时
log.Printf("server: device %s handshaked, session=%s", deviceID, conn.SessionID)
// 阶段2正常消息循环
for {
msgType, raw, err := ws.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
if !isNetworkClose(err) {
log.Printf("server: read error for %s: %v", deviceID, err)
}
}
return
}
// 只处理文本消息
if msgType != websocket.TextMessage {
continue
}
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(raw, &envelope); err != nil {
log.Printf("server: invalid json from %s: %v", deviceID, err)
continue
}
switch envelope.Type {
case "story":
go handler.HandleStory(conn, s.client)
case "abort":
handler.HandleAbort(conn)
default:
log.Printf("server: unhandled message type %q from %s", envelope.Type, deviceID)
}
}
}
// waitForHello 等待并处理第一条 hello 消息,成功返回 true。
func (s *Server) waitForHello(conn *connection.Connection) bool {
msgType, raw, err := conn.WS.ReadMessage()
if err != nil {
return false
}
if msgType != websocket.TextMessage {
log.Printf("server: device %s sent non-text as first message", conn.DeviceID)
return false
}
var envelope struct {
Type string `json:"type"`
}
if err := json.Unmarshal(raw, &envelope); err != nil || envelope.Type != "hello" {
log.Printf("server: device %s first message is not hello (got %q)", conn.DeviceID, envelope.Type)
return false
}
if err := handler.HandleHello(conn, raw); err != nil {
log.Printf("server: device %s hello failed: %v", conn.DeviceID, err)
return false
}
return true
}
// register 注册连接,若同一设备已有连接则踢掉旧连接。
func (s *Server) register(conn *connection.Connection) error {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.conns) >= maxConnections {
return errors.New("server: max connections reached")
}
// 同一设备同时只允许一个连接
if old, exists := s.conns[conn.DeviceID]; exists {
log.Printf("server: kicking old connection for %s", conn.DeviceID)
old.Close()
}
s.conns[conn.DeviceID] = conn
return nil
}
// SendCmd 向指定设备发送控制指令。
// 若设备不在线或未握手,返回 error。
func (s *Server) SendCmd(deviceID, action string, params any) error {
s.mu.Lock()
conn, ok := s.conns[deviceID]
s.mu.Unlock()
if !ok {
return fmt.Errorf("server: device %s not connected", deviceID)
}
if !conn.IsHandshaked() {
return fmt.Errorf("server: device %s not handshaked", deviceID)
}
return conn.SendCmd(action, params)
}
func (s *Server) unregister(deviceID string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.conns, deviceID)
}
func (s *Server) activeCount() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.conns)
}
// handleStatus 返回服务状态和当前活跃连接数,用于部署后验证。
// GET /xiaozhi/v1/healthz → {"status":"ok","active_connections":N}
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
active := s.activeCount()
fmt.Fprintf(w, `{"status":"ok","active_connections":%d}`, active)
}
// isNetworkClose 判断是否为普通的网络关闭错误(不需要打印日志)。
func isNetworkClose(err error) bool {
var netErr *net.OpError
return errors.As(err, &netErr)
}