package connection_test import ( "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/gorilla/websocket" "github.com/qy/hw-ws-service/internal/connection" ) // makeWSPair creates a real WebSocket pair for testing. // Returns the server-side conn (what our code uses) and the client-side conn // (what simulates the hardware). Call cleanup() after the test. func makeWSPair(t *testing.T) (svrWS *websocket.Conn, cliWS *websocket.Conn, cleanup func()) { t.Helper() ch := make(chan *websocket.Conn, 1) done := make(chan struct{}) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { up := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} c, err := up.Upgrade(w, r, nil) if err != nil { t.Logf("upgrade error: %v", err) return } ch <- c <-done // hold handler open until cleanup })) wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") cli, _, err := websocket.DefaultDialer.Dial(wsURL, nil) if err != nil { close(done) srv.Close() t.Fatalf("dial error: %v", err) } svr := <-ch return svr, cli, func() { close(done) svr.Close() cli.Close() srv.Close() } } func TestConnection_InitialState(t *testing.T) { svrWS, _, cleanup := makeWSPair(t) defer cleanup() conn := connection.New(svrWS, "AA:BB:CC:DD:EE:FF", "client-uuid") if conn.DeviceID != "AA:BB:CC:DD:EE:FF" { t.Errorf("DeviceID = %q", conn.DeviceID) } if conn.ClientID != "client-uuid" { t.Errorf("ClientID = %q", conn.ClientID) } if conn.IsPlaying() { t.Error("new connection should not be playing") } } func TestConnection_StartStopPlayback(t *testing.T) { svrWS, _, cleanup := makeWSPair(t) defer cleanup() conn := connection.New(svrWS, "dev1", "cli1") ch := conn.StartPlayback() if ch == nil { t.Fatal("StartPlayback should return a non-nil channel") } if !conn.IsPlaying() { t.Error("IsPlaying should be true after StartPlayback") } // Channel must still be open select { case <-ch: t.Error("abortCh should not be closed yet") default: } conn.StopPlayback() if conn.IsPlaying() { t.Error("IsPlaying should be false after StopPlayback") } } // TestConnection_StartPlayback_AbortsOld verifies that calling StartPlayback a second // time closes the previous abort channel, stopping any in-progress streaming. func TestConnection_StartPlayback_AbortsOld(t *testing.T) { svrWS, _, cleanup := makeWSPair(t) defer cleanup() conn := connection.New(svrWS, "dev1", "cli1") ch1 := conn.StartPlayback() ch2 := conn.StartPlayback() // should close ch1 // ch1 must be closed now select { case <-ch1: // expected case <-time.After(100 * time.Millisecond): t.Error("first abortCh should be closed by second StartPlayback call") } // ch2 must still be open select { case <-ch2: t.Error("second abortCh should not be closed yet") default: } } // TestConnection_SendJSON verifies JSON messages are delivered to the client. func TestConnection_SendJSON(t *testing.T) { svrWS, cliWS, cleanup := makeWSPair(t) defer cleanup() conn := connection.New(svrWS, "dev1", "cli1") if err := conn.SendJSON(map[string]string{"type": "tts", "state": "start"}); err != nil { t.Fatalf("SendJSON error: %v", err) } cliWS.SetReadDeadline(time.Now().Add(2 * time.Second)) msgType, data, err := cliWS.ReadMessage() if err != nil { t.Fatalf("client read error: %v", err) } if msgType != websocket.TextMessage { t.Errorf("message type = %d, want TextMessage (%d)", msgType, websocket.TextMessage) } var got map[string]string if err := json.Unmarshal(data, &got); err != nil { t.Fatalf("json.Unmarshal error: %v", err) } if got["type"] != "tts" { t.Errorf("type = %q, want %q", got["type"], "tts") } if got["state"] != "start" { t.Errorf("state = %q, want %q", got["state"], "start") } } // TestConnection_SendBinary verifies binary (Opus) frames are delivered to the client. func TestConnection_SendBinary(t *testing.T) { svrWS, cliWS, cleanup := makeWSPair(t) defer cleanup() conn := connection.New(svrWS, "dev1", "cli1") payload := []byte{0x01, 0x02, 0x03, 0x04} if err := conn.SendBinary(payload); err != nil { t.Fatalf("SendBinary error: %v", err) } cliWS.SetReadDeadline(time.Now().Add(2 * time.Second)) msgType, data, err := cliWS.ReadMessage() if err != nil { t.Fatalf("client read error: %v", err) } if msgType != websocket.BinaryMessage { t.Errorf("message type = %d, want BinaryMessage (%d)", msgType, websocket.BinaryMessage) } if string(data) != string(payload) { t.Errorf("payload = %v, want %v", data, payload) } }