Concurrency Patterns¶
Common patterns for solving concurrent programming challenges.
Overview¶
| Pattern | Problem Solved | Use Case |
|---|---|---|
| Producer-Consumer | Decouple production from consumption | Task queues, data pipelines |
| Worker Pools | Manage parallel workers | Batch processing, job queues |
Additional Patterns¶
Fan-Out/Fan-In¶
Distribute work, then aggregate results.
import asyncio
async def fan_out_fan_in(items, process_fn, max_concurrent=10):
"""Process items concurrently, collect results."""
semaphore = asyncio.Semaphore(max_concurrent)
async def bounded_process(item):
async with semaphore:
return await process_fn(item)
tasks = [bounded_process(item) for item in items]
return await asyncio.gather(*tasks)
# Usage
results = await fan_out_fan_in(urls, fetch_url, max_concurrent=5)
Scatter-Gather¶
Broadcast request, collect responses.
async def scatter_gather(request, handlers, timeout=5.0):
"""Send request to all handlers, collect responses."""
tasks = [asyncio.create_task(handler(request)) for handler in handlers]
done, pending = await asyncio.wait(
tasks,
timeout=timeout,
return_when=asyncio.ALL_COMPLETED
)
# Cancel pending
for task in pending:
task.cancel()
# Collect results
results = []
for task in done:
try:
results.append(task.result())
except Exception as e:
results.append(e)
return results
Pipeline¶
Chain of processing stages.
import asyncio
async def pipeline(*stages):
"""Create a processing pipeline."""
queues = [asyncio.Queue() for _ in stages]
async def stage_worker(stage_fn, in_queue, out_queue):
while True:
item = await in_queue.get()
if item is None:
await out_queue.put(None)
break
result = await stage_fn(item)
await out_queue.put(result)
in_queue.task_done()
# Start stage workers
tasks = []
for i, stage in enumerate(stages):
in_q = queues[i] if i > 0 else queues[0]
out_q = queues[i + 1] if i < len(stages) - 1 else queues[-1]
task = asyncio.create_task(stage_worker(stage, in_q, out_q))
tasks.append(task)
return queues[0], queues[-1], tasks
Bulkhead¶
Isolate failures to prevent cascade.
import asyncio
from contextlib import asynccontextmanager
class Bulkhead:
"""Limit concurrent access to a resource."""
def __init__(self, max_concurrent: int, name: str = "bulkhead"):
self.semaphore = asyncio.Semaphore(max_concurrent)
self.name = name
@asynccontextmanager
async def acquire(self, timeout: float = None):
try:
acquired = await asyncio.wait_for(
self.semaphore.acquire(),
timeout=timeout
)
yield acquired
except asyncio.TimeoutError:
raise BulkheadFullError(f"{self.name} is full")
finally:
self.semaphore.release()
# Usage: Isolate database and external API
db_bulkhead = Bulkhead(20, "database")
api_bulkhead = Bulkhead(10, "external_api")
async def db_query():
async with db_bulkhead.acquire():
return await execute_query()
async def api_call():
async with api_bulkhead.acquire():
return await call_external_api()
Circuit Breaker¶
Fail fast when service is unhealthy.
import asyncio
from enum import Enum
from datetime import datetime, timedelta
class CircuitState(Enum):
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"
class CircuitBreaker:
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: float = 30.0,
half_open_requests: int = 1
):
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.half_open_requests = half_open_requests
self.state = CircuitState.CLOSED
self.failures = 0
self.last_failure_time = None
self.half_open_successes = 0
async def call(self, fn, *args, **kwargs):
if self.state == CircuitState.OPEN:
if self._should_try_recovery():
self.state = CircuitState.HALF_OPEN
self.half_open_successes = 0
else:
raise CircuitOpenError()
try:
result = await fn(*args, **kwargs)
self._on_success()
return result
except Exception as e:
self._on_failure()
raise
def _on_success(self):
if self.state == CircuitState.HALF_OPEN:
self.half_open_successes += 1
if self.half_open_successes >= self.half_open_requests:
self.state = CircuitState.CLOSED
self.failures = 0
else:
self.failures = 0
def _on_failure(self):
self.failures += 1
self.last_failure_time = datetime.now()
if self.failures >= self.failure_threshold:
self.state = CircuitState.OPEN
def _should_try_recovery(self):
if self.last_failure_time is None:
return True
return datetime.now() - self.last_failure_time > timedelta(
seconds=self.recovery_timeout
)
# Usage
circuit = CircuitBreaker(failure_threshold=5, recovery_timeout=30.0)
async def call_external_service():
return await circuit.call(external_api.fetch_data)
Rate Limiter¶
Control request rate.
import asyncio
import time
class RateLimiter:
"""Token bucket rate limiter."""
def __init__(self, rate: float, capacity: int = None):
self.rate = rate # tokens per second
self.capacity = capacity or int(rate)
self.tokens = self.capacity
self.last_update = time.monotonic()
self.lock = asyncio.Lock()
async def acquire(self):
async with self.lock:
now = time.monotonic()
elapsed = now - self.last_update
self.tokens = min(
self.capacity,
self.tokens + elapsed * self.rate
)
self.last_update = now
if self.tokens >= 1:
self.tokens -= 1
return
# Wait for token
wait_time = (1 - self.tokens) / self.rate
await asyncio.sleep(wait_time)
self.tokens = 0
# Usage
limiter = RateLimiter(rate=10) # 10 requests/second
async def rate_limited_fetch(url):
await limiter.acquire()
return await fetch(url)
When to Use Which Pattern¶
| Situation | Pattern |
|---|---|
| Decouple fast producer from slow consumer | Producer-Consumer |
| Process many items in parallel | Worker Pool |
| Multiple data sources to aggregate | Scatter-Gather |
| Chain of transformations | Pipeline |
| Protect from cascade failures | Bulkhead |
| Handle flaky services | Circuit Breaker |
| Control API rate | Rate Limiter |
Combining Patterns¶
# Real-world example: API client with resilience
class ResilientApiClient:
def __init__(self):
self.circuit = CircuitBreaker(failure_threshold=5)
self.rate_limiter = RateLimiter(rate=100)
self.bulkhead = Bulkhead(max_concurrent=20)
async def call(self, endpoint, data):
# Rate limit first
await self.rate_limiter.acquire()
# Then bulkhead
async with self.bulkhead.acquire(timeout=5.0):
# Then circuit breaker
return await self.circuit.call(
self._make_request,
endpoint,
data
)
async def _make_request(self, endpoint, data):
# Actual HTTP request
pass