Add a New AI Provider Processor¶
How to add a new AI provider to the Compute Server. A provider processor handles communication with an external AI service (e.g., Fal.ai, Replicate, Stability AI) and maps its capabilities to the task system.
Architecture Overview¶
Providers are registered in a central registry and matched to tasks based on their capabilities:
Task arrives
└─▶ ProcessorRegistry
└─▶ matches ProviderType + ProviderCapability
└─▶ ProviderProcessor.process() ◀── you implement this
└─▶ External AI service API
Key files:
| File | Purpose |
|---|---|
app/models/tasks/provider_type.py:12-50 |
ProviderType enum (13 providers) |
app/models/tasks/provider_capability.py:6-36 |
ProviderCapability enum (16 capabilities) |
app/services/task/processor/provider_processor.py:20-54 |
ProviderProcessor base class |
app/services/task/processor/processor_registry.py:12-107 |
ProcessorRegistry (13 registered processors) |
app/services/task/processor/fal_standard_processor.py:35-101+ |
Concrete example (FalStandardProcessor) |
app/models/tasks/provider_type.py:77-90 |
Queue routing (provider → Celery queue) |
Step 1: Add the Enum Value¶
# app/models/tasks/provider_type.py
class ProviderType(str, Enum):
FAL_STANDARD = "fal_standard"
REPLICATE = "replicate"
# ... existing values ...
YOUR_PROVIDER = "your_provider" # ← add here
Step 2: Add Capabilities (If Needed)¶
If your provider offers capabilities not yet represented, add them to the enum:
# app/models/tasks/provider_capability.py
class ProviderCapability(str, Enum):
IMAGE_GENERATION = "image_generation"
VIDEO_GENERATION = "video_generation"
SUPER_RESOLUTION = "super_resolution"
# ... existing values ...
YOUR_CAPABILITY = "your_capability" # ← add if needed
Skip this if your provider only implements existing capabilities (e.g., another image generation service).
Step 3: Create Provider Config Schema¶
Define the configuration your provider needs:
# app/schemas/providers/your_provider_config.py
from app.schemas.providers.base_provider_config import BaseProviderConfig
class YourProviderConfig(BaseProviderConfig):
"""Configuration for YourProvider."""
api_key: str
base_url: str = "https://api.yourprovider.com/v1"
# Add provider-specific config fields
max_retries: int = 3
timeout_seconds: int = 300
Step 4: Implement the Processor¶
Create your processor extending ProviderProcessor:
# app/services/task/processor/your_provider_processor.py
from app.models.tasks.provider_type import ProviderType
from app.models.tasks.task_type import TaskType
from app.services.task.processor.provider_processor import ProviderProcessor
class YourProviderProcessor(ProviderProcessor):
"""Processor for YourProvider AI service."""
def supported_task_types(self) -> list[TaskType]:
"""Task types this provider can handle."""
return [
TaskType.GENERATION,
# Add other supported task types
]
def provider_type(self) -> ProviderType:
"""The provider type this processor handles."""
return ProviderType.YOUR_PROVIDER
async def process(self, task) -> None:
"""Execute the task against the provider's API.
Typical flow:
1. Extract parameters from the task
2. Build the provider-specific API request
3. Submit to the external service
4. Poll or wait for results
5. Download/store outputs
6. Update task status
"""
await self._set_status(task, "processing")
try:
# 1. Build request from task parameters
request_params = self._build_request(task)
# 2. Submit to external API
result = await self._submit_to_provider(request_params)
# 3. Handle results (download, store, etc.)
await self._handle_result(task, result)
await self._set_status(task, "completed")
except Exception as e:
await self._set_status(task, "failed", error=str(e))
raise
def _build_request(self, task) -> dict:
"""Map task parameters to provider API format."""
...
async def _submit_to_provider(self, params: dict):
"""Call the external provider API."""
...
async def _handle_result(self, task, result) -> None:
"""Process and store provider results."""
...
Use FalStandardProcessor at app/services/task/processor/fal_standard_processor.py:35-101+ as a reference for the full pattern.
Error Handling
Always use _set_status() to transition task state. This handles DB persistence, Redis event publishing, and webhook notifications in one call.
Step 5: Register in ProcessorRegistry¶
Add your processor to the default registrations:
# app/services/task/processor/processor_registry.py
from .your_provider_processor import YourProviderProcessor
class ProcessorRegistry:
def _register_default_processors(self):
# ... existing registrations ...
self.register(YourProviderProcessor())
The registry at processor_registry.py:12-107 maintains the mapping of provider types to processor instances and resolves which processor handles a given task.
Step 6: Add Queue Routing¶
Map your provider to a Celery queue so tasks are routed to the correct workers:
# app/models/tasks/provider_type.py (queue map section, L77-90)
PROVIDER_QUEUE_MAP: dict[ProviderType, str] = {
ProviderType.FAL_STANDARD: "fal_queue",
# ... existing mappings ...
ProviderType.YOUR_PROVIDER: "your_provider_queue", # ← add here
}
This determines which Celery worker pool processes your provider's tasks. You may reuse an existing queue if your provider has similar resource requirements.
Step 7: Add Tests¶
Write tests covering:
- Provider type registration:
ProcessorRegistryresolves your processor - Capability matching: Provider is matched for the correct task types
- Process flow: Success path produces expected outputs
- Error handling: Failures transition task to correct status
- API mocking: External provider calls are properly mocked
# tests/services/task/processor/test_your_provider_processor.py
import pytest
from unittest.mock import AsyncMock, patch
from app.services.task.processor.your_provider_processor import (
YourProviderProcessor,
)
from app.models.tasks.provider_type import ProviderType
from app.models.tasks.task_type import TaskType
class TestYourProviderProcessor:
def test_provider_type(self):
processor = YourProviderProcessor()
assert processor.provider_type() == ProviderType.YOUR_PROVIDER
def test_supported_task_types(self):
processor = YourProviderProcessor()
assert TaskType.GENERATION in processor.supported_task_types()
@pytest.mark.asyncio
async def test_process_success(self, mock_task):
processor = YourProviderProcessor()
with patch.object(processor, "_submit_to_provider", new_callable=AsyncMock):
await processor.process(mock_task)
# Assert task status transitions and outputs
Checklist¶
- Enum value added to
ProviderType - Capabilities added to
ProviderCapability(if new ones needed) - Provider config schema created
- Processor implements
supported_task_types(),provider_type(),process() - Registered in
ProcessorRegistry._register_default_processors() - Queue routing added in provider queue map
- Tests for registration, capability matching, process flow, and error handling