174 lines
6.4 KiB
Python

from dataclasses import dataclass
import base64
from io import BytesIO
from typing import Any
import requests
from django.conf import settings
from .base import AIProviderResult
@dataclass
class VolcanoArkProvider:
api_key: str | None = None
base_url: str | None = None
def __post_init__(self) -> None:
self.api_key = self.api_key or settings.VOLCANO.get("ark_api_key")
self.base_url = self.base_url or settings.VOLCANO.get("ark_base_url")
def submit(self, payload: dict[str, Any]) -> AIProviderResult:
# The exact endpoint is resolved by ModelConfig; this adapter keeps IO centralized.
endpoint = payload.get("endpoint")
if not endpoint:
raise ValueError("Volcano request payload requires endpoint")
response = requests.post(
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
headers={"Authorization": f"Bearer {self.api_key}"},
json=payload.get("body", {}),
timeout=60,
)
response.raise_for_status()
data = response.json()
return AIProviderResult(
provider_task_id=str(data.get("id") or data.get("task_id") or ""),
status=str(data.get("status") or "submitted"),
payload=data,
)
def chat_completion(self, *, model: str, messages: list[dict[str, str]], endpoint: str = "chat/completions") -> dict[str, Any]:
if not self.api_key:
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
response = requests.post(
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json={"model": model, "messages": messages},
timeout=120,
)
response.raise_for_status()
return response.json()
@staticmethod
def extract_text(data: dict[str, Any]) -> str:
choices = data.get("choices") or []
if choices:
message = choices[0].get("message") or {}
content = message.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
return "\n".join(str(item.get("text", "")) for item in content if isinstance(item, dict))
output = data.get("output")
if isinstance(output, str):
return output
raise ValueError("Volcano response does not contain text content")
def poll(self, provider_task_id: str) -> AIProviderResult:
if not provider_task_id:
raise ValueError("provider_task_id is required")
return AIProviderResult(provider_task_id=provider_task_id, status="polling", payload={})
def image_generation(
self,
*,
model: str,
prompt: str,
endpoint: str = "images/generations",
image: str | list[str] | None = None,
size: str = "2K",
) -> dict[str, Any]:
if not self.api_key:
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
body: dict[str, Any] = {
"model": model,
"prompt": prompt,
"response_format": "url",
"watermark": False,
"size": size,
"sequential_image_generation": "disabled",
}
if image:
body["image"] = image
response = requests.post(
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json=body,
timeout=180,
)
response.raise_for_status()
return response.json()
def create_video_task(
self,
*,
model: str,
endpoint: str,
prompt: str,
ratio: str = "9:16",
duration: int = 15,
resolution: str = "720p",
reference_images: list[str] | None = None,
) -> dict[str, Any]:
if not self.api_key:
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
content: list[dict[str, Any]] = [{"type": "text", "text": prompt}]
for image_url in reference_images or []:
content.append({"type": "image_url", "image_url": {"url": image_url}, "role": "reference_image"})
body = {
"model": model,
"content": content,
"ratio": ratio,
"duration": duration,
"resolution": resolution,
"watermark": False,
"generate_audio": False,
}
response = requests.post(
f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
json=body,
timeout=120,
)
response.raise_for_status()
return response.json()
def poll_video_task(self, *, endpoint: str, provider_task_id: str) -> dict[str, Any]:
if not self.api_key:
raise ValueError("VOLCANO_ARK_API_KEY is not configured")
response = requests.get(
f"{self.base_url.rstrip('/')}/{endpoint.rstrip('/')}/{provider_task_id}",
headers={"Authorization": f"Bearer {self.api_key}"},
timeout=60,
)
response.raise_for_status()
return response.json()
@staticmethod
def extract_first_media_url(data: dict[str, Any]) -> str:
items = data.get("data") or []
for item in items:
if item.get("url"):
return item["url"]
if item.get("b64_json"):
return item["b64_json"]
content = data.get("content") or {}
if content.get("video_url"):
return content["video_url"]
raise ValueError("Volcano response does not contain media url")
@staticmethod
def media_to_bytes(media: str) -> tuple[BytesIO, str]:
if media.startswith("http://") or media.startswith("https://"):
response = requests.get(media, timeout=180)
response.raise_for_status()
return BytesIO(response.content), response.headers.get("content-type", "application/octet-stream")
if "," in media and media.startswith("data:"):
header, raw = media.split(",", 1)
content_type = header.split(";")[0].replace("data:", "") or "application/octet-stream"
return BytesIO(base64.b64decode(raw)), content_type
return BytesIO(base64.b64decode(media)), "image/png"