Skip to main content

subcog/embedding/
bulkhead.rs

1//! Bulkhead pattern implementation for embedding operations.
2//!
3//! Provides concurrency limiting to prevent resource exhaustion when making
4//! parallel embedding calls. Uses a semaphore-based approach to limit the number
5//! of concurrent operations.
6//!
7//! # Why Bulkhead for Embeddings?
8//!
9//! Embedding generation is CPU and memory intensive:
10//!
11//! - **CPU**: ONNX runtime uses significant CPU per embedding
12//! - **Memory**: Model weights and intermediate tensors
13//! - **Batching**: Large batches can exhaust memory
14//! - **Latency**: Too many concurrent operations increase latency for all
15//!
16//! # Usage
17//!
18//! ```rust,ignore
19//! use subcog::embedding::{BulkheadEmbedder, EmbeddingBulkheadConfig, FastEmbedEmbedder};
20//!
21//! let embedder = FastEmbedEmbedder::new()?;
22//! let bulkhead = BulkheadEmbedder::new(embedder, EmbeddingBulkheadConfig::default());
23//!
24//! // Only 2 concurrent embedding operations allowed (default)
25//! let embedding = bulkhead.embed("Hello world")?;
26//! ```
27
28use super::Embedder;
29use crate::{Error, Result};
30use std::sync::Arc;
31use std::time::Duration;
32use tokio::sync::Semaphore;
33
34/// Configuration for the embedding bulkhead pattern.
35#[derive(Debug, Clone)]
36pub struct EmbeddingBulkheadConfig {
37    /// Maximum concurrent embedding operations allowed.
38    ///
39    /// Default: 2 (conservative due to CPU/memory intensity).
40    pub max_concurrent: usize,
41
42    /// Timeout for acquiring a permit in milliseconds (0 = no timeout).
43    ///
44    /// Default: 30000ms (30 seconds - embeddings can be slow).
45    pub acquire_timeout_ms: u64,
46
47    /// Whether to fail fast when bulkhead is full (vs. waiting).
48    ///
49    /// Default: false (wait for permit).
50    pub fail_fast: bool,
51}
52
53impl Default for EmbeddingBulkheadConfig {
54    fn default() -> Self {
55        Self {
56            max_concurrent: 2,
57            acquire_timeout_ms: 30_000,
58            fail_fast: false,
59        }
60    }
61}
62
63impl EmbeddingBulkheadConfig {
64    /// Creates a new embedding bulkhead configuration.
65    #[must_use]
66    pub const fn new() -> Self {
67        Self {
68            max_concurrent: 2,
69            acquire_timeout_ms: 30_000,
70            fail_fast: false,
71        }
72    }
73
74    /// Loads configuration from environment variables.
75    ///
76    /// | Variable | Description | Default |
77    /// |----------|-------------|---------|
78    /// | `SUBCOG_EMBEDDING_BULKHEAD_MAX_CONCURRENT` | Max concurrent ops | 2 |
79    /// | `SUBCOG_EMBEDDING_BULKHEAD_ACQUIRE_TIMEOUT_MS` | Permit timeout | 30000 |
80    /// | `SUBCOG_EMBEDDING_BULKHEAD_FAIL_FAST` | Fail when full | false |
81    #[must_use]
82    pub fn from_env() -> Self {
83        Self::default().with_env_overrides()
84    }
85
86    /// Applies environment variable overrides.
87    #[must_use]
88    pub fn with_env_overrides(mut self) -> Self {
89        if let Ok(v) = std::env::var("SUBCOG_EMBEDDING_BULKHEAD_MAX_CONCURRENT")
90            && let Ok(parsed) = v.parse::<usize>()
91        {
92            self.max_concurrent = parsed.max(1);
93        }
94        if let Ok(v) = std::env::var("SUBCOG_EMBEDDING_BULKHEAD_ACQUIRE_TIMEOUT_MS")
95            && let Ok(parsed) = v.parse::<u64>()
96        {
97            self.acquire_timeout_ms = parsed;
98        }
99        if let Ok(v) = std::env::var("SUBCOG_EMBEDDING_BULKHEAD_FAIL_FAST") {
100            self.fail_fast = v.to_lowercase() == "true" || v == "1";
101        }
102        self
103    }
104
105    /// Sets the maximum concurrent operations.
106    #[must_use]
107    pub const fn with_max_concurrent(mut self, max: usize) -> Self {
108        self.max_concurrent = max;
109        self
110    }
111
112    /// Sets the acquire timeout in milliseconds.
113    #[must_use]
114    pub const fn with_acquire_timeout_ms(mut self, timeout_ms: u64) -> Self {
115        self.acquire_timeout_ms = timeout_ms;
116        self
117    }
118
119    /// Sets whether to fail fast when the bulkhead is full.
120    #[must_use]
121    pub const fn with_fail_fast(mut self, fail_fast: bool) -> Self {
122        self.fail_fast = fail_fast;
123        self
124    }
125}
126
127/// Embedder wrapper with bulkhead (concurrency limiting) pattern.
128///
129/// Limits the number of concurrent embedding operations to prevent resource exhaustion.
130pub struct BulkheadEmbedder<E: Embedder> {
131    inner: E,
132    config: EmbeddingBulkheadConfig,
133    semaphore: Arc<Semaphore>,
134}
135
136impl<E: Embedder> BulkheadEmbedder<E> {
137    /// Creates a new bulkhead-wrapped embedder.
138    #[must_use]
139    pub fn new(inner: E, config: EmbeddingBulkheadConfig) -> Self {
140        let semaphore = Arc::new(Semaphore::new(config.max_concurrent.max(1)));
141        Self {
142            inner,
143            config,
144            semaphore,
145        }
146    }
147
148    /// Returns the current number of available permits.
149    #[must_use]
150    pub fn available_permits(&self) -> usize {
151        self.semaphore.available_permits()
152    }
153
154    /// Acquires a permit, respecting the configured timeout and fail-fast settings.
155    fn acquire_permit(&self) -> Result<tokio::sync::OwnedSemaphorePermit> {
156        let semaphore = &self.semaphore;
157        let available = semaphore.available_permits();
158
159        metrics::gauge!("embedding_bulkhead_available_permits").set(available as f64);
160
161        if self.config.fail_fast {
162            return self.acquire_permit_fail_fast(semaphore, available);
163        }
164
165        self.acquire_permit_with_timeout(semaphore)
166    }
167
168    /// Fast-fail acquisition that returns error immediately if bulkhead is full.
169    fn acquire_permit_fail_fast(
170        &self,
171        semaphore: &Arc<Semaphore>,
172        available: usize,
173    ) -> Result<tokio::sync::OwnedSemaphorePermit> {
174        Arc::clone(semaphore).try_acquire_owned().map_or_else(
175            |_| {
176                metrics::counter!("embedding_bulkhead_rejections_total", "reason" => "full")
177                    .increment(1);
178                Err(Error::OperationFailed {
179                    operation: "embedding_bulkhead_acquire".to_string(),
180                    cause: format!(
181                        "Embedding bulkhead full: {} concurrent operations (max: {})",
182                        self.config.max_concurrent - available,
183                        self.config.max_concurrent
184                    ),
185                })
186            },
187            |permit| {
188                metrics::counter!("embedding_bulkhead_permits_acquired_total").increment(1);
189                Ok(permit)
190            },
191        )
192    }
193
194    /// Acquisition with timeout that waits for a permit.
195    fn acquire_permit_with_timeout(
196        &self,
197        semaphore: &Arc<Semaphore>,
198    ) -> Result<tokio::sync::OwnedSemaphorePermit> {
199        let timeout_ms = if self.config.acquire_timeout_ms == 0 {
200            120_000 // 2 minute safety cap
201        } else {
202            self.config.acquire_timeout_ms
203        };
204        let timeout = Duration::from_millis(timeout_ms);
205        let start = std::time::Instant::now();
206
207        loop {
208            if let Ok(permit) = Arc::clone(semaphore).try_acquire_owned() {
209                metrics::counter!("embedding_bulkhead_permits_acquired_total").increment(1);
210                return Ok(permit);
211            }
212
213            if start.elapsed() >= timeout {
214                metrics::counter!("embedding_bulkhead_rejections_total", "reason" => "timeout")
215                    .increment(1);
216                return Err(Error::OperationFailed {
217                    operation: "embedding_bulkhead_acquire".to_string(),
218                    cause: format!(
219                        "Embedding bulkhead acquire timed out after {}ms",
220                        timeout.as_millis()
221                    ),
222                });
223            }
224
225            std::thread::sleep(Duration::from_millis(5));
226        }
227    }
228
229    /// Executes an operation with bulkhead protection.
230    fn execute<T, F>(&self, operation: &'static str, call: F) -> Result<T>
231    where
232        F: FnOnce() -> Result<T>,
233    {
234        let _permit = self.acquire_permit()?;
235
236        tracing::trace!(operation = operation, "Acquired embedding bulkhead permit");
237
238        let result = call();
239
240        tracing::trace!(
241            operation = operation,
242            success = result.is_ok(),
243            "Released embedding bulkhead permit"
244        );
245
246        result
247    }
248}
249
250impl<E: Embedder> Embedder for BulkheadEmbedder<E> {
251    fn dimensions(&self) -> usize {
252        self.inner.dimensions()
253    }
254
255    fn embed(&self, text: &str) -> Result<Vec<f32>> {
256        self.execute("embed", || self.inner.embed(text))
257    }
258
259    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
260        self.execute("embed_batch", || self.inner.embed_batch(texts))
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use std::sync::atomic::{AtomicUsize, Ordering};
268
269    // Mock embedder for testing
270    struct MockEmbedder {
271        delay_ms: u64,
272        call_count: AtomicUsize,
273    }
274
275    impl MockEmbedder {
276        fn new(delay_ms: u64) -> Self {
277            Self {
278                delay_ms,
279                call_count: AtomicUsize::new(0),
280            }
281        }
282    }
283
284    impl Embedder for MockEmbedder {
285        fn dimensions(&self) -> usize {
286            384
287        }
288
289        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
290            self.call_count.fetch_add(1, Ordering::SeqCst);
291            if self.delay_ms > 0 {
292                std::thread::sleep(Duration::from_millis(self.delay_ms));
293            }
294            Ok(vec![0.0; 384])
295        }
296    }
297
298    #[test]
299    fn test_embedding_bulkhead_config_default() {
300        let config = EmbeddingBulkheadConfig::default();
301        assert_eq!(config.max_concurrent, 2);
302        assert_eq!(config.acquire_timeout_ms, 30_000);
303        assert!(!config.fail_fast);
304    }
305
306    #[test]
307    fn test_embedding_bulkhead_config_builder() {
308        let config = EmbeddingBulkheadConfig::new()
309            .with_max_concurrent(4)
310            .with_acquire_timeout_ms(10_000)
311            .with_fail_fast(true);
312
313        assert_eq!(config.max_concurrent, 4);
314        assert_eq!(config.acquire_timeout_ms, 10_000);
315        assert!(config.fail_fast);
316    }
317
318    #[test]
319    fn test_bulkhead_allows_operations_within_limit() {
320        let embedder = MockEmbedder::new(0);
321        let bulkhead = BulkheadEmbedder::new(embedder, EmbeddingBulkheadConfig::default());
322
323        let result = bulkhead.embed("test");
324        assert!(result.is_ok());
325        assert_eq!(result.unwrap().len(), 384);
326    }
327
328    #[test]
329    fn test_bulkhead_available_permits() {
330        let embedder = MockEmbedder::new(0);
331        let config = EmbeddingBulkheadConfig::new().with_max_concurrent(3);
332        let bulkhead = BulkheadEmbedder::new(embedder, config);
333
334        assert_eq!(bulkhead.available_permits(), 3);
335    }
336
337    #[test]
338    fn test_bulkhead_fail_fast_when_full() {
339        let embedder = MockEmbedder::new(100);
340        let config = EmbeddingBulkheadConfig::new()
341            .with_max_concurrent(1)
342            .with_fail_fast(true);
343        let bulkhead = Arc::new(BulkheadEmbedder::new(embedder, config));
344
345        // Start a slow operation in another thread
346        let bulkhead_clone = Arc::clone(&bulkhead);
347        let handle = std::thread::spawn(move || bulkhead_clone.embed("slow"));
348
349        // Give the thread time to acquire the permit
350        std::thread::sleep(Duration::from_millis(10));
351
352        // This might fail if the bulkhead is full
353        let result = bulkhead.embed("fast");
354
355        let _ = handle.join();
356
357        if let Err(err) = result {
358            assert!(err.to_string().contains("bulkhead full"));
359        }
360    }
361
362    #[test]
363    fn test_bulkhead_dimensions_passthrough() {
364        let embedder = MockEmbedder::new(0);
365        let bulkhead = BulkheadEmbedder::new(embedder, EmbeddingBulkheadConfig::default());
366
367        assert_eq!(bulkhead.dimensions(), 384);
368    }
369}