rtc_backend/apps/music/services/music_generation_service.py
repair-agent f1bead86f6
Some checks failed
Build and Deploy Backend / build-and-deploy (push) Failing after 56s
fix music
2026-02-12 17:35:54 +08:00

318 lines
11 KiB
Python

"""
音乐生成服务 - 从 server.py 迁移
调用 MiniMax Chat API 生成歌词 + MiniMax Music API 生成音频
"""
import json
import logging
import os
import re
import time
import requests
from django.conf import settings
from django.db import transaction
from apps.users.models import PointsRecord
logger = logging.getLogger(__name__)
# MiniMax API endpoints
BASE_URL_CHAT = "https://api.minimax.chat/v1/text/chatcompletion_v2"
BASE_URL_MUSIC = "https://api.minimaxi.com/v1/music_generation"
# Load system prompt
_PROMPT_PATH = os.path.join(os.path.dirname(__file__), 'music_director_prompt.md')
try:
with open(_PROMPT_PATH, 'r', encoding='utf-8') as f:
_raw = f.read()
# Extract the prompt between the ``` fences under "## System Prompt"
_match = re.search(r'## System Prompt\s*\n\s*```\n([\s\S]*?)```', _raw)
SYSTEM_PROMPT = _match.group(1).strip() if _match else _raw
except FileNotFoundError:
SYSTEM_PROMPT = (
"You are a music director AI. Convert user input into JSON with "
"'song_title', 'style' (English description) and 'lyrics' (Chinese, structured)."
)
logger.warning("music_director_prompt.md not found, using default")
def _get_api_key():
return os.environ.get('MINIMAX_API_KEY', '')
def sse_event(data: dict) -> str:
"""Format a dict as an SSE data line."""
return f"data: {json.dumps(data, ensure_ascii=True)}\n\n"
def clean_lyrics(raw: str) -> str:
"""Clean lyrics extracted from LLM JSON output."""
if not raw:
return raw
s = raw
s = s.replace("\\n", "\n")
s = re.sub(r'"\s*"', '', s)
s = s.replace('"', '')
s = re.sub(
r'\[(?:verse|chorus|bridge|outro|intro|hook|pre-chorus|interlude|inst)\s*\d*\]\s*',
'', s, flags=re.IGNORECASE
)
lines = [line.strip() for line in s.split('\n')]
s = '\n'.join(lines)
s = re.sub(r'\n{3,}', '\n\n', s)
s = s.strip()
return s
def _parse_llm_json(content_str: str) -> dict:
"""Robust JSON parsing from LLM output (with fallbacks)."""
content_str = content_str.strip()
# Strip markdown code fences
if content_str.startswith("```"):
content_str = re.sub(r'^```\w*\n?', '', content_str)
content_str = re.sub(r'```\s*$', '', content_str).strip()
json_match = re.search(r'\{[\s\S]*\}', content_str)
if json_match:
try:
return json.loads(json_match.group())
except json.JSONDecodeError:
logger.warning("JSON parse failed, attempting regex extraction")
title_m = re.search(r'"song_title"\s*:\s*"([^"]*)"', json_match.group())
style_m = re.search(r'"style"\s*:\s*"([^"]*)"', json_match.group())
lyrics_m = re.search(r'"lyrics"\s*:\s*"([\s\S]*)', json_match.group())
lyrics_val = ""
if lyrics_m:
lyrics_val = lyrics_m.group(1)
lyrics_val = re.sub(r'"\s*\}\s*$', '', lyrics_val).strip()
return {
"song_title": title_m.group(1) if title_m else "",
"style": style_m.group(1) if style_m else "Pop music, cheerful",
"lyrics": lyrics_val,
}
if content_str.startswith("{"):
# Incomplete JSON
try:
return json.loads(content_str + '"}\n}')
except json.JSONDecodeError:
title_m = re.search(r'"song_title"\s*:\s*"([^"]*)"', content_str)
style_m = re.search(r'"style"\s*:\s*"([^"]*)"', content_str)
lyrics_m = re.search(r'"lyrics"\s*:\s*"([\s\S]*)', content_str)
lyrics_val = lyrics_m.group(1).rstrip('"} \n') if lyrics_m else "[Inst]"
return {
"song_title": title_m.group(1) if title_m else "",
"style": style_m.group(1) if style_m else "Pop music, cheerful",
"lyrics": lyrics_val,
}
raise ValueError(f"No JSON in LLM response: {content_str[:100]}")
def _refund_points(user, track):
"""退还积分并更新 Track 状态为 failed"""
with transaction.atomic():
user.refresh_from_db()
user.points += 100
user.save(update_fields=['points'])
PointsRecord.objects.create(
user=user,
amount=100,
type='refund_music',
description=f'音乐生成失败退款「{track.title}',
)
track.generation_status = 'failed'
track.save(update_fields=['generation_status'])
def generate_music_stream(user, track, text, mood):
"""
SSE generator: 调用 MiniMax API 生成音乐
在 view 层已完成积分扣除和 Track 创建
"""
api_key = _get_api_key()
if not api_key:
_refund_points(user, track)
yield sse_event({
"stage": "error", "progress": 0,
"message": "音乐服务未配置,积分已退还"
})
return
# ── Stage 1: LLM 歌词生成 ──
yield sse_event({
"stage": "lyrics", "progress": 10,
"message": "AI 正在创作词曲..."
})
director_input = f"用户场景描述: {text}。 (预设氛围参考: {mood})"
if mood == 'random' or not text.strip():
director_input = "咔咔今天想来点惊喜"
metadata = None
try:
chat_resp = requests.post(
BASE_URL_CHAT,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": "abab6.5s-chat",
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": director_input},
],
"max_tokens": 2048,
},
timeout=60,
)
chat_data = chat_resp.json()
if "choices" not in chat_data or not chat_data["choices"]:
base = chat_data.get("base_resp", {})
raise ValueError(
f"Chat API error ({base.get('status_code')}): {base.get('status_msg')}"
)
content_str = chat_data["choices"][0]["message"]["content"]
metadata = _parse_llm_json(content_str)
lyrics_val = clean_lyrics(metadata.get("lyrics", ""))
metadata["lyrics"] = lyrics_val
yield sse_event({
"stage": "lyrics_done", "progress": 25,
"message": "词曲创作完成!准备生成音乐..."
})
except Exception as e:
logger.error(f"Director LLM failed: {e}")
metadata = {
"style": "Lofi hip hop, relaxing, slow tempo, water sounds",
"lyrics": "[Inst]",
}
yield sse_event({
"stage": "lyrics_fallback", "progress": 25,
"message": "使用默认风格,准备生成音乐..."
})
# Update track title from LLM output
song_title = metadata.get("song_title", "") or text[:20] or "咔咔新歌"
track.title = song_title
track.lyrics = metadata.get("lyrics", "")
track.save(update_fields=['title', 'lyrics'])
# ── Stage 2: 音频生成 ──
yield sse_event({
"stage": "music", "progress": 30,
"message": "正在生成音乐,请耐心等待..."
})
try:
raw_lyrics = metadata.get("lyrics") or ""
if not raw_lyrics.strip() or "[instrumental]" in raw_lyrics.lower():
raw_lyrics = "[Inst]"
music_resp = requests.post(
BASE_URL_MUSIC,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": "music-2.5",
"prompt": metadata.get("style", "Pop music"),
"lyrics": raw_lyrics,
"audio_setting": {
"sample_rate": 44100,
"bitrate": 256000,
"format": "mp3",
},
},
timeout=300,
)
music_data = music_resp.json()
base_resp = music_data.get("base_resp", {})
if music_data.get("data") and music_data["data"].get("audio"):
hex_audio = music_data["data"]["audio"]
audio_bytes = bytes.fromhex(hex_audio)
# ── Stage 3: 上传到 OSS ──
yield sse_event({
"stage": "saving", "progress": 90,
"message": "音乐生成完成,正在保存..."
})
audio_url = _upload_to_oss(audio_bytes, song_title)
# Update track
track.audio_url = audio_url
track.cover_url = (
f"https://{settings.ALIYUN_OSS['BUCKET_NAME']}"
f".{settings.ALIYUN_OSS['ENDPOINT']}/music/defaults/Capybara.png"
)
track.generation_status = 'completed'
track.save(update_fields=[
'audio_url', 'cover_url', 'generation_status'
])
yield sse_event({
"stage": "done", "progress": 100,
"message": "新歌出炉!",
"track_id": track.id,
"audio_url": audio_url,
"cover_url": track.cover_url,
"metadata": {
"song_title": song_title,
"lyrics": metadata.get("lyrics", ""),
},
})
else:
error_msg = base_resp.get("status_msg", "unknown")
error_code = base_resp.get("status_code", -1)
logger.error(f"Music Gen failed: {error_code} - {error_msg}")
_refund_points(user, track)
yield sse_event({
"stage": "error", "progress": 0,
"message": f"生成失败 ({error_code}): {error_msg},积分已退还"
})
except requests.exceptions.Timeout:
logger.error("Music Gen Timeout")
_refund_points(user, track)
yield sse_event({
"stage": "error", "progress": 0,
"message": "音乐生成超时,积分已退还"
})
except Exception as e:
logger.error(f"Music API exception: {e}")
_refund_points(user, track)
yield sse_event({
"stage": "error", "progress": 0,
"message": f"服务器错误: {str(e)},积分已退还"
})
def _upload_to_oss(audio_bytes: bytes, title: str) -> str:
"""Upload MP3 bytes to Aliyun OSS, return public URL."""
try:
import oss2
except ImportError:
logger.error("oss2 not installed")
raise RuntimeError("OSS SDK 未安装")
oss_config = settings.ALIYUN_OSS
auth = oss2.Auth(oss_config['ACCESS_KEY_ID'], oss_config['ACCESS_KEY_SECRET'])
bucket = oss2.Bucket(auth, oss_config['ENDPOINT'], oss_config['BUCKET_NAME'])
safe_name = re.sub(r'[^\w\u4e00-\u9fff]', '', title)[:20] or "ai_song"
filename = f"{safe_name}_{int(time.time())}.mp3"
oss_key = f"music/generated/{filename}"
bucket.put_object(oss_key, audio_bytes)
custom_domain = oss_config.get('CUSTOM_DOMAIN', '')
if custom_domain:
return f"https://{custom_domain}/{oss_key}"
return f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}/{oss_key}"