103 lines
4.2 KiB
Python

from django.test import TestCase
from unittest.mock import patch
from rest_framework.test import APIClient
from apps.accounts.models import Team, TeamMember, User
from apps.ai.models import ModelConfig, ModelProvider
from apps.billing.models import CreditAccount, CreditLedger
from apps.products.models import Product
from apps.projects.models import Project, ProjectStage, ScriptVersion, VideoSegment
class ProjectApiTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(username="owner", password="pass")
self.team = Team.objects.create(name="E2E Team", owner=self.user)
TeamMember.objects.create(team=self.team, user=self.user, role=TeamMember.Role.OWNER)
CreditAccount.objects.create(team=self.team, balance="100.0000")
self.product = Product.objects.create(team=self.team, created_by=self.user, title="Test Product")
self.provider = ModelProvider.objects.create(
name="volcengine",
display_name="Volcano",
base_url="https://ark.cn-beijing.volces.com/api/v3",
)
self.model = ModelConfig.objects.create(
provider=self.provider,
name="doubao-seed-2-0-pro-260215",
display_name="Doubao",
capability=ModelConfig.Capability.TEXT,
endpoint="chat/completions",
unit_price="2.0000",
)
self.client = APIClient()
self.client.force_authenticate(self.user)
def test_create_project_initializes_pipeline(self):
response = self.client.post(
"/api/projects/",
{"name": "Launch Video", "product": str(self.product.id)},
format="json",
)
self.assertEqual(response.status_code, 201)
project = Project.objects.get(id=response.data["id"])
self.assertEqual(project.team, self.team)
self.assertEqual(project.created_by, self.user)
self.assertEqual(project.stages.count(), 5)
self.assertEqual(project.video_segments.count(), 4)
self.assertEqual(
list(project.stages.values_list("stage", flat=True)),
[
ProjectStage.Stage.SCRIPT,
ProjectStage.Stage.BASE_ASSETS,
ProjectStage.Stage.STORYBOARD,
ProjectStage.Stage.VIDEO,
ProjectStage.Stage.EXPORT,
],
)
self.assertEqual(
list(project.video_segments.values_list("target_duration_seconds", flat=True)),
[15, 15, 15, 15],
)
self.assertEqual(
list(project.video_segments.values_list("sort_order", flat=True)),
[0, 1, 2, 3],
)
self.assertTrue(VideoSegment.objects.filter(project=project).exists())
@patch("apps.ai.services.VolcanoArkProvider")
def test_generate_script_creates_script_segments_and_charges_credit(self, provider_cls):
provider = provider_cls.return_value
provider.chat_completion.return_value = {
"choices": [
{
"message": {
"content": "1. 开场吸引\n2. 展示卖点\n3. 使用场景\n4. 促单转化",
}
}
]
}
provider.extract_text.return_value = "1. 开场吸引\n2. 展示卖点\n3. 使用场景\n4. 促单转化"
project = Project.objects.create(team=self.team, created_by=self.user, product=self.product, name="Launch Video")
for stage in [
ProjectStage.Stage.SCRIPT,
ProjectStage.Stage.BASE_ASSETS,
ProjectStage.Stage.STORYBOARD,
ProjectStage.Stage.VIDEO,
ProjectStage.Stage.EXPORT,
]:
ProjectStage.objects.create(project=project, stage=stage)
response = self.client.post(
f"/api/projects/{project.id}/generate-script/",
{"prompt": "突出高转化"},
format="json",
)
self.assertEqual(response.status_code, 201)
script = ScriptVersion.objects.get(project=project)
self.assertEqual(script.segments.count(), 4)
self.assertEqual(CreditLedger.objects.filter(team=self.team, ledger_type=CreditLedger.Type.CHARGE).count(), 1)