All checks were successful
Build and Deploy Backend / build-and-deploy (push) Successful in 3m48s
278 lines
7.3 KiB
Go
278 lines
7.3 KiB
Go
// Package server 实现 WebSocket 服务器,管理硬件设备连接的生命周期。
|
||
package server
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gorilla/websocket"
|
||
"github.com/qy/hw-ws-service/internal/connection"
|
||
"github.com/qy/hw-ws-service/internal/handler"
|
||
"github.com/qy/hw-ws-service/internal/rtcclient"
|
||
)
|
||
|
||
const (
|
||
// maxConnections 最大并发连接数,防止资源耗尽。
|
||
maxConnections = 500
|
||
// maxMessageBytes WebSocket 单条消息上限(4KB),防止内存耗尽攻击。
|
||
maxMessageBytes = 4 * 1024
|
||
// helloTimeout 握手超时:连接建立后必须在此时间内发送 hello,否则断开。
|
||
helloTimeout = 10 * time.Second
|
||
)
|
||
|
||
var upgrader = websocket.Upgrader{
|
||
ReadBufferSize: 1024,
|
||
WriteBufferSize: 1024,
|
||
// IoT 设备无浏览器 Origin,允许所有来源
|
||
CheckOrigin: func(r *http.Request) bool { return true },
|
||
}
|
||
|
||
// Server 管理所有活跃的设备连接。
|
||
type Server struct {
|
||
client *rtcclient.Client
|
||
httpServer *http.Server
|
||
|
||
mu sync.Mutex
|
||
conns map[string]*connection.Connection // key: DeviceID
|
||
wg sync.WaitGroup // 跟踪所有连接 goroutine
|
||
}
|
||
|
||
// New 创建 Server,addr 形如 "0.0.0.0:8888"。
|
||
func New(addr string, client *rtcclient.Client) *Server {
|
||
s := &Server{
|
||
client: client,
|
||
conns: make(map[string]*connection.Connection),
|
||
}
|
||
|
||
mux := http.NewServeMux()
|
||
mux.HandleFunc("/xiaozhi/v1/healthz", s.handleStatus)
|
||
mux.HandleFunc("/xiaozhi/v1/", s.handleConn)
|
||
mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusOK)
|
||
})
|
||
|
||
s.httpServer = &http.Server{
|
||
Addr: addr,
|
||
Handler: mux,
|
||
}
|
||
return s
|
||
}
|
||
|
||
// ListenAndServe 启动服务器,阻塞直到服务器关闭。
|
||
func (s *Server) ListenAndServe() error {
|
||
log.Printf("server: listening on %s", s.httpServer.Addr)
|
||
err := s.httpServer.ListenAndServe()
|
||
if errors.Is(err, http.ErrServerClosed) {
|
||
return nil
|
||
}
|
||
return err
|
||
}
|
||
|
||
// Shutdown 优雅关闭:先停止接受新连接,再等待所有连接 goroutine 退出。
|
||
func (s *Server) Shutdown(ctx context.Context) {
|
||
log.Println("server: shutting down...")
|
||
s.httpServer.Shutdown(ctx) //nolint:errcheck
|
||
|
||
// 等待所有连接 goroutine 退出(由 ctx 超时兜底)
|
||
done := make(chan struct{})
|
||
go func() {
|
||
s.wg.Wait()
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case <-done:
|
||
log.Println("server: all connections closed gracefully")
|
||
case <-ctx.Done():
|
||
log.Println("server: shutdown timeout, forcing close")
|
||
}
|
||
}
|
||
|
||
// handleConn 处理单个 WebSocket 连接的完整生命周期。
|
||
// URL 格式:/xiaozhi/v1/?device-id=<MAC>&client-id=<UUID>
|
||
func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) {
|
||
if r.URL.Path == "/xiaozhi/v1/healthz" {
|
||
s.handleStatus(w, r)
|
||
return
|
||
}
|
||
|
||
deviceID := r.URL.Query().Get("device-id")
|
||
clientID := r.URL.Query().Get("client-id")
|
||
|
||
if deviceID == "" {
|
||
http.Error(w, "missing device-id", http.StatusBadRequest)
|
||
return
|
||
}
|
||
|
||
ws, err := upgrader.Upgrade(w, r, nil)
|
||
if err != nil {
|
||
log.Printf("server: upgrade failed for %s: %v", deviceID, err)
|
||
return
|
||
}
|
||
|
||
// 设置单条消息大小上限
|
||
ws.SetReadLimit(maxMessageBytes)
|
||
|
||
conn := connection.New(ws, deviceID, clientID)
|
||
|
||
if err := s.register(conn); err != nil {
|
||
log.Printf("server: register %s failed: %v", deviceID, err)
|
||
ws.Close()
|
||
return
|
||
}
|
||
|
||
s.wg.Add(1)
|
||
defer func() {
|
||
conn.StopPlayback()
|
||
s.unregister(deviceID)
|
||
ws.Close()
|
||
s.wg.Done()
|
||
log.Printf("server: device %s disconnected, active=%d", deviceID, s.activeCount())
|
||
}()
|
||
|
||
log.Printf("server: device %s connected, active=%d", deviceID, s.activeCount())
|
||
|
||
// 阶段1:等待 hello 握手(超时 helloTimeout)
|
||
ws.SetReadDeadline(time.Now().Add(helloTimeout)) //nolint:errcheck
|
||
if !s.waitForHello(conn) {
|
||
log.Printf("server: device %s hello timeout or failed", deviceID)
|
||
return
|
||
}
|
||
ws.SetReadDeadline(time.Time{}) //nolint:errcheck // 握手成功,取消读超时
|
||
|
||
log.Printf("server: device %s handshaked, session=%s", deviceID, conn.SessionID)
|
||
|
||
// 阶段2:正常消息循环
|
||
for {
|
||
msgType, raw, err := ws.ReadMessage()
|
||
if err != nil {
|
||
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
||
if !isNetworkClose(err) {
|
||
log.Printf("server: read error for %s: %v", deviceID, err)
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
// 只处理文本消息
|
||
if msgType != websocket.TextMessage {
|
||
continue
|
||
}
|
||
|
||
var envelope struct {
|
||
Type string `json:"type"`
|
||
}
|
||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||
log.Printf("server: invalid json from %s: %v", deviceID, err)
|
||
continue
|
||
}
|
||
|
||
switch envelope.Type {
|
||
case "story":
|
||
go handler.HandleStory(conn, s.client)
|
||
case "abort":
|
||
handler.HandleAbort(conn)
|
||
default:
|
||
log.Printf("server: unhandled message type %q from %s", envelope.Type, deviceID)
|
||
}
|
||
}
|
||
}
|
||
|
||
// waitForHello 等待并处理第一条 hello 消息,成功返回 true。
|
||
func (s *Server) waitForHello(conn *connection.Connection) bool {
|
||
msgType, raw, err := conn.WS.ReadMessage()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
if msgType != websocket.TextMessage {
|
||
log.Printf("server: device %s sent non-text as first message", conn.DeviceID)
|
||
return false
|
||
}
|
||
|
||
var envelope struct {
|
||
Type string `json:"type"`
|
||
}
|
||
if err := json.Unmarshal(raw, &envelope); err != nil || envelope.Type != "hello" {
|
||
log.Printf("server: device %s first message is not hello (got %q)", conn.DeviceID, envelope.Type)
|
||
return false
|
||
}
|
||
|
||
if err := handler.HandleHello(conn, raw); err != nil {
|
||
log.Printf("server: device %s hello failed: %v", conn.DeviceID, err)
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
|
||
// register 注册连接,若同一设备已有连接则踢掉旧连接。
|
||
func (s *Server) register(conn *connection.Connection) error {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
if len(s.conns) >= maxConnections {
|
||
return errors.New("server: max connections reached")
|
||
}
|
||
|
||
// 同一设备同时只允许一个连接
|
||
if old, exists := s.conns[conn.DeviceID]; exists {
|
||
log.Printf("server: kicking old connection for %s", conn.DeviceID)
|
||
old.Close()
|
||
}
|
||
|
||
s.conns[conn.DeviceID] = conn
|
||
return nil
|
||
}
|
||
|
||
// SendCmd 向指定设备发送控制指令。
|
||
// 若设备不在线或未握手,返回 error。
|
||
func (s *Server) SendCmd(deviceID, action string, params any) error {
|
||
s.mu.Lock()
|
||
conn, ok := s.conns[deviceID]
|
||
s.mu.Unlock()
|
||
if !ok {
|
||
return fmt.Errorf("server: device %s not connected", deviceID)
|
||
}
|
||
if !conn.IsHandshaked() {
|
||
return fmt.Errorf("server: device %s not handshaked", deviceID)
|
||
}
|
||
return conn.SendCmd(action, params)
|
||
}
|
||
|
||
func (s *Server) unregister(deviceID string) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
delete(s.conns, deviceID)
|
||
}
|
||
|
||
func (s *Server) activeCount() int {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
return len(s.conns)
|
||
}
|
||
|
||
// handleStatus 返回服务状态和当前活跃连接数,用于部署后验证。
|
||
// GET /xiaozhi/v1/healthz → {"status":"ok","active_connections":N}
|
||
func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||
if r.Method != http.MethodGet {
|
||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||
return
|
||
}
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(http.StatusOK)
|
||
active := s.activeCount()
|
||
fmt.Fprintf(w, `{"status":"ok","active_connections":%d}`, active)
|
||
}
|
||
|
||
// isNetworkClose 判断是否为普通的网络关闭错误(不需要打印日志)。
|
||
func isNetworkClose(err error) bool {
|
||
var netErr *net.OpError
|
||
return errors.As(err, &netErr)
|
||
}
|