- 接入豆包TTS V1 WebSocket API,支持故事朗读语音合成 - 新增 PillProgressButton 组件(药丸形进度按钮) - 新增 TTSService 单例,后台生成不中断 - 音频保存到 Capybara audio/ 目录 - 唱片架当前播放歌曲高亮(金色卡片+音波动效+喇叭图标) - 播放时气泡持续显示当前歌名,暂停后隐藏 - 音乐总监Prompt去固定模板,歌名不再重复 - 新增 API 参考文档(豆包语音合成) Co-authored-by: Cursor <cursoragent@cursor.com>
859 lines
34 KiB
Python
859 lines
34 KiB
Python
import os
|
||
import re
|
||
import sys
|
||
import time
|
||
import uuid
|
||
import struct
|
||
import asyncio
|
||
import uvicorn
|
||
import requests
|
||
import json
|
||
import websockets
|
||
from fastapi import FastAPI, HTTPException, Query
|
||
from fastapi.responses import StreamingResponse
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from dotenv import load_dotenv
|
||
|
||
# Force UTF-8 stdout/stderr on Windows (avoids GBK encoding errors)
|
||
if sys.platform == "win32":
|
||
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||
|
||
# Load Environment Variables
|
||
load_dotenv()
|
||
MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY")
|
||
VOLCENGINE_API_KEY = os.getenv("VOLCENGINE_API_KEY")
|
||
TTS_APP_ID = os.getenv("TTS_APP_ID")
|
||
TTS_ACCESS_TOKEN = os.getenv("TTS_ACCESS_TOKEN")
|
||
|
||
if not MINIMAX_API_KEY:
|
||
print("Warning: MINIMAX_API_KEY not found in .env")
|
||
if not VOLCENGINE_API_KEY:
|
||
print("Warning: VOLCENGINE_API_KEY not found in .env")
|
||
if not TTS_APP_ID or not TTS_ACCESS_TOKEN:
|
||
print("Warning: TTS_APP_ID or TTS_ACCESS_TOKEN not found in .env")
|
||
|
||
# Initialize FastAPI
|
||
app = FastAPI()
|
||
|
||
# Allow CORS for local frontend
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Request Models
|
||
class MusicRequest(BaseModel):
|
||
text: str
|
||
mood: str = "custom" # 'chill', 'happy', 'sleepy', 'random', 'custom'
|
||
|
||
class StoryRequest(BaseModel):
|
||
characters: list[str] = []
|
||
scenes: list[str] = []
|
||
props: list[str] = []
|
||
|
||
# Minimax Constants
|
||
MINIMAX_GROUP_ID = "YOUR_GROUP_ID"
|
||
BASE_URL_CHAT = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||
BASE_URL_MUSIC = "https://api.minimaxi.com/v1/music_generation"
|
||
|
||
# Load System Prompts
|
||
try:
|
||
with open("prompts/music_director.md", "r", encoding="utf-8") as f:
|
||
SYSTEM_PROMPT = f.read()
|
||
except FileNotFoundError:
|
||
SYSTEM_PROMPT = "You are a music director AI. Convert user input into JSON with 'style' (English description) and 'lyrics' (Chinese, structured)."
|
||
print("Warning: prompts/music_director.md not found, using default.")
|
||
|
||
try:
|
||
with open("prompts/story_director.md", "r", encoding="utf-8") as f:
|
||
STORY_SYSTEM_PROMPT = f.read()
|
||
except FileNotFoundError:
|
||
STORY_SYSTEM_PROMPT = "你是一个儿童故事大师。根据用户提供的角色、场景、道具素材创作一个300-600字的儿童故事。只返回JSON格式:{\"title\": \"标题\", \"content\": \"正文\"}"
|
||
print("Warning: prompts/story_director.md not found, using default.")
|
||
|
||
# Volcengine / Doubao constants
|
||
DOUBAO_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||
DOUBAO_MODEL = "doubao-seed-1-6-lite-251015" # Doubao-Seed-1.6-lite
|
||
|
||
|
||
def sse_event(data):
|
||
"""Format a dict as an SSE data line.
|
||
Use ensure_ascii=True so all non-ASCII chars become \\uXXXX escapes,
|
||
avoiding Windows GBK encoding issues in the SSE stream."""
|
||
return f"data: {json.dumps(data, ensure_ascii=True)}\n\n"
|
||
|
||
|
||
def clean_lyrics(raw: str) -> str:
|
||
"""Clean lyrics extracted from LLM JSON output.
|
||
Removes JSON artifacts, structure tags, and normalizes formatting."""
|
||
if not raw:
|
||
return raw
|
||
s = raw
|
||
# Replace literal \n with real newlines
|
||
s = s.replace("\\n", "\n")
|
||
# Remove JSON string quotes and concatenation artifacts (" ")
|
||
s = re.sub(r'"\s*"', '', s)
|
||
s = s.replace('"', '')
|
||
# Remove structure tags like [verse 1], [chorus], [outro], [bridge], [intro], etc.
|
||
s = re.sub(r'\[(?:verse|chorus|bridge|outro|intro|hook|pre-chorus|interlude|inst)\s*\d*\]\s*', '', s, flags=re.IGNORECASE)
|
||
# Strip leading/trailing whitespace from each line
|
||
lines = [line.strip() for line in s.split('\n')]
|
||
s = '\n'.join(lines)
|
||
# Collapse 3+ consecutive newlines into 2 (one blank line between paragraphs)
|
||
s = re.sub(r'\n{3,}', '\n\n', s)
|
||
# Remove leading/trailing blank lines
|
||
s = s.strip()
|
||
return s
|
||
|
||
|
||
@app.post("/api/create_music")
|
||
def create_music(req: MusicRequest):
|
||
"""SSE streaming endpoint – pushes progress events to the frontend."""
|
||
print(f"[Music] Received request: {req.text} [{req.mood}]", flush=True)
|
||
|
||
def event_stream():
|
||
import sys
|
||
def log(msg):
|
||
print(msg, flush=True)
|
||
sys.stdout.flush()
|
||
|
||
# ── Stage 1: LLM "Music Director" ────────────────────────
|
||
log("[Stage 1] Starting LLM call...")
|
||
yield sse_event({
|
||
"stage": "lyrics",
|
||
"progress": 10,
|
||
"message": "AI 正在创作词曲..."
|
||
})
|
||
|
||
director_input = f"用户场景描述: {req.text}。 (预设氛围参考: {req.mood})"
|
||
|
||
try:
|
||
chat_resp = requests.post(
|
||
BASE_URL_CHAT,
|
||
headers={
|
||
"Authorization": f"Bearer {MINIMAX_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
},
|
||
json={
|
||
"model": "abab6.5s-chat",
|
||
"messages": [
|
||
{"role": "system", "content": SYSTEM_PROMPT},
|
||
{"role": "user", "content": director_input}
|
||
],
|
||
"max_tokens": 2048 # Enough for long lyrics
|
||
},
|
||
timeout=60
|
||
)
|
||
|
||
chat_data = chat_resp.json()
|
||
log(f"[Debug] Chat API status: {chat_resp.status_code}, resp keys: {list(chat_data.keys())}")
|
||
if "choices" not in chat_data or not chat_data["choices"]:
|
||
base = chat_data.get("base_resp", {})
|
||
raise ValueError(f"Chat API error ({base.get('status_code')}): {base.get('status_msg')}")
|
||
content_str = chat_data["choices"][0]["message"]["content"]
|
||
log(f"[Debug] LLM raw output (first 200): {content_str[:200]}")
|
||
# Strip markdown code fences if present
|
||
content_str = content_str.strip()
|
||
if content_str.startswith("```"):
|
||
content_str = re.sub(r'^```\w*\n?', '', content_str)
|
||
content_str = re.sub(r'```\s*$', '', content_str).strip()
|
||
|
||
# Try to extract JSON from response (robust parsing)
|
||
json_match = re.search(r'\{[\s\S]*\}', content_str)
|
||
if json_match:
|
||
json_str = json_match.group()
|
||
try:
|
||
metadata = json.loads(json_str)
|
||
except json.JSONDecodeError:
|
||
# JSON might have unescaped newlines in string values — try fixing
|
||
log(f"[Warn] JSON parse failed, attempting repair...")
|
||
# Extract fields manually via regex
|
||
title_m = re.search(r'"song_title"\s*:\s*"([^"]*)"', json_str)
|
||
style_m = re.search(r'"style"\s*:\s*"([^"]*)"', json_str)
|
||
lyrics_m = re.search(r'"lyrics"\s*:\s*"([\s\S]*)', json_str)
|
||
lyrics_val = ""
|
||
if lyrics_m:
|
||
# Take everything after "lyrics": " and strip trailing quotes/braces
|
||
lyrics_val = lyrics_m.group(1)
|
||
lyrics_val = re.sub(r'"\s*\}\s*$', '', lyrics_val).strip()
|
||
metadata = {
|
||
"song_title": title_m.group(1) if title_m else "",
|
||
"style": style_m.group(1) if style_m else "Pop music, cheerful",
|
||
"lyrics": lyrics_val
|
||
}
|
||
log(f"[Repaired] title={metadata['song_title']}, style={metadata['style'][:60]}")
|
||
elif content_str.strip().startswith("{"):
|
||
# JSON is incomplete (missing closing brace) — try adding it
|
||
log(f"[Warn] Incomplete JSON, attempting to close...")
|
||
try:
|
||
metadata = json.loads(content_str + '"}\n}')
|
||
except json.JSONDecodeError:
|
||
# Manual extraction as last resort
|
||
title_m = re.search(r'"song_title"\s*:\s*"([^"]*)"', content_str)
|
||
style_m = re.search(r'"style"\s*:\s*"([^"]*)"', content_str)
|
||
lyrics_m = re.search(r'"lyrics"\s*:\s*"([\s\S]*)', content_str)
|
||
lyrics_val = lyrics_m.group(1).rstrip('"} \n') if lyrics_m else "[Inst]"
|
||
metadata = {
|
||
"song_title": title_m.group(1) if title_m else "",
|
||
"style": style_m.group(1) if style_m else "Pop music, cheerful",
|
||
"lyrics": lyrics_val
|
||
}
|
||
log(f"[Repaired] title={metadata.get('song_title')}")
|
||
else:
|
||
raise ValueError(f"No JSON in LLM response: {content_str[:100]}")
|
||
|
||
style_val = metadata.get("style", "")
|
||
lyrics_val = clean_lyrics(metadata.get("lyrics", ""))
|
||
metadata["lyrics"] = lyrics_val # Store cleaned version
|
||
log(f"[Director] Style: {style_val[:80]}")
|
||
log(f"[Director] Lyrics (first 60): {lyrics_val[:60]}")
|
||
|
||
yield sse_event({
|
||
"stage": "lyrics_done",
|
||
"progress": 25,
|
||
"message": "词曲创作完成!准备生成音乐..."
|
||
})
|
||
|
||
except Exception as e:
|
||
log(f"[Error] Director LLM Failed: {e}")
|
||
metadata = {
|
||
"style": "Lofi hip hop, relaxing, slow tempo, water sounds",
|
||
"lyrics": "[Inst]"
|
||
}
|
||
yield sse_event({
|
||
"stage": "lyrics_fallback",
|
||
"progress": 25,
|
||
"message": "使用默认风格,准备生成音乐..."
|
||
})
|
||
|
||
# ── Stage 2: Music Generation ────────────────────────────
|
||
yield sse_event({
|
||
"stage": "music",
|
||
"progress": 30,
|
||
"message": "正在生成音乐,请耐心等待..."
|
||
})
|
||
|
||
try:
|
||
raw_lyrics = metadata.get("lyrics") or ""
|
||
# API requires lyrics >= 1 char
|
||
if not raw_lyrics.strip() or "[instrumental]" in raw_lyrics.lower():
|
||
raw_lyrics = "[Inst]"
|
||
|
||
music_payload = {
|
||
"model": "music-2.5",
|
||
"prompt": metadata.get("style", "Pop music"),
|
||
"lyrics": raw_lyrics,
|
||
"audio_setting": {
|
||
"sample_rate": 44100,
|
||
"bitrate": 256000,
|
||
"format": "mp3"
|
||
}
|
||
}
|
||
log(f"[Debug] Music payload prompt: {music_payload['prompt'][:80]}")
|
||
log(f"[Debug] Music payload lyrics (first 60): {music_payload['lyrics'][:60]}")
|
||
|
||
music_resp = requests.post(
|
||
BASE_URL_MUSIC,
|
||
headers={
|
||
"Authorization": f"Bearer {MINIMAX_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
},
|
||
json=music_payload,
|
||
timeout=300 # 5 min — music generation can be slow
|
||
)
|
||
|
||
music_data = music_resp.json()
|
||
base_resp = music_data.get("base_resp", {})
|
||
log(f"[Debug] Music API status: {music_resp.status_code}, base_resp: {base_resp}")
|
||
|
||
if music_data.get("data") and music_data["data"].get("audio"):
|
||
hex_audio = music_data["data"]["audio"]
|
||
log(f"[OK] Music generated! Audio hex length: {len(hex_audio)}")
|
||
|
||
# ── Stage 3: Saving ──────────────────────────────
|
||
yield sse_event({
|
||
"stage": "saving",
|
||
"progress": 90,
|
||
"message": "音乐生成完成,正在保存..."
|
||
})
|
||
|
||
save_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara music")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# Prefer song_title from LLM; fallback to user input
|
||
raw_title = metadata.get("song_title") or req.text
|
||
safe_name = re.sub(r'[^\w\u4e00-\u9fff]', '', raw_title)[:20] or "ai_song"
|
||
filename = f"{safe_name}_{int(time.time())}.mp3"
|
||
filepath = os.path.join(save_dir, filename)
|
||
|
||
audio_bytes = bytes.fromhex(hex_audio)
|
||
with open(filepath, "wb") as f:
|
||
f.write(audio_bytes)
|
||
log(f"[Saved] {filepath}")
|
||
|
||
# Save lyrics txt
|
||
lyrics_text = metadata.get("lyrics", "")
|
||
if lyrics_text:
|
||
lyrics_dir = os.path.join(save_dir, "lyrics")
|
||
os.makedirs(lyrics_dir, exist_ok=True)
|
||
lyrics_filename = f"{safe_name}_{int(time.time())}.txt"
|
||
with open(os.path.join(lyrics_dir, lyrics_filename), "w", encoding="utf-8") as lf:
|
||
lf.write(lyrics_text)
|
||
|
||
relative_path = f"Capybara music/{filename}"
|
||
|
||
# ── Done ─────────────────────────────────────────
|
||
yield sse_event({
|
||
"stage": "done",
|
||
"progress": 100,
|
||
"message": "新歌出炉!",
|
||
"status": "success",
|
||
"file_path": relative_path,
|
||
"metadata": metadata
|
||
})
|
||
else:
|
||
error_msg = base_resp.get("status_msg", "unknown")
|
||
error_code = base_resp.get("status_code", -1)
|
||
log(f"[Error] Music Gen failed: {error_code} - {error_msg}")
|
||
yield sse_event({
|
||
"stage": "error",
|
||
"progress": 0,
|
||
"message": f"生成失败 ({error_code}): {error_msg}"
|
||
})
|
||
|
||
except requests.exceptions.Timeout:
|
||
log("[Error] Music Gen Timeout")
|
||
yield sse_event({
|
||
"stage": "error",
|
||
"progress": 0,
|
||
"message": "音乐生成超时,请稍后再试"
|
||
})
|
||
except Exception as e:
|
||
log(f"[Error] API exception: {e}")
|
||
yield sse_event({
|
||
"stage": "error",
|
||
"progress": 0,
|
||
"message": f"服务器错误: {str(e)}"
|
||
})
|
||
|
||
return StreamingResponse(
|
||
event_stream(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"X-Accel-Buffering": "no",
|
||
"Connection": "keep-alive"
|
||
}
|
||
)
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════
|
||
# ── Story Generation (Doubao / Volcengine) ──
|
||
# ═══════════════════════════════════════════════════════════════════
|
||
|
||
@app.post("/api/create_story")
|
||
def create_story(req: StoryRequest):
|
||
"""SSE streaming endpoint – generates a children's story via Doubao LLM."""
|
||
print(f"[Story] Received request: characters={req.characters}, scenes={req.scenes}, props={req.props}", flush=True)
|
||
|
||
def event_stream():
|
||
def log(msg):
|
||
print(msg, flush=True)
|
||
|
||
# ── Stage 1: Connecting ──
|
||
yield sse_event({"stage": "connecting", "progress": 5, "message": "正在连接 AI..."})
|
||
|
||
# Build user prompt from selected elements
|
||
parts = []
|
||
if req.characters:
|
||
parts.append(f"角色=[{', '.join(req.characters)}]")
|
||
if req.scenes:
|
||
parts.append(f"场景=[{', '.join(req.scenes)}]")
|
||
if req.props:
|
||
parts.append(f"道具=[{', '.join(req.props)}]")
|
||
user_prompt = "请用这些素材创作一个故事:" + ",".join(parts) if parts else "请随机创作一个有趣的儿童故事"
|
||
|
||
log(f"[Story] User prompt: {user_prompt}")
|
||
|
||
# ── Stage 2: Generating (streaming) ──
|
||
yield sse_event({"stage": "generating", "progress": 10, "message": "故事正在诞生..."})
|
||
|
||
try:
|
||
# Explicitly encode as UTF-8 to avoid Windows GBK encoding issues
|
||
payload = json.dumps({
|
||
"model": DOUBAO_MODEL,
|
||
"messages": [
|
||
{"role": "system", "content": STORY_SYSTEM_PROMPT},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"max_tokens": 2048,
|
||
"stream": True,
|
||
"thinking": {"type": "disabled"},
|
||
}, ensure_ascii=False)
|
||
|
||
resp = requests.post(
|
||
DOUBAO_BASE_URL,
|
||
headers={
|
||
"Authorization": f"Bearer {VOLCENGINE_API_KEY}",
|
||
"Content-Type": "application/json; charset=utf-8",
|
||
},
|
||
data=payload.encode("utf-8"),
|
||
stream=True,
|
||
timeout=120,
|
||
)
|
||
|
||
if resp.status_code != 200:
|
||
log(f"[Error] Doubao API returned {resp.status_code}: {resp.text[:300]}")
|
||
yield sse_event({"stage": "error", "progress": 0, "message": f"AI 服务返回异常 ({resp.status_code})"})
|
||
return
|
||
|
||
# Force UTF-8 decoding (requests defaults to ISO-8859-1 which garbles Chinese)
|
||
resp.encoding = "utf-8"
|
||
|
||
# Parse SSE stream from Doubao
|
||
full_content = ""
|
||
chunk_count = 0
|
||
|
||
for line in resp.iter_lines(decode_unicode=True):
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
data_str = line[6:] # strip "data: "
|
||
if data_str.strip() == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
chunk_data = json.loads(data_str)
|
||
choices = chunk_data.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
delta_content = delta.get("content", "")
|
||
if delta_content:
|
||
full_content += delta_content
|
||
chunk_count += 1
|
||
# Send progress updates every 5 chunks
|
||
if chunk_count % 5 == 0:
|
||
progress = min(10 + int(chunk_count * 0.8), 85)
|
||
yield sse_event({
|
||
"stage": "generating",
|
||
"progress": progress,
|
||
"message": "故事正在诞生...",
|
||
})
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
log(f"[Story] Stream done. Total chunks: {chunk_count}, content length: {len(full_content)}")
|
||
log(f"[Story] Raw output (first 200): {full_content[:200]}")
|
||
|
||
if not full_content.strip():
|
||
yield sse_event({"stage": "error", "progress": 0, "message": "AI 未返回故事内容"})
|
||
return
|
||
|
||
# ── Stage 3: Parse response ──
|
||
yield sse_event({"stage": "parsing", "progress": 90, "message": "正在整理故事..."})
|
||
|
||
# Clean up response — strip markdown fences if present
|
||
cleaned = full_content.strip()
|
||
if cleaned.startswith("```"):
|
||
cleaned = re.sub(r'^```\w*\n?', '', cleaned)
|
||
cleaned = re.sub(r'```\s*$', '', cleaned).strip()
|
||
|
||
# Try to parse JSON
|
||
title = ""
|
||
content = ""
|
||
|
||
json_match = re.search(r'\{[\s\S]*\}', cleaned)
|
||
if json_match:
|
||
try:
|
||
story_json = json.loads(json_match.group())
|
||
title = story_json.get("title", "")
|
||
content = story_json.get("content", "")
|
||
except json.JSONDecodeError:
|
||
log("[Warn] JSON parse failed, extracting manually...")
|
||
title_m = re.search(r'"title"\s*:\s*"([^"]*)"', cleaned)
|
||
content_m = re.search(r'"content"\s*:\s*"([\s\S]*)', cleaned)
|
||
title = title_m.group(1) if title_m else "卡皮巴拉的故事"
|
||
if content_m:
|
||
content = content_m.group(1)
|
||
content = re.sub(r'"\s*\}\s*$', '', content).strip()
|
||
|
||
if not title and not content:
|
||
# Not JSON at all — treat entire output as story content
|
||
title = "卡皮巴拉的故事"
|
||
content = cleaned
|
||
|
||
# Clean content: replace literal \n with real newlines
|
||
content = content.replace("\\n", "\n").strip()
|
||
# Collapse 3+ newlines into 2
|
||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||
|
||
log(f"[Story] Title: {title}")
|
||
log(f"[Story] Content (first 100): {content[:100]}")
|
||
|
||
# ── Save story to disk ──
|
||
save_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara stories")
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
safe_name = re.sub(r'[^\w\u4e00-\u9fff]', '', title)[:20] or "story"
|
||
filename = f"{safe_name}_{int(time.time())}.txt"
|
||
filepath = os.path.join(save_dir, filename)
|
||
with open(filepath, "w", encoding="utf-8") as f:
|
||
f.write(f"# {title}\n\n{content}")
|
||
log(f"[Saved] {filepath}")
|
||
|
||
# ── Done ──
|
||
yield sse_event({
|
||
"stage": "done",
|
||
"progress": 100,
|
||
"message": "故事创作完成!",
|
||
"title": title,
|
||
"content": content,
|
||
})
|
||
|
||
except requests.exceptions.Timeout:
|
||
log("[Error] Doubao API Timeout")
|
||
yield sse_event({"stage": "error", "progress": 0, "message": "AI 响应超时,请稍后再试"})
|
||
except Exception as e:
|
||
log(f"[Error] Story generation exception: {e}")
|
||
yield sse_event({"stage": "error", "progress": 0, "message": f"故事生成失败: {str(e)}"})
|
||
|
||
return StreamingResponse(
|
||
event_stream(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"X-Accel-Buffering": "no",
|
||
"Connection": "keep-alive",
|
||
},
|
||
)
|
||
|
||
|
||
@app.get("/api/stories")
|
||
def get_stories():
|
||
"""Scan Capybara stories/ directory and return all saved stories."""
|
||
stories_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara stories")
|
||
|
||
stories = []
|
||
if not os.path.isdir(stories_dir):
|
||
return {"stories": []}
|
||
|
||
for f in sorted(os.listdir(stories_dir), reverse=True): # newest first
|
||
if not f.lower().endswith(".txt"):
|
||
continue
|
||
|
||
filepath = os.path.join(stories_dir, f)
|
||
try:
|
||
with open(filepath, "r", encoding="utf-8") as fh:
|
||
raw = fh.read()
|
||
|
||
# Parse: first line is "# Title", rest is content
|
||
lines = raw.strip().split("\n", 2)
|
||
title = lines[0].lstrip("# ").strip() if lines else f[:-4]
|
||
content = lines[2].strip() if len(lines) > 2 else ""
|
||
|
||
# Skip garbled files: if title or content has mojibake patterns, skip
|
||
# Normal Chinese chars are in range \u4e00-\u9fff; mojibake typically has
|
||
# lots of Latin Extended chars like \u00e0-\u00ff mixed with CJK
|
||
if title and not any('\u4e00' <= c <= '\u9fff' for c in title):
|
||
continue # title has no Chinese chars at all → likely garbled
|
||
|
||
# Display title: strip timestamp suffix like _1770647563
|
||
display_title = re.sub(r'_\d{10,}$', '', f[:-4])
|
||
if title:
|
||
display_title = title
|
||
|
||
stories.append({
|
||
"title": display_title,
|
||
"content": content,
|
||
"filename": f,
|
||
})
|
||
except Exception:
|
||
pass
|
||
|
||
return {"stories": stories}
|
||
|
||
|
||
@app.get("/api/playlist")
|
||
def get_playlist():
|
||
"""Scan Capybara music/ directory and return full playlist with lyrics."""
|
||
music_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara music")
|
||
lyrics_dir = os.path.join(music_dir, "lyrics")
|
||
|
||
playlist = []
|
||
if not os.path.isdir(music_dir):
|
||
return {"playlist": []}
|
||
|
||
for f in sorted(os.listdir(music_dir)):
|
||
if not f.lower().endswith(".mp3"):
|
||
continue
|
||
|
||
name = f[:-4] # strip .mp3
|
||
|
||
# Read lyrics if available
|
||
lyrics = ""
|
||
lyrics_file = os.path.join(lyrics_dir, name + ".txt")
|
||
if os.path.isfile(lyrics_file):
|
||
try:
|
||
with open(lyrics_file, "r", encoding="utf-8") as lf:
|
||
lyrics = lf.read()
|
||
except Exception:
|
||
pass
|
||
|
||
# Display title: strip timestamp suffix like _1770367350
|
||
title = re.sub(r'_\d{10,}$', '', name)
|
||
|
||
playlist.append({
|
||
"title": title,
|
||
"audioUrl": f"Capybara music/{f}",
|
||
"lyrics": lyrics
|
||
})
|
||
|
||
return {"playlist": playlist}
|
||
|
||
|
||
# ═══════════════════════════════════════════════════════════════════
|
||
# ── TTS: 豆包语音合成 2.0 WebSocket V3 二进制协议 ──
|
||
# ═══════════════════════════════════════════════════════════════════
|
||
|
||
TTS_WS_URL = "wss://openspeech.bytedance.com/api/v1/tts/ws_binary"
|
||
TTS_CLUSTER = "volcano_tts"
|
||
TTS_SPEAKER = "ICL_zh_female_keainvsheng_tob"
|
||
|
||
_audio_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara audio")
|
||
os.makedirs(_audio_dir, exist_ok=True)
|
||
|
||
|
||
def _build_tts_v1_request(payload_json: dict) -> bytes:
|
||
"""Build a V1 full-client-request binary frame.
|
||
Header: 0x11 0x10 0x10 0x00 (v1, 4-byte header, full-client-request, JSON, no compression)
|
||
Then 4-byte big-endian payload length, then JSON payload bytes.
|
||
"""
|
||
payload_bytes = json.dumps(payload_json, ensure_ascii=False).encode("utf-8")
|
||
header = bytes([0x11, 0x10, 0x10, 0x00])
|
||
length = struct.pack(">I", len(payload_bytes))
|
||
return header + length + payload_bytes
|
||
|
||
|
||
def _parse_tts_v1_response(data: bytes):
|
||
"""Parse a V1 TTS response binary frame.
|
||
Returns (audio_bytes_or_none, is_last, is_error, error_msg).
|
||
"""
|
||
if len(data) < 4:
|
||
return None, False, True, "Frame too short"
|
||
|
||
byte1 = data[1]
|
||
msg_type = (byte1 >> 4) & 0x0F
|
||
msg_flags = byte1 & 0x0F
|
||
|
||
# Error frame: msg_type = 0xF
|
||
if msg_type == 0x0F:
|
||
offset = 4
|
||
error_code = 0
|
||
if len(data) >= offset + 4:
|
||
error_code = struct.unpack(">I", data[offset:offset + 4])[0]
|
||
offset += 4
|
||
if len(data) >= offset + 4:
|
||
msg_len = struct.unpack(">I", data[offset:offset + 4])[0]
|
||
offset += 4
|
||
error_msg = data[offset:offset + msg_len].decode("utf-8", errors="replace")
|
||
else:
|
||
error_msg = f"error code {error_code}"
|
||
print(f"[TTS Error] code={error_code}, msg={error_msg}", flush=True)
|
||
return None, False, True, error_msg
|
||
|
||
# Audio-only response: msg_type = 0xB
|
||
if msg_type == 0x0B:
|
||
# flags: 0b0000=no seq, 0b0001=seq>0, 0b0010/0b0011=last (seq<0)
|
||
is_last = (msg_flags & 0x02) != 0 # bit 1 set = last message
|
||
offset = 4
|
||
|
||
# If flags != 0, there's a 4-byte sequence number
|
||
if msg_flags != 0:
|
||
offset += 4 # skip sequence number
|
||
|
||
if len(data) < offset + 4:
|
||
return None, is_last, False, ""
|
||
|
||
payload_size = struct.unpack(">I", data[offset:offset + 4])[0]
|
||
offset += 4
|
||
audio_data = data[offset:offset + payload_size]
|
||
return audio_data, is_last, False, ""
|
||
|
||
# Server response with JSON (msg_type = 0x9): usually contains metadata
|
||
if msg_type == 0x09:
|
||
offset = 4
|
||
if len(data) >= offset + 4:
|
||
payload_size = struct.unpack(">I", data[offset:offset + 4])[0]
|
||
offset += 4
|
||
json_str = data[offset:offset + payload_size].decode("utf-8", errors="replace")
|
||
print(f"[TTS] Server JSON: {json_str[:200]}", flush=True)
|
||
return None, False, False, ""
|
||
|
||
return None, False, False, ""
|
||
|
||
|
||
async def tts_synthesize(text: str) -> bytes:
|
||
"""Connect to Doubao TTS V1 WebSocket and synthesize text to MP3 bytes."""
|
||
headers = {
|
||
"Authorization": f"Bearer;{TTS_ACCESS_TOKEN}",
|
||
}
|
||
|
||
payload = {
|
||
"app": {
|
||
"appid": TTS_APP_ID,
|
||
"token": "placeholder",
|
||
"cluster": TTS_CLUSTER,
|
||
},
|
||
"user": {
|
||
"uid": "airhub_user",
|
||
},
|
||
"audio": {
|
||
"voice_type": TTS_SPEAKER,
|
||
"encoding": "mp3",
|
||
"speed_ratio": 1.0,
|
||
"rate": 24000,
|
||
},
|
||
"request": {
|
||
"reqid": str(uuid.uuid4()),
|
||
"text": text,
|
||
"operation": "submit", # streaming mode
|
||
},
|
||
}
|
||
|
||
audio_buffer = bytearray()
|
||
request_frame = _build_tts_v1_request(payload)
|
||
|
||
print(f"[TTS] Connecting to V1 WebSocket... text length={len(text)}", flush=True)
|
||
|
||
async with websockets.connect(
|
||
TTS_WS_URL,
|
||
extra_headers=headers,
|
||
max_size=10 * 1024 * 1024, # 10MB max frame
|
||
ping_interval=None,
|
||
) as ws:
|
||
# Send request
|
||
await ws.send(request_frame)
|
||
print("[TTS] Request sent, waiting for audio...", flush=True)
|
||
|
||
# Receive audio chunks
|
||
chunk_count = 0
|
||
async for message in ws:
|
||
if isinstance(message, bytes):
|
||
audio_data, is_last, is_error, error_msg = _parse_tts_v1_response(message)
|
||
|
||
if is_error:
|
||
raise RuntimeError(f"TTS error: {error_msg}")
|
||
|
||
if audio_data and len(audio_data) > 0:
|
||
audio_buffer.extend(audio_data)
|
||
chunk_count += 1
|
||
|
||
if is_last:
|
||
print(f"[TTS] Last frame received. chunks={chunk_count}, "
|
||
f"audio size={len(audio_buffer)} bytes", flush=True)
|
||
break
|
||
|
||
return bytes(audio_buffer)
|
||
|
||
|
||
class TTSRequest(BaseModel):
|
||
title: str
|
||
content: str
|
||
|
||
|
||
@app.get("/api/tts_check")
|
||
def tts_check(title: str = Query(...)):
|
||
"""Check if audio already exists for a story title."""
|
||
for f in os.listdir(_audio_dir):
|
||
if f.lower().endswith(".mp3"):
|
||
# Match by title prefix (before timestamp)
|
||
name = f[:-4] # strip .mp3
|
||
name_without_ts = re.sub(r'_\d{10,}$', '', name)
|
||
if name_without_ts == title or name == title:
|
||
return {
|
||
"exists": True,
|
||
"audio_url": f"Capybara audio/{f}",
|
||
}
|
||
return {"exists": False, "audio_url": None}
|
||
|
||
|
||
@app.post("/api/create_tts")
|
||
def create_tts(req: TTSRequest):
|
||
"""Generate TTS audio for a story. Returns SSE stream with progress."""
|
||
|
||
def event_stream():
|
||
import asyncio
|
||
|
||
yield sse_event({"stage": "connecting", "progress": 10,
|
||
"message": "正在连接语音合成服务..."})
|
||
|
||
# Check if audio already exists
|
||
for f in os.listdir(_audio_dir):
|
||
if f.lower().endswith(".mp3"):
|
||
name = f[:-4]
|
||
name_without_ts = re.sub(r'_\d{10,}$', '', name)
|
||
if name_without_ts == req.title:
|
||
yield sse_event({"stage": "done", "progress": 100,
|
||
"message": "语音已存在",
|
||
"audio_url": f"Capybara audio/{f}"})
|
||
return
|
||
|
||
yield sse_event({"stage": "generating", "progress": 30,
|
||
"message": "AI 正在朗读故事..."})
|
||
|
||
try:
|
||
# Run async TTS in a new event loop
|
||
loop = asyncio.new_event_loop()
|
||
audio_bytes = loop.run_until_complete(tts_synthesize(req.content))
|
||
loop.close()
|
||
|
||
if not audio_bytes or len(audio_bytes) < 100:
|
||
yield sse_event({"stage": "error", "progress": 0,
|
||
"message": "语音合成返回了空音频"})
|
||
return
|
||
|
||
yield sse_event({"stage": "saving", "progress": 80,
|
||
"message": "正在保存音频..."})
|
||
|
||
# Save audio file
|
||
timestamp = int(time.time())
|
||
safe_title = re.sub(r'[<>:"/\\|?*]', '', req.title)[:50]
|
||
filename = f"{safe_title}_{timestamp}.mp3"
|
||
filepath = os.path.join(_audio_dir, filename)
|
||
|
||
with open(filepath, "wb") as f:
|
||
f.write(audio_bytes)
|
||
|
||
print(f"[TTS Saved] {filepath} ({len(audio_bytes)} bytes)", flush=True)
|
||
|
||
yield sse_event({"stage": "done", "progress": 100,
|
||
"message": "语音生成完成!",
|
||
"audio_url": f"Capybara audio/{filename}"})
|
||
|
||
except Exception as e:
|
||
print(f"[TTS Error] {e}", flush=True)
|
||
yield sse_event({"stage": "error", "progress": 0,
|
||
"message": f"语音合成失败: {str(e)}"})
|
||
|
||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||
|
||
|
||
# ── Static file serving ──
|
||
from fastapi.staticfiles import StaticFiles
|
||
|
||
# Music directory
|
||
_music_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara music")
|
||
os.makedirs(_music_dir, exist_ok=True)
|
||
app.mount("/Capybara music", StaticFiles(directory=_music_dir), name="music_files")
|
||
|
||
# Audio directory (TTS generated)
|
||
app.mount("/Capybara audio", StaticFiles(directory=_audio_dir), name="audio_files")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
print("[Server] Music Server running on http://localhost:3000")
|
||
uvicorn.run(app, host="0.0.0.0", port=3000)
|