// Package server 实现 WebSocket 服务器,管理硬件设备连接的生命周期。 package server import ( "context" "encoding/json" "errors" "fmt" "log" "net" "net/http" "sync" "github.com/gorilla/websocket" "github.com/qy/hw-ws-service/internal/connection" "github.com/qy/hw-ws-service/internal/handler" "github.com/qy/hw-ws-service/internal/rtcclient" ) const ( // maxConnections 最大并发连接数,防止资源耗尽。 maxConnections = 500 // maxMessageBytes WebSocket 单条消息上限(4KB),防止内存耗尽攻击。 maxMessageBytes = 4 * 1024 ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, // IoT 设备无浏览器 Origin,允许所有来源 CheckOrigin: func(r *http.Request) bool { return true }, } // Server 管理所有活跃的设备连接。 type Server struct { client *rtcclient.Client httpServer *http.Server mu sync.Mutex conns map[string]*connection.Connection // key: DeviceID wg sync.WaitGroup // 跟踪所有连接 goroutine } // New 创建 Server,addr 形如 "0.0.0.0:8888"。 func New(addr string, client *rtcclient.Client) *Server { s := &Server{ client: client, conns: make(map[string]*connection.Connection), } mux := http.NewServeMux() mux.HandleFunc("/xiaozhi/v1/healthz", s.handleStatus) mux.HandleFunc("/xiaozhi/v1/", s.handleConn) mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) s.httpServer = &http.Server{ Addr: addr, Handler: mux, } return s } // ListenAndServe 启动服务器,阻塞直到服务器关闭。 func (s *Server) ListenAndServe() error { log.Printf("server: listening on %s", s.httpServer.Addr) err := s.httpServer.ListenAndServe() if errors.Is(err, http.ErrServerClosed) { return nil } return err } // Shutdown 优雅关闭:先停止接受新连接,再等待所有连接 goroutine 退出。 func (s *Server) Shutdown(ctx context.Context) { log.Println("server: shutting down...") s.httpServer.Shutdown(ctx) //nolint:errcheck // 等待所有连接 goroutine 退出(由 ctx 超时兜底) done := make(chan struct{}) go func() { s.wg.Wait() close(done) }() select { case <-done: log.Println("server: all connections closed gracefully") case <-ctx.Done(): log.Println("server: shutdown timeout, forcing close") } } // handleConn 处理单个 WebSocket 连接的完整生命周期。 // URL 格式:/xiaozhi/v1/?device-id=&client-id= func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) { deviceID := r.URL.Query().Get("device-id") clientID := r.URL.Query().Get("client-id") if deviceID == "" { http.Error(w, "missing device-id", http.StatusBadRequest) return } ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("server: upgrade failed for %s: %v", deviceID, err) return } // 设置单条消息大小上限 ws.SetReadLimit(maxMessageBytes) conn := connection.New(ws, deviceID, clientID) if err := s.register(conn); err != nil { log.Printf("server: register %s failed: %v", deviceID, err) ws.Close() return } s.wg.Add(1) defer func() { conn.StopPlayback() s.unregister(deviceID) ws.Close() s.wg.Done() log.Printf("server: device %s disconnected, active=%d", deviceID, s.activeCount()) }() log.Printf("server: device %s connected, active=%d", deviceID, s.activeCount()) // 消息读取循环 for { msgType, raw, err := ws.ReadMessage() if err != nil { if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { if !isNetworkClose(err) { log.Printf("server: read error for %s: %v", deviceID, err) } } return } // 只处理文本消息(二进制为上行音频,本服务暂不处理) if msgType != websocket.TextMessage { continue } var envelope struct { Type string `json:"type"` } if err := json.Unmarshal(raw, &envelope); err != nil { log.Printf("server: invalid json from %s: %v", deviceID, err) continue } switch envelope.Type { case "story": go handler.HandleStory(conn, s.client) default: log.Printf("server: unhandled message type %q from %s", envelope.Type, deviceID) } } } // register 注册连接,若同一设备已有连接则踢掉旧连接。 func (s *Server) register(conn *connection.Connection) error { s.mu.Lock() defer s.mu.Unlock() if len(s.conns) >= maxConnections { return errors.New("server: max connections reached") } // 同一设备同时只允许一个连接 if old, exists := s.conns[conn.DeviceID]; exists { log.Printf("server: kicking old connection for %s", conn.DeviceID) old.Close() } s.conns[conn.DeviceID] = conn return nil } func (s *Server) unregister(deviceID string) { s.mu.Lock() defer s.mu.Unlock() delete(s.conns, deviceID) } func (s *Server) activeCount() int { s.mu.Lock() defer s.mu.Unlock() return len(s.conns) } // handleStatus 返回服务状态和当前活跃连接数,用于部署后验证。 // GET /xiaozhi/v1/healthz → {"status":"ok","active_connections":N} func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) active := s.activeCount() fmt.Fprintf(w, `{"status":"ok","active_connections":%d}`, active) } // isNetworkClose 判断是否为普通的网络关闭错误(不需要打印日志)。 func isNetworkClose(err error) bool { var netErr *net.OpError return errors.As(err, &netErr) }