video-shuoshan/backend/tests/test_ark_prompt_format.py
seaislee1209 13440f2709
All checks were successful
Build and Deploy / build-and-deploy (push) Successful in 5m58s
feat: v0.19.2 prompt 里 @素材名 按火山规范转为「图片N/视频N/音频N」
火山 Seedance 模型只能理解"素材类型+序号"的指代(官方文档 FAQ Q3);
对文件名 / asset id / URL 类字符串一律读不懂,只能按 content 数组里
图片出现顺序瞎猜谁是谁,导致用户看到的"人物颠倒"概率性现象(典型
任务 cgt-20260422163517-4k8x6)。

改动
- backend/apps/generation/views.py:
  - 新增 _format_prompt_for_ark(prompt, label_placeholders) helper
    用 str.replace 避 regex 元字符崩溃, 按 label 长度降序防子串吞噬
  - video_generate_view references 循环同步维护 image_n/video_n/audio_n
    三个独立计数器 + label_to_placeholder 映射
  - 关键不变量: 任意时刻 counter == content_items 里该类型 *_url 已 push 数
    group 老路径 counter 照推但不登记 label + WARNING, 避免编号错位
  - 调 create_task 前构造 api_prompt 传给火山, DB.prompt 保留用户原文
    (带 @xxx.jpg) 以便 reEdit 重建带缩略图标签

测试覆盖 14 项 (airlabs-test MySQL 全绿)
- 单元 9 项: 基础替换 / 多类型独立计数 / 重复 @ / 子串冲突 / 正则元字符 /
  空 mapping / label 未 @ / 中文标点 / 空 label 跳过
- 集成 5 项: local 正常替换 / DB 原文保留 / group 老路径不换 + WARN /
  混用 local+group counter 对齐 (关键回归) / 图片音频独立计数

兼容性
- reEdit: DB 保留原文, PromptInput.rebuildMentionSpans 按 @label 正则
  仍能重建 span, 缩略图正常
- regenerate: 走同一 POST /api/v1/video/generate, 二次过转换
- Celery: 只 query 不重发, 不受影响

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 20:52:32 +08:00

