103 lines
4.2 KiB
Python
103 lines
4.2 KiB
Python
from django.test import TestCase
|
|
from unittest.mock import patch
|
|
|
|
from rest_framework.test import APIClient
|
|
|
|
from apps.accounts.models import Team, TeamMember, User
|
|
from apps.ai.models import ModelConfig, ModelProvider
|
|
from apps.billing.models import CreditAccount, CreditLedger
|
|
from apps.products.models import Product
|
|
from apps.projects.models import Project, ProjectStage, ScriptVersion, VideoSegment
|
|
|
|
|
|
class ProjectApiTests(TestCase):
|
|
def setUp(self):
|
|
self.user = User.objects.create_user(username="owner", password="pass")
|
|
self.team = Team.objects.create(name="E2E Team", owner=self.user)
|
|
TeamMember.objects.create(team=self.team, user=self.user, role=TeamMember.Role.OWNER)
|
|
CreditAccount.objects.create(team=self.team, balance="100.0000")
|
|
self.product = Product.objects.create(team=self.team, created_by=self.user, title="Test Product")
|
|
self.provider = ModelProvider.objects.create(
|
|
name="volcengine",
|
|
display_name="Volcano",
|
|
base_url="https://ark.cn-beijing.volces.com/api/v3",
|
|
)
|
|
self.model = ModelConfig.objects.create(
|
|
provider=self.provider,
|
|
name="doubao-seed-2-0-pro-260215",
|
|
display_name="Doubao",
|
|
capability=ModelConfig.Capability.TEXT,
|
|
endpoint="chat/completions",
|
|
unit_price="2.0000",
|
|
)
|
|
self.client = APIClient()
|
|
self.client.force_authenticate(self.user)
|
|
|
|
def test_create_project_initializes_pipeline(self):
|
|
response = self.client.post(
|
|
"/api/projects/",
|
|
{"name": "Launch Video", "product": str(self.product.id)},
|
|
format="json",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 201)
|
|
project = Project.objects.get(id=response.data["id"])
|
|
self.assertEqual(project.team, self.team)
|
|
self.assertEqual(project.created_by, self.user)
|
|
self.assertEqual(project.stages.count(), 5)
|
|
self.assertEqual(project.video_segments.count(), 4)
|
|
self.assertEqual(
|
|
list(project.stages.values_list("stage", flat=True)),
|
|
[
|
|
ProjectStage.Stage.SCRIPT,
|
|
ProjectStage.Stage.BASE_ASSETS,
|
|
ProjectStage.Stage.STORYBOARD,
|
|
ProjectStage.Stage.VIDEO,
|
|
ProjectStage.Stage.EXPORT,
|
|
],
|
|
)
|
|
self.assertEqual(
|
|
list(project.video_segments.values_list("target_duration_seconds", flat=True)),
|
|
[15, 15, 15, 15],
|
|
)
|
|
self.assertEqual(
|
|
list(project.video_segments.values_list("sort_order", flat=True)),
|
|
[0, 1, 2, 3],
|
|
)
|
|
self.assertTrue(VideoSegment.objects.filter(project=project).exists())
|
|
|
|
@patch("apps.ai.services.VolcanoArkProvider")
|
|
def test_generate_script_creates_script_segments_and_charges_credit(self, provider_cls):
|
|
provider = provider_cls.return_value
|
|
provider.chat_completion.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"content": "1. 开场吸引\n2. 展示卖点\n3. 使用场景\n4. 促单转化",
|
|
}
|
|
}
|
|
]
|
|
}
|
|
provider.extract_text.return_value = "1. 开场吸引\n2. 展示卖点\n3. 使用场景\n4. 促单转化"
|
|
project = Project.objects.create(team=self.team, created_by=self.user, product=self.product, name="Launch Video")
|
|
for stage in [
|
|
ProjectStage.Stage.SCRIPT,
|
|
ProjectStage.Stage.BASE_ASSETS,
|
|
ProjectStage.Stage.STORYBOARD,
|
|
ProjectStage.Stage.VIDEO,
|
|
ProjectStage.Stage.EXPORT,
|
|
]:
|
|
ProjectStage.objects.create(project=project, stage=stage)
|
|
|
|
response = self.client.post(
|
|
f"/api/projects/{project.id}/generate-script/",
|
|
{"prompt": "突出高转化"},
|
|
format="json",
|
|
)
|
|
|
|
self.assertEqual(response.status_code, 201)
|
|
script = ScriptVersion.objects.get(project=project)
|
|
self.assertEqual(script.segments.count(), 4)
|
|
self.assertEqual(CreditLedger.objects.filter(team=self.team, ledger_type=CreditLedger.Type.CHARGE).count(), 1)
|
|
|