from django.test import TestCase from unittest.mock import patch from rest_framework.test import APIClient from apps.accounts.models import Team, TeamMember, User from apps.ai.models import ModelConfig, ModelProvider from apps.billing.models import CreditAccount, CreditLedger from apps.products.models import Product from apps.projects.models import Project, ProjectStage, ScriptVersion, VideoSegment class ProjectApiTests(TestCase): def setUp(self): self.user = User.objects.create_user(username="owner", password="pass") self.team = Team.objects.create(name="E2E Team", owner=self.user) TeamMember.objects.create(team=self.team, user=self.user, role=TeamMember.Role.OWNER) CreditAccount.objects.create(team=self.team, balance="100.0000") self.product = Product.objects.create(team=self.team, created_by=self.user, title="Test Product") self.provider = ModelProvider.objects.create( name="volcengine", display_name="Volcano", base_url="https://ark.cn-beijing.volces.com/api/v3", ) self.model = ModelConfig.objects.create( provider=self.provider, name="doubao-seed-2-0-pro-260215", display_name="Doubao", capability=ModelConfig.Capability.TEXT, endpoint="chat/completions", unit_price="2.0000", ) self.client = APIClient() self.client.force_authenticate(self.user) def test_create_project_initializes_pipeline(self): response = self.client.post( "/api/projects/", {"name": "Launch Video", "product": str(self.product.id)}, format="json", ) self.assertEqual(response.status_code, 201) project = Project.objects.get(id=response.data["id"]) self.assertEqual(project.team, self.team) self.assertEqual(project.created_by, self.user) self.assertEqual(project.stages.count(), 5) self.assertEqual(project.video_segments.count(), 4) self.assertEqual( list(project.stages.values_list("stage", flat=True)), [ ProjectStage.Stage.SCRIPT, ProjectStage.Stage.BASE_ASSETS, ProjectStage.Stage.STORYBOARD, ProjectStage.Stage.VIDEO, ProjectStage.Stage.EXPORT, ], ) self.assertEqual( list(project.video_segments.values_list("target_duration_seconds", flat=True)), [15, 15, 15, 15], ) self.assertEqual( list(project.video_segments.values_list("sort_order", flat=True)), [0, 1, 2, 3], ) self.assertTrue(VideoSegment.objects.filter(project=project).exists()) @patch("apps.ai.services.VolcanoArkProvider") def test_generate_script_creates_script_segments_and_charges_credit(self, provider_cls): provider = provider_cls.return_value provider.chat_completion.return_value = { "choices": [ { "message": { "content": "1. 开场吸引\n2. 展示卖点\n3. 使用场景\n4. 促单转化", } } ] } provider.extract_text.return_value = "1. 开场吸引\n2. 展示卖点\n3. 使用场景\n4. 促单转化" project = Project.objects.create(team=self.team, created_by=self.user, product=self.product, name="Launch Video") for stage in [ ProjectStage.Stage.SCRIPT, ProjectStage.Stage.BASE_ASSETS, ProjectStage.Stage.STORYBOARD, ProjectStage.Stage.VIDEO, ProjectStage.Stage.EXPORT, ]: ProjectStage.objects.create(project=project, stage=stage) response = self.client.post( f"/api/projects/{project.id}/generate-script/", {"prompt": "突出高转化"}, format="json", ) self.assertEqual(response.status_code, 201) script = ScriptVersion.objects.get(project=project) self.assertEqual(script.segments.count(), 4) self.assertEqual(CreditLedger.objects.filter(team=self.team, ledger_type=CreditLedger.Type.CHARGE).count(), 1)