import os import re import sys import time import uvicorn import requests import json from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from dotenv import load_dotenv # Force UTF-8 stdout/stderr on Windows (avoids GBK encoding errors) if sys.platform == "win32": sys.stdout.reconfigure(encoding="utf-8", errors="replace") sys.stderr.reconfigure(encoding="utf-8", errors="replace") # Load Environment Variables load_dotenv() MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY") VOLCENGINE_API_KEY = os.getenv("VOLCENGINE_API_KEY") if not MINIMAX_API_KEY: print("Warning: MINIMAX_API_KEY not found in .env") if not VOLCENGINE_API_KEY: print("Warning: VOLCENGINE_API_KEY not found in .env") # Initialize FastAPI app = FastAPI() # Allow CORS for local frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Request Models class MusicRequest(BaseModel): text: str mood: str = "custom" # 'chill', 'happy', 'sleepy', 'random', 'custom' class StoryRequest(BaseModel): characters: list[str] = [] scenes: list[str] = [] props: list[str] = [] # Minimax Constants MINIMAX_GROUP_ID = "YOUR_GROUP_ID" BASE_URL_CHAT = "https://api.minimax.chat/v1/text/chatcompletion_v2" BASE_URL_MUSIC = "https://api.minimaxi.com/v1/music_generation" # Load System Prompts try: with open("prompts/music_director.md", "r", encoding="utf-8") as f: SYSTEM_PROMPT = f.read() except FileNotFoundError: SYSTEM_PROMPT = "You are a music director AI. Convert user input into JSON with 'style' (English description) and 'lyrics' (Chinese, structured)." print("Warning: prompts/music_director.md not found, using default.") try: with open("prompts/story_director.md", "r", encoding="utf-8") as f: STORY_SYSTEM_PROMPT = f.read() except FileNotFoundError: STORY_SYSTEM_PROMPT = "你是一个儿童故事大师。根据用户提供的角色、场景、道具素材创作一个300-600字的儿童故事。只返回JSON格式:{\"title\": \"标题\", \"content\": \"正文\"}" print("Warning: prompts/story_director.md not found, using default.") # Volcengine / Doubao constants DOUBAO_BASE_URL = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" DOUBAO_MODEL = "doubao-seed-1-6-lite-251015" # Doubao-Seed-1.6-lite def sse_event(data): """Format a dict as an SSE data line. Use ensure_ascii=True so all non-ASCII chars become \\uXXXX escapes, avoiding Windows GBK encoding issues in the SSE stream.""" return f"data: {json.dumps(data, ensure_ascii=True)}\n\n" def clean_lyrics(raw: str) -> str: """Clean lyrics extracted from LLM JSON output. Removes JSON artifacts, structure tags, and normalizes formatting.""" if not raw: return raw s = raw # Replace literal \n with real newlines s = s.replace("\\n", "\n") # Remove JSON string quotes and concatenation artifacts (" ") s = re.sub(r'"\s*"', '', s) s = s.replace('"', '') # Remove structure tags like [verse 1], [chorus], [outro], [bridge], [intro], etc. s = re.sub(r'\[(?:verse|chorus|bridge|outro|intro|hook|pre-chorus|interlude|inst)\s*\d*\]\s*', '', s, flags=re.IGNORECASE) # Strip leading/trailing whitespace from each line lines = [line.strip() for line in s.split('\n')] s = '\n'.join(lines) # Collapse 3+ consecutive newlines into 2 (one blank line between paragraphs) s = re.sub(r'\n{3,}', '\n\n', s) # Remove leading/trailing blank lines s = s.strip() return s @app.post("/api/create_music") def create_music(req: MusicRequest): """SSE streaming endpoint – pushes progress events to the frontend.""" print(f"[Music] Received request: {req.text} [{req.mood}]", flush=True) def event_stream(): import sys def log(msg): print(msg, flush=True) sys.stdout.flush() # ── Stage 1: LLM "Music Director" ──────────────────────── log("[Stage 1] Starting LLM call...") yield sse_event({ "stage": "lyrics", "progress": 10, "message": "AI 正在创作词曲..." }) director_input = f"用户场景描述: {req.text}。 (预设氛围参考: {req.mood})" try: chat_resp = requests.post( BASE_URL_CHAT, headers={ "Authorization": f"Bearer {MINIMAX_API_KEY}", "Content-Type": "application/json" }, json={ "model": "abab6.5s-chat", "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": director_input} ], "max_tokens": 2048 # Enough for long lyrics }, timeout=60 ) chat_data = chat_resp.json() log(f"[Debug] Chat API status: {chat_resp.status_code}, resp keys: {list(chat_data.keys())}") if "choices" not in chat_data or not chat_data["choices"]: base = chat_data.get("base_resp", {}) raise ValueError(f"Chat API error ({base.get('status_code')}): {base.get('status_msg')}") content_str = chat_data["choices"][0]["message"]["content"] log(f"[Debug] LLM raw output (first 200): {content_str[:200]}") # Strip markdown code fences if present content_str = content_str.strip() if content_str.startswith("```"): content_str = re.sub(r'^```\w*\n?', '', content_str) content_str = re.sub(r'```\s*$', '', content_str).strip() # Try to extract JSON from response (robust parsing) json_match = re.search(r'\{[\s\S]*\}', content_str) if json_match: json_str = json_match.group() try: metadata = json.loads(json_str) except json.JSONDecodeError: # JSON might have unescaped newlines in string values — try fixing log(f"[Warn] JSON parse failed, attempting repair...") # Extract fields manually via regex title_m = re.search(r'"song_title"\s*:\s*"([^"]*)"', json_str) style_m = re.search(r'"style"\s*:\s*"([^"]*)"', json_str) lyrics_m = re.search(r'"lyrics"\s*:\s*"([\s\S]*)', json_str) lyrics_val = "" if lyrics_m: # Take everything after "lyrics": " and strip trailing quotes/braces lyrics_val = lyrics_m.group(1) lyrics_val = re.sub(r'"\s*\}\s*$', '', lyrics_val).strip() metadata = { "song_title": title_m.group(1) if title_m else "", "style": style_m.group(1) if style_m else "Pop music, cheerful", "lyrics": lyrics_val } log(f"[Repaired] title={metadata['song_title']}, style={metadata['style'][:60]}") elif content_str.strip().startswith("{"): # JSON is incomplete (missing closing brace) — try adding it log(f"[Warn] Incomplete JSON, attempting to close...") try: metadata = json.loads(content_str + '"}\n}') except json.JSONDecodeError: # Manual extraction as last resort title_m = re.search(r'"song_title"\s*:\s*"([^"]*)"', content_str) style_m = re.search(r'"style"\s*:\s*"([^"]*)"', content_str) lyrics_m = re.search(r'"lyrics"\s*:\s*"([\s\S]*)', content_str) lyrics_val = lyrics_m.group(1).rstrip('"} \n') if lyrics_m else "[Inst]" metadata = { "song_title": title_m.group(1) if title_m else "", "style": style_m.group(1) if style_m else "Pop music, cheerful", "lyrics": lyrics_val } log(f"[Repaired] title={metadata.get('song_title')}") else: raise ValueError(f"No JSON in LLM response: {content_str[:100]}") style_val = metadata.get("style", "") lyrics_val = clean_lyrics(metadata.get("lyrics", "")) metadata["lyrics"] = lyrics_val # Store cleaned version log(f"[Director] Style: {style_val[:80]}") log(f"[Director] Lyrics (first 60): {lyrics_val[:60]}") yield sse_event({ "stage": "lyrics_done", "progress": 25, "message": "词曲创作完成!准备生成音乐..." }) except Exception as e: log(f"[Error] Director LLM Failed: {e}") metadata = { "style": "Lofi hip hop, relaxing, slow tempo, water sounds", "lyrics": "[Inst]" } yield sse_event({ "stage": "lyrics_fallback", "progress": 25, "message": "使用默认风格,准备生成音乐..." }) # ── Stage 2: Music Generation ──────────────────────────── yield sse_event({ "stage": "music", "progress": 30, "message": "正在生成音乐,请耐心等待..." }) try: raw_lyrics = metadata.get("lyrics") or "" # API requires lyrics >= 1 char if not raw_lyrics.strip() or "[instrumental]" in raw_lyrics.lower(): raw_lyrics = "[Inst]" music_payload = { "model": "music-2.5", "prompt": metadata.get("style", "Pop music"), "lyrics": raw_lyrics, "audio_setting": { "sample_rate": 44100, "bitrate": 256000, "format": "mp3" } } log(f"[Debug] Music payload prompt: {music_payload['prompt'][:80]}") log(f"[Debug] Music payload lyrics (first 60): {music_payload['lyrics'][:60]}") music_resp = requests.post( BASE_URL_MUSIC, headers={ "Authorization": f"Bearer {MINIMAX_API_KEY}", "Content-Type": "application/json" }, json=music_payload, timeout=300 # 5 min — music generation can be slow ) music_data = music_resp.json() base_resp = music_data.get("base_resp", {}) log(f"[Debug] Music API status: {music_resp.status_code}, base_resp: {base_resp}") if music_data.get("data") and music_data["data"].get("audio"): hex_audio = music_data["data"]["audio"] log(f"[OK] Music generated! Audio hex length: {len(hex_audio)}") # ── Stage 3: Saving ────────────────────────────── yield sse_event({ "stage": "saving", "progress": 90, "message": "音乐生成完成,正在保存..." }) save_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara music") os.makedirs(save_dir, exist_ok=True) # Prefer song_title from LLM; fallback to user input raw_title = metadata.get("song_title") or req.text safe_name = re.sub(r'[^\w\u4e00-\u9fff]', '', raw_title)[:20] or "ai_song" filename = f"{safe_name}_{int(time.time())}.mp3" filepath = os.path.join(save_dir, filename) audio_bytes = bytes.fromhex(hex_audio) with open(filepath, "wb") as f: f.write(audio_bytes) log(f"[Saved] {filepath}") # Save lyrics txt lyrics_text = metadata.get("lyrics", "") if lyrics_text: lyrics_dir = os.path.join(save_dir, "lyrics") os.makedirs(lyrics_dir, exist_ok=True) lyrics_filename = f"{safe_name}_{int(time.time())}.txt" with open(os.path.join(lyrics_dir, lyrics_filename), "w", encoding="utf-8") as lf: lf.write(lyrics_text) relative_path = f"Capybara music/{filename}" # ── Done ───────────────────────────────────────── yield sse_event({ "stage": "done", "progress": 100, "message": "新歌出炉!", "status": "success", "file_path": relative_path, "metadata": metadata }) else: error_msg = base_resp.get("status_msg", "unknown") error_code = base_resp.get("status_code", -1) log(f"[Error] Music Gen failed: {error_code} - {error_msg}") yield sse_event({ "stage": "error", "progress": 0, "message": f"生成失败 ({error_code}): {error_msg}" }) except requests.exceptions.Timeout: log("[Error] Music Gen Timeout") yield sse_event({ "stage": "error", "progress": 0, "message": "音乐生成超时,请稍后再试" }) except Exception as e: log(f"[Error] API exception: {e}") yield sse_event({ "stage": "error", "progress": 0, "message": f"服务器错误: {str(e)}" }) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive" } ) # ═══════════════════════════════════════════════════════════════════ # ── Story Generation (Doubao / Volcengine) ── # ═══════════════════════════════════════════════════════════════════ @app.post("/api/create_story") def create_story(req: StoryRequest): """SSE streaming endpoint – generates a children's story via Doubao LLM.""" print(f"[Story] Received request: characters={req.characters}, scenes={req.scenes}, props={req.props}", flush=True) def event_stream(): def log(msg): print(msg, flush=True) # ── Stage 1: Connecting ── yield sse_event({"stage": "connecting", "progress": 5, "message": "正在连接 AI..."}) # Build user prompt from selected elements parts = [] if req.characters: parts.append(f"角色=[{', '.join(req.characters)}]") if req.scenes: parts.append(f"场景=[{', '.join(req.scenes)}]") if req.props: parts.append(f"道具=[{', '.join(req.props)}]") user_prompt = "请用这些素材创作一个故事:" + ",".join(parts) if parts else "请随机创作一个有趣的儿童故事" log(f"[Story] User prompt: {user_prompt}") # ── Stage 2: Generating (streaming) ── yield sse_event({"stage": "generating", "progress": 10, "message": "故事正在诞生..."}) try: # Explicitly encode as UTF-8 to avoid Windows GBK encoding issues payload = json.dumps({ "model": DOUBAO_MODEL, "messages": [ {"role": "system", "content": STORY_SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], "max_tokens": 2048, "stream": True, "thinking": {"type": "disabled"}, }, ensure_ascii=False) resp = requests.post( DOUBAO_BASE_URL, headers={ "Authorization": f"Bearer {VOLCENGINE_API_KEY}", "Content-Type": "application/json; charset=utf-8", }, data=payload.encode("utf-8"), stream=True, timeout=120, ) if resp.status_code != 200: log(f"[Error] Doubao API returned {resp.status_code}: {resp.text[:300]}") yield sse_event({"stage": "error", "progress": 0, "message": f"AI 服务返回异常 ({resp.status_code})"}) return # Force UTF-8 decoding (requests defaults to ISO-8859-1 which garbles Chinese) resp.encoding = "utf-8" # Parse SSE stream from Doubao full_content = "" chunk_count = 0 for line in resp.iter_lines(decode_unicode=True): if not line or not line.startswith("data: "): continue data_str = line[6:] # strip "data: " if data_str.strip() == "[DONE]": break try: chunk_data = json.loads(data_str) choices = chunk_data.get("choices", []) if choices: delta = choices[0].get("delta", {}) delta_content = delta.get("content", "") if delta_content: full_content += delta_content chunk_count += 1 # Send progress updates every 5 chunks if chunk_count % 5 == 0: progress = min(10 + int(chunk_count * 0.8), 85) yield sse_event({ "stage": "generating", "progress": progress, "message": "故事正在诞生...", }) except json.JSONDecodeError: continue log(f"[Story] Stream done. Total chunks: {chunk_count}, content length: {len(full_content)}") log(f"[Story] Raw output (first 200): {full_content[:200]}") if not full_content.strip(): yield sse_event({"stage": "error", "progress": 0, "message": "AI 未返回故事内容"}) return # ── Stage 3: Parse response ── yield sse_event({"stage": "parsing", "progress": 90, "message": "正在整理故事..."}) # Clean up response — strip markdown fences if present cleaned = full_content.strip() if cleaned.startswith("```"): cleaned = re.sub(r'^```\w*\n?', '', cleaned) cleaned = re.sub(r'```\s*$', '', cleaned).strip() # Try to parse JSON title = "" content = "" json_match = re.search(r'\{[\s\S]*\}', cleaned) if json_match: try: story_json = json.loads(json_match.group()) title = story_json.get("title", "") content = story_json.get("content", "") except json.JSONDecodeError: log("[Warn] JSON parse failed, extracting manually...") title_m = re.search(r'"title"\s*:\s*"([^"]*)"', cleaned) content_m = re.search(r'"content"\s*:\s*"([\s\S]*)', cleaned) title = title_m.group(1) if title_m else "卡皮巴拉的故事" if content_m: content = content_m.group(1) content = re.sub(r'"\s*\}\s*$', '', content).strip() if not title and not content: # Not JSON at all — treat entire output as story content title = "卡皮巴拉的故事" content = cleaned # Clean content: replace literal \n with real newlines content = content.replace("\\n", "\n").strip() # Collapse 3+ newlines into 2 content = re.sub(r'\n{3,}', '\n\n', content) log(f"[Story] Title: {title}") log(f"[Story] Content (first 100): {content[:100]}") # ── Save story to disk ── save_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara stories") os.makedirs(save_dir, exist_ok=True) safe_name = re.sub(r'[^\w\u4e00-\u9fff]', '', title)[:20] or "story" filename = f"{safe_name}_{int(time.time())}.txt" filepath = os.path.join(save_dir, filename) with open(filepath, "w", encoding="utf-8") as f: f.write(f"# {title}\n\n{content}") log(f"[Saved] {filepath}") # ── Done ── yield sse_event({ "stage": "done", "progress": 100, "message": "故事创作完成!", "title": title, "content": content, }) except requests.exceptions.Timeout: log("[Error] Doubao API Timeout") yield sse_event({"stage": "error", "progress": 0, "message": "AI 响应超时,请稍后再试"}) except Exception as e: log(f"[Error] Story generation exception: {e}") yield sse_event({"stage": "error", "progress": 0, "message": f"故事生成失败: {str(e)}"}) return StreamingResponse( event_stream(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive", }, ) @app.get("/api/stories") def get_stories(): """Scan Capybara stories/ directory and return all saved stories.""" stories_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara stories") stories = [] if not os.path.isdir(stories_dir): return {"stories": []} for f in sorted(os.listdir(stories_dir), reverse=True): # newest first if not f.lower().endswith(".txt"): continue filepath = os.path.join(stories_dir, f) try: with open(filepath, "r", encoding="utf-8") as fh: raw = fh.read() # Parse: first line is "# Title", rest is content lines = raw.strip().split("\n", 2) title = lines[0].lstrip("# ").strip() if lines else f[:-4] content = lines[2].strip() if len(lines) > 2 else "" # Skip garbled files: if title or content has mojibake patterns, skip # Normal Chinese chars are in range \u4e00-\u9fff; mojibake typically has # lots of Latin Extended chars like \u00e0-\u00ff mixed with CJK if title and not any('\u4e00' <= c <= '\u9fff' for c in title): continue # title has no Chinese chars at all → likely garbled # Display title: strip timestamp suffix like _1770647563 display_title = re.sub(r'_\d{10,}$', '', f[:-4]) if title: display_title = title stories.append({ "title": display_title, "content": content, "filename": f, }) except Exception: pass return {"stories": stories} @app.get("/api/playlist") def get_playlist(): """Scan Capybara music/ directory and return full playlist with lyrics.""" music_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara music") lyrics_dir = os.path.join(music_dir, "lyrics") playlist = [] if not os.path.isdir(music_dir): return {"playlist": []} for f in sorted(os.listdir(music_dir)): if not f.lower().endswith(".mp3"): continue name = f[:-4] # strip .mp3 # Read lyrics if available lyrics = "" lyrics_file = os.path.join(lyrics_dir, name + ".txt") if os.path.isfile(lyrics_file): try: with open(lyrics_file, "r", encoding="utf-8") as lf: lyrics = lf.read() except Exception: pass # Display title: strip timestamp suffix like _1770367350 title = re.sub(r'_\d{10,}$', '', name) playlist.append({ "title": title, "audioUrl": f"Capybara music/{f}", "lyrics": lyrics }) return {"playlist": playlist} # ── Static file serving for generated music ── from fastapi.staticfiles import StaticFiles # Create music directory if it doesn't exist _music_dir = os.path.join(os.path.dirname(__file__) or ".", "Capybara music") os.makedirs(_music_dir, exist_ok=True) app.mount("/Capybara music", StaticFiles(directory=_music_dir), name="music_files") if __name__ == "__main__": print("[Server] Music Server running on http://localhost:3000") uvicorn.run(app, host="0.0.0.0", port=3000)