All checks were successful
Build and Deploy Backend / build-and-deploy (push) Successful in 5m17s
239 lines
8.0 KiB
Python
239 lines
8.0 KiB
Python
"""
|
||
电子吧唧 - AI 图片生成服务
|
||
使用火山引擎豆包 Seedream 文生图模型,与故事封面生成共用同一模型。
|
||
"""
|
||
import base64
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
import requests as req_lib
|
||
from django.conf import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 风格 → 提示词后缀映射
|
||
STYLE_PROMPT_MAP = {
|
||
'anime': '日系动漫风格,精致细腻,色彩鲜明',
|
||
'realistic': '超写实风格,高清摄影质感,细节丰富',
|
||
'pixel': '像素艺术风格,复古游戏画面,8-bit色彩',
|
||
'watercolor': '水彩画风格,淡雅柔和,笔触自然晕染',
|
||
'cyberpunk': '赛博朋克风格,霓虹灯光,暗色调科幻感',
|
||
'cute': '可爱卡通风格,Q版萌系,圆润造型,柔和配色',
|
||
'ink': '中国水墨画风格,黑白灰韵,留白意境',
|
||
'comic': '漫画风格,粗线条,强对比,夸张表现力',
|
||
}
|
||
|
||
|
||
def sse_event(data: dict) -> str:
|
||
"""格式化 SSE data 行"""
|
||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||
|
||
|
||
def generate_t2i_stream(user, badge_image, prompt, style=None, width=1920, height=1920):
|
||
"""
|
||
文生图 SSE 流。
|
||
使用豆包 Seedream 模型生成正方形图片,上传 OSS,返回 URL。
|
||
"""
|
||
config = settings.LLM_CONFIG
|
||
|
||
if not config.get('API_KEY'):
|
||
badge_image.generation_status = 'failed'
|
||
badge_image.save(update_fields=['generation_status'])
|
||
yield sse_event({'stage': 'error', 'message': 'AI 服务未配置'})
|
||
return
|
||
|
||
try:
|
||
from volcenginesdkarkruntime import Ark
|
||
except ImportError:
|
||
badge_image.generation_status = 'failed'
|
||
badge_image.save(update_fields=['generation_status'])
|
||
yield sse_event({'stage': 'error', 'message': 'AI SDK 未安装'})
|
||
return
|
||
|
||
# ── Stage 1: 生成中 ──
|
||
yield sse_event({
|
||
'stage': 'generating', 'progress': 20,
|
||
'message': '正在生成图片...',
|
||
})
|
||
|
||
try:
|
||
client = Ark(api_key=config['API_KEY'])
|
||
|
||
# 构建提示词
|
||
full_prompt = _build_prompt(prompt, style)
|
||
image_model = config.get('IMAGE_MODEL_NAME', 'doubao-seedream-4-5-251128')
|
||
image_size = f'{width}x{height}'
|
||
|
||
result = client.images.generate(
|
||
model=image_model,
|
||
prompt=full_prompt,
|
||
size=image_size,
|
||
response_format='url',
|
||
watermark=False,
|
||
)
|
||
|
||
temp_url = result.data[0].url
|
||
|
||
# ── Stage 2: 处理中 ──
|
||
yield sse_event({
|
||
'stage': 'processing', 'progress': 70,
|
||
'message': '正在处理图片...',
|
||
})
|
||
|
||
# 下载临时图片并上传到 OSS
|
||
image_url = _download_and_upload(temp_url)
|
||
|
||
# 更新记录
|
||
badge_image.image_url = image_url
|
||
badge_image.generation_status = 'completed'
|
||
badge_image.save(update_fields=['image_url', 'generation_status'])
|
||
|
||
# ── Stage 3: 完成 ──
|
||
yield sse_event({
|
||
'stage': 'done', 'progress': 100,
|
||
'message': '生成完成!',
|
||
'image_url': image_url,
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f'Badge t2i generation failed: {e}')
|
||
badge_image.generation_status = 'failed'
|
||
badge_image.save(update_fields=['generation_status'])
|
||
yield sse_event({
|
||
'stage': 'error', 'progress': 0,
|
||
'message': f'生成失败: {str(e)}',
|
||
})
|
||
|
||
|
||
def generate_i2i_stream(user, badge_image, image_bytes, prompt='', style=None,
|
||
strength=0.7, width=1920, height=1920):
|
||
"""
|
||
图生图 SSE 流。
|
||
将用户上传的参考图 + 提示词发给豆包模型,生成正方形新图。
|
||
"""
|
||
config = settings.LLM_CONFIG
|
||
|
||
if not config.get('API_KEY'):
|
||
badge_image.generation_status = 'failed'
|
||
badge_image.save(update_fields=['generation_status'])
|
||
yield sse_event({'stage': 'error', 'message': 'AI 服务未配置'})
|
||
return
|
||
|
||
try:
|
||
from volcenginesdkarkruntime import Ark
|
||
except ImportError:
|
||
badge_image.generation_status = 'failed'
|
||
badge_image.save(update_fields=['generation_status'])
|
||
yield sse_event({'stage': 'error', 'message': 'AI SDK 未安装'})
|
||
return
|
||
|
||
# ── Stage 1: 生成中 ──
|
||
yield sse_event({
|
||
'stage': 'generating', 'progress': 20,
|
||
'message': '正在根据参考图生成...',
|
||
})
|
||
|
||
try:
|
||
client = Ark(api_key=config['API_KEY'])
|
||
|
||
# 先上传参考图到 OSS 获取 URL
|
||
ref_url = _upload_reference_image(image_bytes)
|
||
badge_image.reference_image_url = ref_url
|
||
badge_image.save(update_fields=['reference_image_url'])
|
||
|
||
# 构建提示词
|
||
full_prompt = _build_prompt(prompt or '基于参考图生成类似风格的图片', style)
|
||
image_model = config.get('IMAGE_MODEL_NAME', 'doubao-seedream-4-5-251128')
|
||
image_size = f'{width}x{height}'
|
||
|
||
# 将前端 strength(0.1~1.0) 映射到 guidance_scale(1.0~20.0)
|
||
# strength 越大 → 越贴近参考图 → guidance_scale 越低(更依赖图片)
|
||
# strength 越小 → 越自由发挥 → guidance_scale 越高(更依赖提示词)
|
||
guidance = 1.0 + (1.0 - strength) * 19.0
|
||
|
||
result = client.images.generate(
|
||
model=image_model,
|
||
prompt=full_prompt,
|
||
size=image_size,
|
||
response_format='url',
|
||
watermark=False,
|
||
image=ref_url,
|
||
guidance_scale=guidance,
|
||
)
|
||
|
||
temp_url = result.data[0].url
|
||
|
||
# ── Stage 2: 处理中 ──
|
||
yield sse_event({
|
||
'stage': 'processing', 'progress': 70,
|
||
'message': '正在处理图片...',
|
||
})
|
||
|
||
image_url = _download_and_upload(temp_url)
|
||
|
||
badge_image.image_url = image_url
|
||
badge_image.generation_status = 'completed'
|
||
badge_image.save(update_fields=['image_url', 'generation_status'])
|
||
|
||
# ── Stage 3: 完成 ──
|
||
yield sse_event({
|
||
'stage': 'done', 'progress': 100,
|
||
'message': '生成完成!',
|
||
'image_url': image_url,
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f'Badge i2i generation failed: {e}')
|
||
badge_image.generation_status = 'failed'
|
||
badge_image.save(update_fields=['generation_status'])
|
||
yield sse_event({
|
||
'stage': 'error', 'progress': 0,
|
||
'message': f'生成失败: {str(e)}',
|
||
})
|
||
|
||
|
||
def _build_prompt(prompt, style=None):
|
||
"""构建完整提示词:用户描述 + 风格后缀 + 正方形构图提示"""
|
||
parts = [prompt]
|
||
if style and style in STYLE_PROMPT_MAP:
|
||
parts.append(STYLE_PROMPT_MAP[style])
|
||
parts.append('正方形构图,居中主体,适合圆形裁切展示')
|
||
return ','.join(parts)
|
||
|
||
|
||
def _download_and_upload(temp_url):
|
||
"""从临时 URL 下载图片,上传到 OSS,返回持久化 URL"""
|
||
resp = req_lib.get(temp_url, timeout=60)
|
||
resp.raise_for_status()
|
||
|
||
from utils.oss import get_oss_client
|
||
oss_client = get_oss_client()
|
||
key = f"badge/generated/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.jpg"
|
||
oss_client.bucket.put_object(
|
||
key, resp.content,
|
||
headers={'Content-Type': 'image/jpeg'},
|
||
)
|
||
|
||
oss_config = settings.ALIYUN_OSS
|
||
if oss_config.get('CUSTOM_DOMAIN'):
|
||
return f"https://{oss_config['CUSTOM_DOMAIN']}/{key}"
|
||
return f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}/{key}"
|
||
|
||
|
||
def _upload_reference_image(image_bytes):
|
||
"""上传参考图到 OSS,返回 URL"""
|
||
from utils.oss import get_oss_client
|
||
oss_client = get_oss_client()
|
||
key = f"badge/reference/{datetime.now().strftime('%Y%m%d')}/{uuid.uuid4().hex}.jpg"
|
||
oss_client.bucket.put_object(
|
||
key, image_bytes,
|
||
headers={'Content-Type': 'image/jpeg'},
|
||
)
|
||
|
||
oss_config = settings.ALIYUN_OSS
|
||
if oss_config.get('CUSTOM_DOMAIN'):
|
||
return f"https://{oss_config['CUSTOM_DOMAIN']}/{key}"
|
||
return f"https://{oss_config['BUCKET_NAME']}.{oss_config['ENDPOINT']}/{key}"
|