diff --git a/apps/devices/views.py b/apps/devices/views.py index c87c081..d9acb7a 100644 --- a/apps/devices/views.py +++ b/apps/devices/views.py @@ -306,6 +306,73 @@ class DeviceViewSet(viewsets.ViewSet): return success(message='WiFi 配置成功') + @action( + detail=False, methods=['get'], + url_path='stories', + authentication_classes=[], permission_classes=[AllowAny] + ) + def stories_by_mac(self, request): + """ + 获取设备关联用户的随机故事(公开接口,无需认证) + GET /api/v1/devices/stories/?mac_address=AA:BB:CC:DD:EE:FF + 供 hw-ws-service 调用。 + 优先返回用户自己的故事,无则兜底返回系统默认故事(is_default=True)。 + """ + mac = request.query_params.get('mac_address', '').strip() + if not mac: + return error(message='mac_address 参数不能为空') + + mac = mac.upper().replace('-', ':') + + try: + device = Device.objects.get(mac_address=mac) + except Device.DoesNotExist: + return error( + code=ErrorCode.DEVICE_NOT_FOUND, + message='未找到对应设备', + status_code=status.HTTP_404_NOT_FOUND + ) + + user_device = ( + UserDevice.objects + .filter(device=device, is_active=True, bind_type='owner') + .select_related('user') + .first() + ) + if not user_device: + return error( + code=ErrorCode.NOT_FOUND, + message='该设备尚未绑定用户', + status_code=status.HTTP_404_NOT_FOUND + ) + + from apps.stories.models import Story + # 优先随机取用户自己有 audio_url 的故事 + story = ( + Story.objects + .filter(user=user_device.user) + .exclude(audio_url='') + .order_by('?') + .first() + ) + # 兜底:用户暂无故事时使用系统默认故事 + if not story: + story = ( + Story.objects + .filter(is_default=True) + .exclude(audio_url='') + .order_by('?') + .first() + ) + if not story: + return error( + code=ErrorCode.STORY_NOT_FOUND, + message='暂无可播放的故事', + status_code=status.HTTP_404_NOT_FOUND + ) + + return success(data={'title': story.title, 'audio_url': story.audio_url}) + @action(detail=False, methods=['post'], url_path='report-status', authentication_classes=[], permission_classes=[AllowAny]) def report_status(self, request): diff --git a/hw_service_go/internal/connection/connection.go b/hw_service_go/internal/connection/connection.go index 30911fd..f606c6e 100644 --- a/hw_service_go/internal/connection/connection.go +++ b/hw_service_go/internal/connection/connection.go @@ -12,12 +12,14 @@ import ( // Connection 保存单个硬件连接的状态,所有方法并发安全。 type Connection struct { WS *websocket.Conn - DeviceID string // MAC 地址,来自 URL 参数 device-id - ClientID string // 来自 URL 参数 client-id + DeviceID string // MAC 地址,来自 URL 参数 device-id + ClientID string // 来自 URL 参数 client-id + SessionID string // 握手后分配的会话 ID - mu sync.Mutex - isPlaying bool - abortCh chan struct{} // close(abortCh) 通知流控 goroutine 中止播放 + mu sync.Mutex + handshaked bool // 是否已完成 hello 握手 + isPlaying bool + abortCh chan struct{} // close(abortCh) 通知流控 goroutine 中止播放 writeMu sync.Mutex // gorilla/websocket 写操作不并发安全,需独立锁 } @@ -31,6 +33,30 @@ func New(ws *websocket.Conn, deviceID, clientID string) *Connection { } } +// Handshake 完成 hello 握手,存储 session_id。 +func (c *Connection) Handshake(sessionID string) { + c.mu.Lock() + defer c.mu.Unlock() + c.SessionID = sessionID + c.handshaked = true +} + +// IsHandshaked 返回连接是否已完成 hello 握手。 +func (c *Connection) IsHandshaked() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.handshaked +} + +// SendCmd 向硬件发送控制指令,并发安全。 +func (c *Connection) SendCmd(action string, params any) error { + return c.SendJSON(map[string]any{ + "type": "cmd", + "action": action, + "params": params, + }) +} + // StartPlayback 开始新一轮播放,返回 abortCh 供流控 goroutine 监听。 // 若已在播放,先中止上一轮再开始新的。 func (c *Connection) StartPlayback() <-chan struct{} { diff --git a/hw_service_go/internal/handler/abort.go b/hw_service_go/internal/handler/abort.go new file mode 100644 index 0000000..c5767d5 --- /dev/null +++ b/hw_service_go/internal/handler/abort.go @@ -0,0 +1,13 @@ +package handler + +import ( + "log" + + "github.com/qy/hw-ws-service/internal/connection" +) + +// HandleAbort 处理硬件发来的 {"type":"abort"} 指令,中止当前播放。 +func HandleAbort(conn *connection.Connection) { + log.Printf("[abort][%s] stopping playback", conn.DeviceID) + conn.StopPlayback() +} diff --git a/hw_service_go/internal/handler/hello.go b/hw_service_go/internal/handler/hello.go new file mode 100644 index 0000000..ef8b600 --- /dev/null +++ b/hw_service_go/internal/handler/hello.go @@ -0,0 +1,45 @@ +package handler + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "log" + "strings" + + "github.com/qy/hw-ws-service/internal/connection" +) + +// helloMessage 是硬件发来的 hello 握手消息。 +type helloMessage struct { + MAC string `json:"mac"` +} + +// HandleHello 处理硬件的 hello 握手消息。 +// 校验 MAC 地址,分配 session_id,返回握手响应。 +func HandleHello(conn *connection.Connection, raw []byte) error { + var msg helloMessage + if err := json.Unmarshal(raw, &msg); err != nil { + return fmt.Errorf("hello: invalid json: %w", err) + } + + // MAC 地址与 URL 参数不一致时记录警告,但不拒绝连接 + if msg.MAC != "" && !strings.EqualFold(msg.MAC, conn.DeviceID) { + log.Printf("[hello][%s] MAC mismatch: url=%s body=%s", conn.DeviceID, conn.DeviceID, msg.MAC) + } + + sessionID := newSessionID() + conn.Handshake(sessionID) + + return conn.SendJSON(map[string]string{ + "type": "hello", + "status": "ok", + "session_id": sessionID, + }) +} + +func newSessionID() string { + b := make([]byte, 4) + rand.Read(b) //nolint:errcheck // crypto/rand.Read 在标准库中不会返回错误 + return fmt.Sprintf("%x", b) +} diff --git a/hw_service_go/internal/server/server.go b/hw_service_go/internal/server/server.go index 5ba90c5..6f4fcaf 100644 --- a/hw_service_go/internal/server/server.go +++ b/hw_service_go/internal/server/server.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "sync" + "time" "github.com/gorilla/websocket" "github.com/qy/hw-ws-service/internal/connection" @@ -22,6 +23,8 @@ const ( maxConnections = 500 // maxMessageBytes WebSocket 单条消息上限(4KB),防止内存耗尽攻击。 maxMessageBytes = 4 * 1024 + // helloTimeout 握手超时:连接建立后必须在此时间内发送 hello,否则断开。 + helloTimeout = 10 * time.Second ) var upgrader = websocket.Upgrader{ @@ -131,7 +134,17 @@ func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) { log.Printf("server: device %s connected, active=%d", deviceID, s.activeCount()) - // 消息读取循环 + // 阶段1:等待 hello 握手(超时 helloTimeout) + ws.SetReadDeadline(time.Now().Add(helloTimeout)) //nolint:errcheck + if !s.waitForHello(conn) { + log.Printf("server: device %s hello timeout or failed", deviceID) + return + } + ws.SetReadDeadline(time.Time{}) //nolint:errcheck // 握手成功,取消读超时 + + log.Printf("server: device %s handshaked, session=%s", deviceID, conn.SessionID) + + // 阶段2:正常消息循环 for { msgType, raw, err := ws.ReadMessage() if err != nil { @@ -143,7 +156,7 @@ func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) { return } - // 只处理文本消息(二进制为上行音频,本服务暂不处理) + // 只处理文本消息 if msgType != websocket.TextMessage { continue } @@ -159,12 +172,40 @@ func (s *Server) handleConn(w http.ResponseWriter, r *http.Request) { switch envelope.Type { case "story": go handler.HandleStory(conn, s.client) + case "abort": + handler.HandleAbort(conn) default: log.Printf("server: unhandled message type %q from %s", envelope.Type, deviceID) } } } +// waitForHello 等待并处理第一条 hello 消息,成功返回 true。 +func (s *Server) waitForHello(conn *connection.Connection) bool { + msgType, raw, err := conn.WS.ReadMessage() + if err != nil { + return false + } + if msgType != websocket.TextMessage { + log.Printf("server: device %s sent non-text as first message", conn.DeviceID) + return false + } + + var envelope struct { + Type string `json:"type"` + } + if err := json.Unmarshal(raw, &envelope); err != nil || envelope.Type != "hello" { + log.Printf("server: device %s first message is not hello (got %q)", conn.DeviceID, envelope.Type) + return false + } + + if err := handler.HandleHello(conn, raw); err != nil { + log.Printf("server: device %s hello failed: %v", conn.DeviceID, err) + return false + } + return true +} + // register 注册连接,若同一设备已有连接则踢掉旧连接。 func (s *Server) register(conn *connection.Connection) error { s.mu.Lock() @@ -184,6 +225,21 @@ func (s *Server) register(conn *connection.Connection) error { return nil } +// SendCmd 向指定设备发送控制指令。 +// 若设备不在线或未握手,返回 error。 +func (s *Server) SendCmd(deviceID, action string, params any) error { + s.mu.Lock() + conn, ok := s.conns[deviceID] + s.mu.Unlock() + if !ok { + return fmt.Errorf("server: device %s not connected", deviceID) + } + if !conn.IsHandshaked() { + return fmt.Errorf("server: device %s not handshaked", deviceID) + } + return conn.SendCmd(action, params) +} + func (s *Server) unregister(deviceID string) { s.mu.Lock() defer s.mu.Unlock()