// Package server 实现 WebSocket 服务器,管理硬件设备连接的生命周期。 package server import ( "context" "encoding/json" "errors" "fmt" "log" "net" "net/http" "sync" "time" "github.com/gorilla/websocket" "github.com/qy/hw-ws-service/internal/connection" "github.com/qy/hw-ws-service/internal/handler" "github.com/qy/hw-ws-service/internal/rtcclient" ) const ( // maxConnections 最大并发连接数,防止资源耗尽。 maxConnections = 500 // maxMessageBytes WebSocket 单条消息上限(4KB),防止内存耗尽攻击。 maxMessageBytes = 4 * 1024 // helloTimeout 握手超时:连接建立后必须在此时间内发送 hello,否则断开。 helloTimeout = 10 * time.Second ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // IoT 设备无浏览器 Origin,允许所有来源 CheckOrigin: func(r *http.Request) bool { return true }, } // Server 管理所有活跃的设备连接。 type Server struct { client *rtcclient.Client httpServer *http.Server mu sync.Mutex conns map[string]*connection.Connection // key: DeviceID wg sync.WaitGroup // 跟踪所有连接 goroutine } // New 创建 Server,addr 形如 "0.0.0.0:8888"。 func New(addr string, client *rtcclient.Client) *Server { s := &Server{ client: client, conns: make(map[string]*connection.Connection), } mux := http.NewServeMux() mux.HandleFunc("/xiaozhi/v1/healthz", s.handleStatus) mux.HandleFunc("/xiaozhi/v1/", s.handleConn) mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) s.httpServer = &http.Server{ Addr: addr, Handler: mux, } return s } // ListenAndServe 启动服务器,阻塞直到服务器关闭。 func (s *Server) ListenAndServe() error { log.Printf("server: listening on %s", s.httpServer.Addr) err := s.httpServer.ListenAndServe() if errors.Is(err, http.ErrServerClosed) { return nil } return err } // Shutdown 优雅关闭:先停止接受新连接,再等待所有连接 goroutine 退出。 func (s *Server) Shutdown(ctx context.Context) { log.Println("server: shutting down...") s.httpServer.Shutdown(ctx) //nolint:errcheck // 等待所有连接 goroutine 退出(由 ctx 超时兜底) done := make(chan struct{}) go func() { s.wg.Wait() close(done) }() select { case <-done: log.Println("server: all connections closed gracefully") case <-ctx.Done(): log.Println("server: shutdown timeout, forcing close") } } // handleConn 处理单个 WebSocket 连接的完整生命周期。 // URL 格式:/xiaozhi/v1/?device-id=&client-id= func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/xiaozhi/v1/healthz" { s.handleStatus(w, r) return } deviceID := r.URL.Query().Get("device-id") clientID := r.URL.Query().Get("client-id") if deviceID == "" { http.Error(w, "missing device-id", http.StatusBadRequest) return } ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("server: upgrade failed for %s: %v", deviceID, err) return } // 设置单条消息大小上限 ws.SetReadLimit(maxMessageBytes) conn := connection.New(ws, deviceID, clientID) if err := s.register(conn); err != nil { log.Printf("server: register %s failed: %v", deviceID, err) ws.Close() return } s.wg.Add(1) defer func() { conn.StopPlayback() s.unregister(deviceID) ws.Close() s.wg.Done() log.Printf("server: device %s disconnected, active=%d", deviceID, s.activeCount()) }() log.Printf("server: device %s connected, active=%d", deviceID, s.activeCount()) // 阶段1:等待 hello 握手(超时 helloTimeout) ws.SetReadDeadline(time.Now().Add(helloTimeout)) //nolint:errcheck if !s.waitForHello(conn) { log.Printf("server: device %s hello timeout or failed", deviceID) return } ws.SetReadDeadline(time.Time{}) //nolint:errcheck // 握手成功,取消读超时 log.Printf("server: device %s handshaked, session=%s", deviceID, conn.SessionID) // 阶段2:正常消息循环 for { msgType, raw, err := ws.ReadMessage() if err != nil { if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { if !isNetworkClose(err) { log.Printf("server: read error for %s: %v", deviceID, err) } } return } // 只处理文本消息 if msgType != websocket.TextMessage { continue } var envelope struct { Type string `json:"type"` } if err := json.Unmarshal(raw, &envelope); err != nil { log.Printf("server: invalid json from %s: %v", deviceID, err) continue } switch envelope.Type { case "story": go handler.HandleStory(conn, s.client) case "abort": handler.HandleAbort(conn) default: log.Printf("server: unhandled message type %q from %s", envelope.Type, deviceID) } } } // waitForHello 等待并处理第一条 hello 消息,成功返回 true。 func (s *Server) waitForHello(conn *connection.Connection) bool { msgType, raw, err := conn.WS.ReadMessage() if err != nil { return false } if msgType != websocket.TextMessage { log.Printf("server: device %s sent non-text as first message", conn.DeviceID) return false } var envelope struct { Type string `json:"type"` } if err := json.Unmarshal(raw, &envelope); err != nil || envelope.Type != "hello" { log.Printf("server: device %s first message is not hello (got %q)", conn.DeviceID, envelope.Type) return false } if err := handler.HandleHello(conn, raw); err != nil { log.Printf("server: device %s hello failed: %v", conn.DeviceID, err) return false } return true } // register 注册连接,若同一设备已有连接则踢掉旧连接。 func (s *Server) register(conn *connection.Connection) error { s.mu.Lock() defer s.mu.Unlock() if len(s.conns) >= maxConnections { return errors.New("server: max connections reached") } // 同一设备同时只允许一个连接 if old, exists := s.conns[conn.DeviceID]; exists { log.Printf("server: kicking old connection for %s", conn.DeviceID) old.Close() } s.conns[conn.DeviceID] = conn return nil } // SendCmd 向指定设备发送控制指令。 // 若设备不在线或未握手,返回 error。 func (s *Server) SendCmd(deviceID, action string, params any) error { s.mu.Lock() conn, ok := s.conns[deviceID] s.mu.Unlock() if !ok { return fmt.Errorf("server: device %s not connected", deviceID) } if !conn.IsHandshaked() { return fmt.Errorf("server: device %s not handshaked", deviceID) } return conn.SendCmd(action, params) } func (s *Server) unregister(deviceID string) { s.mu.Lock() defer s.mu.Unlock() delete(s.conns, deviceID) } func (s *Server) activeCount() int { s.mu.Lock() defer s.mu.Unlock() return len(s.conns) } // handleStatus 返回服务状态和当前活跃连接数,用于部署后验证。 // GET /xiaozhi/v1/healthz → {"status":"ok","active_connections":N} func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) active := s.activeCount() fmt.Fprintf(w, `{"status":"ok","active_connections":%d}`, active) } // isNetworkClose 判断是否为普通的网络关闭错误(不需要打印日志)。 func isNetworkClose(err error) bool { var netErr *net.OpError return errors.As(err, &netErr) }