174 lines
6.4 KiB
Python
174 lines
6.4 KiB
Python
from dataclasses import dataclass
|
|
import base64
|
|
from io import BytesIO
|
|
from typing import Any
|
|
|
|
import requests
|
|
from django.conf import settings
|
|
|
|
from .base import AIProviderResult
|
|
|
|
|
|
@dataclass
|
|
class VolcanoArkProvider:
|
|
api_key: str | None = None
|
|
base_url: str | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
self.api_key = self.api_key or settings.VOLCANO.get("ark_api_key")
|
|
self.base_url = self.base_url or settings.VOLCANO.get("ark_base_url")
|
|
|
|
def submit(self, payload: dict[str, Any]) -> AIProviderResult:
|
|
# The exact endpoint is resolved by ModelConfig; this adapter keeps IO centralized.
|
|
endpoint = payload.get("endpoint")
|
|
if not endpoint:
|
|
raise ValueError("Volcano request payload requires endpoint")
|
|
|
|
response = requests.post(
|
|
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
|
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
json=payload.get("body", {}),
|
|
timeout=60,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return AIProviderResult(
|
|
provider_task_id=str(data.get("id") or data.get("task_id") or ""),
|
|
status=str(data.get("status") or "submitted"),
|
|
payload=data,
|
|
)
|
|
|
|
def chat_completion(self, *, model: str, messages: list[dict[str, str]], endpoint: str = "chat/completions") -> dict[str, Any]:
|
|
if not self.api_key:
|
|
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
|
|
|
|
response = requests.post(
|
|
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
|
|
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
|
json={"model": model, "messages": messages},
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
@staticmethod
|
|
def extract_text(data: dict[str, Any]) -> str:
|
|
choices = data.get("choices") or []
|
|
if choices:
|
|
message = choices[0].get("message") or {}
|
|
content = message.get("content")
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
return "\n".join(str(item.get("text", "")) for item in content if isinstance(item, dict))
|
|
output = data.get("output")
|
|
if isinstance(output, str):
|
|
return output
|
|
raise ValueError("Volcano response does not contain text content")
|
|
|
|
def poll(self, provider_task_id: str) -> AIProviderResult:
|
|
if not provider_task_id:
|
|
raise ValueError("provider_task_id is required")
|
|
|
|
return AIProviderResult(provider_task_id=provider_task_id, status="polling", payload={})
|
|
|
|
def image_generation(
|
|
self,
|
|
*,
|
|
model: str,
|
|
prompt: str,
|
|
endpoint: str = "images/generations",
|
|
image: str | list[str] | None = None,
|
|
size: str = "2K",
|
|
) -> dict[str, Any]:
|
|
if not self.api_key:
|
|
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
|
|
body: dict[str, Any] = {
|
|
"model": model,
|
|
"prompt": prompt,
|
|
"response_format": "url",
|
|
"watermark": False,
|
|
"size": size,
|
|
"sequential_image_generation": "disabled",
|
|
}
|
|
if image:
|
|
body["image"] = image
|
|
response = requests.post(
|
|
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
|
|
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
|
json=body,
|
|
timeout=180,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def create_video_task(
|
|
self,
|
|
*,
|
|
model: str,
|
|
endpoint: str,
|
|
prompt: str,
|
|
ratio: str = "9:16",
|
|
duration: int = 15,
|
|
resolution: str = "720p",
|
|
reference_images: list[str] | None = None,
|
|
) -> dict[str, Any]:
|
|
if not self.api_key:
|
|
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
|
|
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
|
|
for image_url in reference_images or []:
|
|
content.append({"type": "image_url", "image_url": {"url": image_url}, "role": "reference_image"})
|
|
body = {
|
|
"model": model,
|
|
"content": content,
|
|
"ratio": ratio,
|
|
"duration": duration,
|
|
"resolution": resolution,
|
|
"watermark": False,
|
|
"generate_audio": False,
|
|
}
|
|
response = requests.post(
|
|
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
|
|
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
|
json=body,
|
|
timeout=120,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
def poll_video_task(self, *, endpoint: str, provider_task_id: str) -> dict[str, Any]:
|
|
if not self.api_key:
|
|
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
|
|
response = requests.get(
|
|
f"{self.base_url.rstrip('/')}/{endpoint.rstrip('/')}/{provider_task_id}",
|
|
headers={"Authorization": f"Bearer {self.api_key}"},
|
|
timeout=60,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
@staticmethod
|
|
def extract_first_media_url(data: dict[str, Any]) -> str:
|
|
items = data.get("data") or []
|
|
for item in items:
|
|
if item.get("url"):
|
|
return item["url"]
|
|
if item.get("b64_json"):
|
|
return item["b64_json"]
|
|
content = data.get("content") or {}
|
|
if content.get("video_url"):
|
|
return content["video_url"]
|
|
raise ValueError("Volcano response does not contain media url")
|
|
|
|
@staticmethod
|
|
def media_to_bytes(media: str) -> tuple[BytesIO, str]:
|
|
if media.startswith("http://") or media.startswith("https://"):
|
|
response = requests.get(media, timeout=180)
|
|
response.raise_for_status()
|
|
return BytesIO(response.content), response.headers.get("content-type", "application/octet-stream")
|
|
if "," in media and media.startswith("data:"):
|
|
header, raw = media.split(",", 1)
|
|
content_type = header.split(";")[0].replace("data:", "") or "application/octet-stream"
|
|
return BytesIO(base64.b64decode(raw)), content_type
|
|
return BytesIO(base64.b64decode(media)), "image/png"
|