Skip to content

Add a New AI Provider Processor

How to add a new AI provider to the Compute Server. A provider processor handles communication with an external AI service (e.g., Fal.ai, Replicate, Stability AI) and maps its capabilities to the task system.


Architecture Overview

Providers are registered in a central registry and matched to tasks based on their capabilities:

Task arrives
  └─▶ ProcessorRegistry
        └─▶ matches ProviderType + ProviderCapability
              └─▶ ProviderProcessor.process()   ◀── you implement this
                    └─▶ External AI service API

Key files:

File Purpose
app/models/tasks/provider_type.py:12-50 ProviderType enum (13 providers)
app/models/tasks/provider_capability.py:6-36 ProviderCapability enum (16 capabilities)
app/services/task/processor/provider_processor.py:20-54 ProviderProcessor base class
app/services/task/processor/processor_registry.py:12-107 ProcessorRegistry (13 registered processors)
app/services/task/processor/fal_standard_processor.py:35-101+ Concrete example (FalStandardProcessor)
app/models/tasks/provider_type.py:77-90 Queue routing (provider → Celery queue)

Step 1: Add the Enum Value

# app/models/tasks/provider_type.py

class ProviderType(str, Enum):
    FAL_STANDARD = "fal_standard"
    REPLICATE = "replicate"
    # ... existing values ...
    YOUR_PROVIDER = "your_provider"  # ← add here

Step 2: Add Capabilities (If Needed)

If your provider offers capabilities not yet represented, add them to the enum:

# app/models/tasks/provider_capability.py

class ProviderCapability(str, Enum):
    IMAGE_GENERATION = "image_generation"
    VIDEO_GENERATION = "video_generation"
    SUPER_RESOLUTION = "super_resolution"
    # ... existing values ...
    YOUR_CAPABILITY = "your_capability"  # ← add if needed

Skip this if your provider only implements existing capabilities (e.g., another image generation service).


Step 3: Create Provider Config Schema

Define the configuration your provider needs:

# app/schemas/providers/your_provider_config.py

from app.schemas.providers.base_provider_config import BaseProviderConfig


class YourProviderConfig(BaseProviderConfig):
    """Configuration for YourProvider."""

    api_key: str
    base_url: str = "https://api.yourprovider.com/v1"
    # Add provider-specific config fields
    max_retries: int = 3
    timeout_seconds: int = 300

Step 4: Implement the Processor

Create your processor extending ProviderProcessor:

# app/services/task/processor/your_provider_processor.py

from app.models.tasks.provider_type import ProviderType
from app.models.tasks.task_type import TaskType
from app.services.task.processor.provider_processor import ProviderProcessor


class YourProviderProcessor(ProviderProcessor):
    """Processor for YourProvider AI service."""

    def supported_task_types(self) -> list[TaskType]:
        """Task types this provider can handle."""
        return [
            TaskType.GENERATION,
            # Add other supported task types
        ]

    def provider_type(self) -> ProviderType:
        """The provider type this processor handles."""
        return ProviderType.YOUR_PROVIDER

    async def process(self, task) -> None:
        """Execute the task against the provider's API.

        Typical flow:
        1. Extract parameters from the task
        2. Build the provider-specific API request
        3. Submit to the external service
        4. Poll or wait for results
        5. Download/store outputs
        6. Update task status
        """
        await self._set_status(task, "processing")
        try:
            # 1. Build request from task parameters
            request_params = self._build_request(task)

            # 2. Submit to external API
            result = await self._submit_to_provider(request_params)

            # 3. Handle results (download, store, etc.)
            await self._handle_result(task, result)

            await self._set_status(task, "completed")
        except Exception as e:
            await self._set_status(task, "failed", error=str(e))
            raise

    def _build_request(self, task) -> dict:
        """Map task parameters to provider API format."""
        ...

    async def _submit_to_provider(self, params: dict):
        """Call the external provider API."""
        ...

    async def _handle_result(self, task, result) -> None:
        """Process and store provider results."""
        ...

Use FalStandardProcessor at app/services/task/processor/fal_standard_processor.py:35-101+ as a reference for the full pattern.

Error Handling

Always use _set_status() to transition task state. This handles DB persistence, Redis event publishing, and webhook notifications in one call.


Step 5: Register in ProcessorRegistry

Add your processor to the default registrations:

# app/services/task/processor/processor_registry.py

from .your_provider_processor import YourProviderProcessor


class ProcessorRegistry:
    def _register_default_processors(self):
        # ... existing registrations ...
        self.register(YourProviderProcessor())

The registry at processor_registry.py:12-107 maintains the mapping of provider types to processor instances and resolves which processor handles a given task.


Step 6: Add Queue Routing

Map your provider to a Celery queue so tasks are routed to the correct workers:

# app/models/tasks/provider_type.py (queue map section, L77-90)

PROVIDER_QUEUE_MAP: dict[ProviderType, str] = {
    ProviderType.FAL_STANDARD: "fal_queue",
    # ... existing mappings ...
    ProviderType.YOUR_PROVIDER: "your_provider_queue",  # ← add here
}

This determines which Celery worker pool processes your provider's tasks. You may reuse an existing queue if your provider has similar resource requirements.


Step 7: Add Tests

Write tests covering:

  • Provider type registration: ProcessorRegistry resolves your processor
  • Capability matching: Provider is matched for the correct task types
  • Process flow: Success path produces expected outputs
  • Error handling: Failures transition task to correct status
  • API mocking: External provider calls are properly mocked
# tests/services/task/processor/test_your_provider_processor.py

import pytest
from unittest.mock import AsyncMock, patch

from app.services.task.processor.your_provider_processor import (
    YourProviderProcessor,
)
from app.models.tasks.provider_type import ProviderType
from app.models.tasks.task_type import TaskType


class TestYourProviderProcessor:
    def test_provider_type(self):
        processor = YourProviderProcessor()
        assert processor.provider_type() == ProviderType.YOUR_PROVIDER

    def test_supported_task_types(self):
        processor = YourProviderProcessor()
        assert TaskType.GENERATION in processor.supported_task_types()

    @pytest.mark.asyncio
    async def test_process_success(self, mock_task):
        processor = YourProviderProcessor()
        with patch.object(processor, "_submit_to_provider", new_callable=AsyncMock):
            await processor.process(mock_task)
            # Assert task status transitions and outputs

Checklist

  • Enum value added to ProviderType
  • Capabilities added to ProviderCapability (if new ones needed)
  • Provider config schema created
  • Processor implements supported_task_types(), provider_type(), process()
  • Registered in ProcessorRegistry._register_default_processors()
  • Queue routing added in provider queue map
  • Tests for registration, capability matching, process flow, and error handling