subcog/embedding/
bulkhead.rs1use super::Embedder;
29use crate::{Error, Result};
30use std::sync::Arc;
31use std::time::Duration;
32use tokio::sync::Semaphore;
33
34#[derive(Debug, Clone)]
36pub struct EmbeddingBulkheadConfig {
37 pub max_concurrent: usize,
41
42 pub acquire_timeout_ms: u64,
46
47 pub fail_fast: bool,
51}
52
53impl Default for EmbeddingBulkheadConfig {
54 fn default() -> Self {
55 Self {
56 max_concurrent: 2,
57 acquire_timeout_ms: 30_000,
58 fail_fast: false,
59 }
60 }
61}
62
63impl EmbeddingBulkheadConfig {
64 #[must_use]
66 pub const fn new() -> Self {
67 Self {
68 max_concurrent: 2,
69 acquire_timeout_ms: 30_000,
70 fail_fast: false,
71 }
72 }
73
74 #[must_use]
82 pub fn from_env() -> Self {
83 Self::default().with_env_overrides()
84 }
85
86 #[must_use]
88 pub fn with_env_overrides(mut self) -> Self {
89 if let Ok(v) = std::env::var("SUBCOG_EMBEDDING_BULKHEAD_MAX_CONCURRENT")
90 && let Ok(parsed) = v.parse::<usize>()
91 {
92 self.max_concurrent = parsed.max(1);
93 }
94 if let Ok(v) = std::env::var("SUBCOG_EMBEDDING_BULKHEAD_ACQUIRE_TIMEOUT_MS")
95 && let Ok(parsed) = v.parse::<u64>()
96 {
97 self.acquire_timeout_ms = parsed;
98 }
99 if let Ok(v) = std::env::var("SUBCOG_EMBEDDING_BULKHEAD_FAIL_FAST") {
100 self.fail_fast = v.to_lowercase() == "true" || v == "1";
101 }
102 self
103 }
104
105 #[must_use]
107 pub const fn with_max_concurrent(mut self, max: usize) -> Self {
108 self.max_concurrent = max;
109 self
110 }
111
112 #[must_use]
114 pub const fn with_acquire_timeout_ms(mut self, timeout_ms: u64) -> Self {
115 self.acquire_timeout_ms = timeout_ms;
116 self
117 }
118
119 #[must_use]
121 pub const fn with_fail_fast(mut self, fail_fast: bool) -> Self {
122 self.fail_fast = fail_fast;
123 self
124 }
125}
126
127pub struct BulkheadEmbedder<E: Embedder> {
131 inner: E,
132 config: EmbeddingBulkheadConfig,
133 semaphore: Arc<Semaphore>,
134}
135
136impl<E: Embedder> BulkheadEmbedder<E> {
137 #[must_use]
139 pub fn new(inner: E, config: EmbeddingBulkheadConfig) -> Self {
140 let semaphore = Arc::new(Semaphore::new(config.max_concurrent.max(1)));
141 Self {
142 inner,
143 config,
144 semaphore,
145 }
146 }
147
148 #[must_use]
150 pub fn available_permits(&self) -> usize {
151 self.semaphore.available_permits()
152 }
153
154 fn acquire_permit(&self) -> Result<tokio::sync::OwnedSemaphorePermit> {
156 let semaphore = &self.semaphore;
157 let available = semaphore.available_permits();
158
159 metrics::gauge!("embedding_bulkhead_available_permits").set(available as f64);
160
161 if self.config.fail_fast {
162 return self.acquire_permit_fail_fast(semaphore, available);
163 }
164
165 self.acquire_permit_with_timeout(semaphore)
166 }
167
168 fn acquire_permit_fail_fast(
170 &self,
171 semaphore: &Arc<Semaphore>,
172 available: usize,
173 ) -> Result<tokio::sync::OwnedSemaphorePermit> {
174 Arc::clone(semaphore).try_acquire_owned().map_or_else(
175 |_| {
176 metrics::counter!("embedding_bulkhead_rejections_total", "reason" => "full")
177 .increment(1);
178 Err(Error::OperationFailed {
179 operation: "embedding_bulkhead_acquire".to_string(),
180 cause: format!(
181 "Embedding bulkhead full: {} concurrent operations (max: {})",
182 self.config.max_concurrent - available,
183 self.config.max_concurrent
184 ),
185 })
186 },
187 |permit| {
188 metrics::counter!("embedding_bulkhead_permits_acquired_total").increment(1);
189 Ok(permit)
190 },
191 )
192 }
193
194 fn acquire_permit_with_timeout(
196 &self,
197 semaphore: &Arc<Semaphore>,
198 ) -> Result<tokio::sync::OwnedSemaphorePermit> {
199 let timeout_ms = if self.config.acquire_timeout_ms == 0 {
200 120_000 } else {
202 self.config.acquire_timeout_ms
203 };
204 let timeout = Duration::from_millis(timeout_ms);
205 let start = std::time::Instant::now();
206
207 loop {
208 if let Ok(permit) = Arc::clone(semaphore).try_acquire_owned() {
209 metrics::counter!("embedding_bulkhead_permits_acquired_total").increment(1);
210 return Ok(permit);
211 }
212
213 if start.elapsed() >= timeout {
214 metrics::counter!("embedding_bulkhead_rejections_total", "reason" => "timeout")
215 .increment(1);
216 return Err(Error::OperationFailed {
217 operation: "embedding_bulkhead_acquire".to_string(),
218 cause: format!(
219 "Embedding bulkhead acquire timed out after {}ms",
220 timeout.as_millis()
221 ),
222 });
223 }
224
225 std::thread::sleep(Duration::from_millis(5));
226 }
227 }
228
229 fn execute<T, F>(&self, operation: &'static str, call: F) -> Result<T>
231 where
232 F: FnOnce() -> Result<T>,
233 {
234 let _permit = self.acquire_permit()?;
235
236 tracing::trace!(operation = operation, "Acquired embedding bulkhead permit");
237
238 let result = call();
239
240 tracing::trace!(
241 operation = operation,
242 success = result.is_ok(),
243 "Released embedding bulkhead permit"
244 );
245
246 result
247 }
248}
249
250impl<E: Embedder> Embedder for BulkheadEmbedder<E> {
251 fn dimensions(&self) -> usize {
252 self.inner.dimensions()
253 }
254
255 fn embed(&self, text: &str) -> Result<Vec<f32>> {
256 self.execute("embed", || self.inner.embed(text))
257 }
258
259 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
260 self.execute("embed_batch", || self.inner.embed_batch(texts))
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use std::sync::atomic::{AtomicUsize, Ordering};
268
269 struct MockEmbedder {
271 delay_ms: u64,
272 call_count: AtomicUsize,
273 }
274
275 impl MockEmbedder {
276 fn new(delay_ms: u64) -> Self {
277 Self {
278 delay_ms,
279 call_count: AtomicUsize::new(0),
280 }
281 }
282 }
283
284 impl Embedder for MockEmbedder {
285 fn dimensions(&self) -> usize {
286 384
287 }
288
289 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
290 self.call_count.fetch_add(1, Ordering::SeqCst);
291 if self.delay_ms > 0 {
292 std::thread::sleep(Duration::from_millis(self.delay_ms));
293 }
294 Ok(vec![0.0; 384])
295 }
296 }
297
298 #[test]
299 fn test_embedding_bulkhead_config_default() {
300 let config = EmbeddingBulkheadConfig::default();
301 assert_eq!(config.max_concurrent, 2);
302 assert_eq!(config.acquire_timeout_ms, 30_000);
303 assert!(!config.fail_fast);
304 }
305
306 #[test]
307 fn test_embedding_bulkhead_config_builder() {
308 let config = EmbeddingBulkheadConfig::new()
309 .with_max_concurrent(4)
310 .with_acquire_timeout_ms(10_000)
311 .with_fail_fast(true);
312
313 assert_eq!(config.max_concurrent, 4);
314 assert_eq!(config.acquire_timeout_ms, 10_000);
315 assert!(config.fail_fast);
316 }
317
318 #[test]
319 fn test_bulkhead_allows_operations_within_limit() {
320 let embedder = MockEmbedder::new(0);
321 let bulkhead = BulkheadEmbedder::new(embedder, EmbeddingBulkheadConfig::default());
322
323 let result = bulkhead.embed("test");
324 assert!(result.is_ok());
325 assert_eq!(result.unwrap().len(), 384);
326 }
327
328 #[test]
329 fn test_bulkhead_available_permits() {
330 let embedder = MockEmbedder::new(0);
331 let config = EmbeddingBulkheadConfig::new().with_max_concurrent(3);
332 let bulkhead = BulkheadEmbedder::new(embedder, config);
333
334 assert_eq!(bulkhead.available_permits(), 3);
335 }
336
337 #[test]
338 fn test_bulkhead_fail_fast_when_full() {
339 let embedder = MockEmbedder::new(100);
340 let config = EmbeddingBulkheadConfig::new()
341 .with_max_concurrent(1)
342 .with_fail_fast(true);
343 let bulkhead = Arc::new(BulkheadEmbedder::new(embedder, config));
344
345 let bulkhead_clone = Arc::clone(&bulkhead);
347 let handle = std::thread::spawn(move || bulkhead_clone.embed("slow"));
348
349 std::thread::sleep(Duration::from_millis(10));
351
352 let result = bulkhead.embed("fast");
354
355 let _ = handle.join();
356
357 if let Err(err) = result {
358 assert!(err.to_string().contains("bulkhead full"));
359 }
360 }
361
362 #[test]
363 fn test_bulkhead_dimensions_passthrough() {
364 let embedder = MockEmbedder::new(0);
365 let bulkhead = BulkheadEmbedder::new(embedder, EmbeddingBulkheadConfig::default());
366
367 assert_eq!(bulkhead.dimensions(), 384);
368 }
369}