240 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 prompt 转火山「图片N/视频N/音频N」格式 — v0.19.1+
火山模型无法理解文件名/asset id必须用「素材类型+序号」指代(官方文档 FAQ Q3
本测试文件覆盖:
单元测试:纯函数 _format_prompt_for_ark
集成测试video_generate_view 端到端(含 counter 对齐关键回归)
"""
import os
import sys
import django
from unittest import mock
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
django.setup()
import unittest
from django.test import TestCase, override_settings
from django.contrib.auth import get_user_model
from rest_framework.test import APIClient
from apps.accounts.models import Team
from apps.generation.models import QuotaConfig, AssetGroup, Asset, GenerationRecord
from apps.generation.views import _format_prompt_for_ark
User = get_user_model()
# ────────────────────────────────────────────────
# 单元测试:纯函数 _format_prompt_for_ark
# ────────────────────────────────────────────────
class TestFormatPromptForArk(unittest.TestCase):
"""覆盖各种 label 替换场景(字符串级别)。"""
def test_basic_replacement(self):
"""@label 替换为 placeholder。"""
out = _format_prompt_for_ark('@碧碧.jpg 是碧儿', [('碧碧.jpg', '图片1')])
self.assertEqual(out, '图片1 是碧儿')
def test_multi_type_independent_counters(self):
"""图片/视频/音频各自独立编号。"""
pairs = [
('img1.jpg', '图片1'),
('video1.mp4', '视频1'),
('audio1.mp3', '音频1'),
]
out = _format_prompt_for_ark('用 @img1.jpg @video1.mp4 和 @audio1.mp3', pairs)
self.assertEqual(out, '用 图片1 视频1 和 音频1')
def test_same_label_multiple_at_signs(self):
"""同一 label 在 prompt 里 @ 多次,全部替换成同一 placeholderstr.replace 全局)。"""
out = _format_prompt_for_ark('@foo 然后 @foo 再 @foo', [('foo', '图片1')])
self.assertEqual(out, '图片1 然后 图片1 再 图片1')
def test_substring_conflict_long_first(self):
"""当存在子串关系('''碧碧' 的子串),长 label 必须先替换。"""
# 模拟调用方已经按长度降序传入
pairs = [('碧碧', '图片2'), ('', '图片1')]
out = _format_prompt_for_ark('@碧碧 和 @碧 是姐妹', pairs)
self.assertEqual(out, '图片2 和 图片1 是姐妹')
def test_label_with_regex_metachars(self):
"""label 含正则元字符([ ] + . * ? 等str.replace 不当正则处理。"""
pairs = [
('[test].png', '图片1'),
('a+b.png', '图片2'),
('a.b*.png', '图片3'),
]
prompt = '@[test].png 和 @a+b.png 还有 @a.b*.png'
out = _format_prompt_for_ark(prompt, pairs)
self.assertEqual(out, '图片1 和 图片2 还有 图片3')
def test_empty_mapping(self):
"""无 @ 素材的 prompt 原样返回。"""
out = _format_prompt_for_ark('今天天气真好', [])
self.assertEqual(out, '今天天气真好')
def test_label_in_mapping_not_in_prompt(self):
"""mapping 里有 label 但 prompt 里没 @ → 不动。"""
out = _format_prompt_for_ark('一段普通文字', [('foo.jpg', '图片1')])
self.assertEqual(out, '一段普通文字')
def test_chinese_punctuation_around_label(self):
"""中文标点不影响替换。"""
out = _format_prompt_for_ark('@碧碧.jpg"你好。"', [('碧碧.jpg', '图片1')])
self.assertEqual(out, '图片1"你好。"')
def test_empty_label_skipped(self):
"""label 为空字符串时跳过,不崩溃。"""
out = _format_prompt_for_ark('@real.jpg 内容', [('', '图片0'), ('real.jpg', '图片1')])
self.assertEqual(out, '图片1 内容')
# ────────────────────────────────────────────────
# 集成测试video_generate_view
# ────────────────────────────────────────────────
@override_settings(SEEDANCE_ENABLED=True, ARK_API_KEY='fake-test-key')
class TestVideoGenerateArkPrompt(TestCase):
"""经 POST /api/v1/video/generate 验证 prompt 转换 + DB 原文保留 + counter 对齐。"""
def setUp(self):
QuotaConfig.objects.get_or_create(pk=1)
self.team = Team.objects.create(
name='test-ark-prompt',
is_active=True,
monthly_spending_limit=10000,
markup_percentage=0,
balance=10000,
frozen_amount=0,
)
self.user = User.objects.create_user(
username='ark_prompt_user',
email='arkprompt@example.com',
password='testpass123',
team=self.team,
spending_limit=-1,
daily_generation_limit=-1,
monthly_generation_limit=-1,
)
self.client = APIClient()
self.client.force_authenticate(user=self.user)
# 建两个 local asset 方便多场景复用
self.group_a = AssetGroup.objects.create(
team=self.team, remote_group_id='group-fake-a', name='角色A',
)
self.asset_bibi = Asset.objects.create(
group=self.group_a, remote_asset_id='asset-fake-bibi', name='碧碧.jpg',
url='https://fake/bibi.jpg', asset_type='Image', status='active',
)
self.asset_bubu = Asset.objects.create(
group=self.group_a, remote_asset_id='asset-fake-bubu', name='布布.jpg',
url='https://fake/bubu.jpg', asset_type='Image', status='active',
)
def _post_generate(self, prompt, references):
return self.client.post('/api/v1/video/generate', {
'prompt': prompt,
'mode': 'universal',
'model': 'seedance_2.0',
'aspect_ratio': '9:16',
'duration': 5,
'resolution': '720p',
'references': references,
}, format='json')
@mock.patch('apps.generation.tasks.poll_video_task')
@mock.patch('apps.generation.views.create_task')
def test_view_converts_prompt_for_local_assets(self, mock_create_task, mock_poll):
"""prompt 里两个 @local 素材 → 发给火山的 prompt 变成「图片1/图片2」。"""
mock_create_task.return_value = {'id': 'ark-mock-1'}
prompt = '@碧碧.jpg 是碧儿,@布布.jpg 是步若'
resp = self._post_generate(prompt, [
{'url': f'asset://local-{self.asset_bibi.id}', 'type': 'image', 'label': '碧碧.jpg'},
{'url': f'asset://local-{self.asset_bubu.id}', 'type': 'image', 'label': '布布.jpg'},
])
self.assertEqual(resp.status_code, 202, resp.content)
self.assertTrue(mock_create_task.called, 'create_task must be called')
sent_prompt = mock_create_task.call_args.kwargs['prompt']
self.assertEqual(sent_prompt, '图片1 是碧儿图片2 是步若')
@mock.patch('apps.generation.tasks.poll_video_task')
@mock.patch('apps.generation.views.create_task')
def test_view_db_prompt_unchanged_for_reedit(self, mock_create_task, mock_poll):
"""DB.prompt 必须保留用户原文(含 @xxx.jpgreEdit 才能重建带缩略图的标签。"""
mock_create_task.return_value = {'id': 'ark-mock-2'}
prompt = '@碧碧.jpg 走过来'
resp = self._post_generate(prompt, [
{'url': f'asset://local-{self.asset_bibi.id}', 'type': 'image', 'label': '碧碧.jpg'},
])
self.assertEqual(resp.status_code, 202, resp.content)
rec = GenerationRecord.objects.filter(user=self.user).order_by('-id').first()
self.assertIsNotNone(rec)
self.assertEqual(rec.prompt, prompt) # 原文,不含 '图片1'
self.assertIn('@碧碧.jpg', rec.prompt)
self.assertNotIn('图片1', rec.prompt)
@mock.patch('apps.generation.tasks.poll_video_task')
@mock.patch('apps.generation.views.create_task')
def test_legacy_group_url_skips_replacement(self, mock_create_task, mock_poll):
"""asset://group-{id} 老路径counter 递增但不登记 labelWARNING 日志,@组名 原样留在 prompt。"""
mock_create_task.return_value = {'id': 'ark-mock-3'}
prompt = '@角色A 做动作'
with self.assertLogs('apps.generation.views', level='WARNING') as cm:
resp = self._post_generate(prompt, [
{'url': f'asset://group-{self.group_a.id}', 'type': 'image', 'label': '角色A'},
])
self.assertEqual(resp.status_code, 202, resp.content)
sent_prompt = mock_create_task.call_args.kwargs['prompt']
self.assertEqual(sent_prompt, prompt) # 未替换
# 验证 WARNING log
self.assertTrue(any('legacy asset://group-' in line for line in cm.output),
f'expected legacy warning, got: {cm.output}')
@mock.patch('apps.generation.tasks.poll_video_task')
@mock.patch('apps.generation.views.create_task')
def test_counter_alignment_with_mixed_local_and_group(self, mock_create_task, mock_poll):
"""关键回归group 展开 2 张图后,紧跟的 local asset 的 label 必须映射到「图片3」不是「图片1」。"""
mock_create_task.return_value = {'id': 'ark-mock-4'}
prompt = '@foo 是主角'
resp = self._post_generate(prompt, [
{'url': f'asset://group-{self.group_a.id}', 'type': 'image', 'label': '角色A'}, # 展开 2 张图
{'url': f'asset://local-{self.asset_bibi.id}', 'type': 'image', 'label': 'foo'},
])
self.assertEqual(resp.status_code, 202, resp.content)
sent_prompt = mock_create_task.call_args.kwargs['prompt']
# foo 对应第 3 张 imagegroup 两张在前)
self.assertEqual(sent_prompt, '图片3 是主角')
# content_items 验证长度
sent_content_items = mock_create_task.call_args.kwargs['content_items']
image_items = [it for it in sent_content_items if it['type'] == 'image_url']
self.assertEqual(len(image_items), 3)
@mock.patch('apps.generation.tasks.poll_video_task')
@mock.patch('apps.generation.views.create_task')
def test_counter_alignment_mixed_types(self, mock_create_task, mock_poll):
"""图片/音频独立计数 — 图片序号不因音频夹在中间而跳变。"""
mock_create_task.return_value = {'id': 'ark-mock-5'}
# 新建一个 audio asset
asset_audio = Asset.objects.create(
group=self.group_a, remote_asset_id='asset-fake-audio', name='speech.mp3',
url='https://fake/speech.mp3', asset_type='Audio', status='active',
)
prompt = '@碧碧.jpg 说 @speech.mp3 的话,@布布.jpg 听'
resp = self._post_generate(prompt, [
{'url': f'asset://local-{self.asset_bibi.id}', 'type': 'image', 'label': '碧碧.jpg'},
{'url': f'asset://local-{asset_audio.id}', 'type': 'audio', 'label': 'speech.mp3'},
{'url': f'asset://local-{self.asset_bubu.id}', 'type': 'image', 'label': '布布.jpg'},
])
self.assertEqual(resp.status_code, 202, resp.content)
sent_prompt = mock_create_task.call_args.kwargs['prompt']
# 碧碧=图片1, speech=音频1, 布布=图片2图片/音频独立计数)
self.assertEqual(sent_prompt, '图片1 说 音频1 的话图片2 听')
if __name__ == '__main__':
unittest.main(verbosity=2)