1use super::traits::{IndexBackend, PersistenceBackend, VectorBackend, VectorFilter};
42use crate::models::{Memory, MemoryId, SearchFilter};
43use crate::{Error, Result};
44use std::sync::Arc;
45use std::time::Duration;
46use tokio::sync::{OwnedSemaphorePermit, Semaphore};
47
48#[derive(Debug, Clone)]
50pub struct StorageBulkheadConfig {
51 pub max_concurrent: usize,
55
56 pub acquire_timeout_ms: u64,
60
61 pub fail_fast: bool,
65}
66
67impl Default for StorageBulkheadConfig {
68 fn default() -> Self {
69 Self {
70 max_concurrent: 10,
71 acquire_timeout_ms: 5000,
72 fail_fast: false,
73 }
74 }
75}
76
77impl StorageBulkheadConfig {
78 #[must_use]
80 pub const fn new() -> Self {
81 Self {
82 max_concurrent: 10,
83 acquire_timeout_ms: 5000,
84 fail_fast: false,
85 }
86 }
87
88 #[must_use]
96 pub fn from_env() -> Self {
97 Self::default().with_env_overrides()
98 }
99
100 #[must_use]
102 pub fn with_env_overrides(mut self) -> Self {
103 if let Ok(v) = std::env::var("SUBCOG_STORAGE_BULKHEAD_MAX_CONCURRENT")
104 && let Ok(parsed) = v.parse::<usize>()
105 {
106 self.max_concurrent = parsed.max(1);
107 }
108 if let Ok(v) = std::env::var("SUBCOG_STORAGE_BULKHEAD_ACQUIRE_TIMEOUT_MS")
109 && let Ok(parsed) = v.parse::<u64>()
110 {
111 self.acquire_timeout_ms = parsed;
112 }
113 if let Ok(v) = std::env::var("SUBCOG_STORAGE_BULKHEAD_FAIL_FAST") {
114 self.fail_fast = v.to_lowercase() == "true" || v == "1";
115 }
116 self
117 }
118
119 #[must_use]
121 pub const fn with_max_concurrent(mut self, max: usize) -> Self {
122 self.max_concurrent = max;
123 self
124 }
125
126 #[must_use]
128 pub const fn with_acquire_timeout_ms(mut self, timeout_ms: u64) -> Self {
129 self.acquire_timeout_ms = timeout_ms;
130 self
131 }
132
133 #[must_use]
135 pub const fn with_fail_fast(mut self, fail_fast: bool) -> Self {
136 self.fail_fast = fail_fast;
137 self
138 }
139}
140
141pub struct Bulkhead<T> {
155 inner: T,
156 config: StorageBulkheadConfig,
157 semaphore: Arc<Semaphore>,
158 backend_name: &'static str,
159}
160
161impl<T> Bulkhead<T> {
162 #[must_use]
164 pub fn new(inner: T, config: StorageBulkheadConfig, backend_name: &'static str) -> Self {
165 let semaphore = Arc::new(Semaphore::new(config.max_concurrent.max(1)));
166 Self {
167 inner,
168 config,
169 semaphore,
170 backend_name,
171 }
172 }
173
174 #[must_use]
176 pub const fn inner(&self) -> &T {
177 &self.inner
178 }
179
180 #[must_use]
182 pub const fn backend_name(&self) -> &'static str {
183 self.backend_name
184 }
185
186 #[must_use]
188 pub fn available_permits(&self) -> usize {
189 self.semaphore.available_permits()
190 }
191
192 fn acquire_permit(&self, operation_prefix: &str) -> Result<OwnedSemaphorePermit> {
194 let semaphore = &self.semaphore;
195 let available = semaphore.available_permits();
196
197 metrics::gauge!(
198 "storage_bulkhead_available_permits",
199 "backend" => self.backend_name
200 )
201 .set(available as f64);
202
203 if self.config.fail_fast {
204 return self.acquire_permit_fail_fast(semaphore, available, operation_prefix);
205 }
206
207 let timeout_ms = if self.config.acquire_timeout_ms == 0 {
208 60_000 } else {
210 self.config.acquire_timeout_ms
211 };
212
213 self.acquire_permit_with_timeout(timeout_ms, operation_prefix)
214 }
215
216 fn acquire_permit_fail_fast(
218 &self,
219 semaphore: &Arc<Semaphore>,
220 available: usize,
221 operation_prefix: &str,
222 ) -> Result<OwnedSemaphorePermit> {
223 Arc::clone(semaphore).try_acquire_owned().map_or_else(
224 |_| {
225 metrics::counter!(
226 "storage_bulkhead_rejections_total",
227 "backend" => self.backend_name,
228 "reason" => "full"
229 )
230 .increment(1);
231 Err(Error::OperationFailed {
232 operation: format!("{operation_prefix}_bulkhead_acquire"),
233 cause: format!(
234 "{} bulkhead full: {} concurrent operations (max: {})",
235 capitalize_first(operation_prefix),
236 self.config.max_concurrent - available,
237 self.config.max_concurrent
238 ),
239 })
240 },
241 |permit| {
242 metrics::counter!(
243 "storage_bulkhead_permits_acquired_total",
244 "backend" => self.backend_name
245 )
246 .increment(1);
247 Ok(permit)
248 },
249 )
250 }
251
252 fn acquire_permit_with_timeout(
254 &self,
255 timeout_ms: u64,
256 operation_prefix: &str,
257 ) -> Result<OwnedSemaphorePermit> {
258 let timeout = Duration::from_millis(timeout_ms);
259 let start = std::time::Instant::now();
260
261 loop {
262 if let Ok(permit) = Arc::clone(&self.semaphore).try_acquire_owned() {
263 metrics::counter!(
264 "storage_bulkhead_permits_acquired_total",
265 "backend" => self.backend_name
266 )
267 .increment(1);
268 return Ok(permit);
269 }
270
271 if start.elapsed() >= timeout {
272 metrics::counter!(
273 "storage_bulkhead_rejections_total",
274 "backend" => self.backend_name,
275 "reason" => "timeout"
276 )
277 .increment(1);
278 return Err(Error::OperationFailed {
279 operation: format!("{operation_prefix}_bulkhead_acquire"),
280 cause: format!(
281 "{} bulkhead acquire timed out after {timeout_ms}ms",
282 capitalize_first(operation_prefix)
283 ),
284 });
285 }
286
287 std::thread::sleep(Duration::from_millis(1));
288 }
289 }
290
291 pub fn execute<R, F>(
300 &self,
301 operation: &'static str,
302 operation_prefix: &str,
303 call: F,
304 ) -> Result<R>
305 where
306 F: FnOnce(&T) -> Result<R>,
307 {
308 let _permit = self.acquire_permit(operation_prefix)?;
309
310 tracing::trace!(
311 backend = self.backend_name,
312 operation = operation,
313 "Acquired bulkhead permit"
314 );
315
316 let result = call(&self.inner);
317
318 tracing::trace!(
319 backend = self.backend_name,
320 operation = operation,
321 success = result.is_ok(),
322 "Released bulkhead permit"
323 );
324
325 result
326 }
327
328 pub fn execute_quiet<R, F>(&self, operation_prefix: &str, call: F) -> Result<R>
336 where
337 F: FnOnce(&T) -> Result<R>,
338 {
339 let _permit = self.acquire_permit(operation_prefix)?;
340 call(&self.inner)
341 }
342}
343
344fn capitalize_first(s: &str) -> String {
346 let mut chars = s.chars();
347 chars.next().map_or_else(String::new, |first| {
348 first.to_uppercase().chain(chars).collect()
349 })
350}
351
352pub struct BulkheadPersistenceBackend<P: PersistenceBackend> {
358 bulkhead: Bulkhead<P>,
359}
360
361impl<P: PersistenceBackend> BulkheadPersistenceBackend<P> {
362 #[must_use]
364 pub fn new(inner: P, config: StorageBulkheadConfig, backend_name: &'static str) -> Self {
365 Self {
366 bulkhead: Bulkhead::new(inner, config, backend_name),
367 }
368 }
369
370 #[must_use]
372 pub fn available_permits(&self) -> usize {
373 self.bulkhead.available_permits()
374 }
375}
376
377#[allow(clippy::redundant_closure_for_method_calls)]
378impl<P: PersistenceBackend> PersistenceBackend for BulkheadPersistenceBackend<P> {
379 fn store(&self, memory: &Memory) -> Result<()> {
380 self.bulkhead
381 .execute("store", "storage", |inner| inner.store(memory))
382 }
383
384 fn get(&self, id: &MemoryId) -> Result<Option<Memory>> {
385 self.bulkhead
386 .execute("get", "storage", |inner| inner.get(id))
387 }
388
389 fn get_batch(&self, ids: &[MemoryId]) -> Result<Vec<Memory>> {
390 self.bulkhead
391 .execute("get_batch", "storage", |inner| inner.get_batch(ids))
392 }
393
394 fn delete(&self, id: &MemoryId) -> Result<bool> {
395 self.bulkhead
396 .execute("delete", "storage", |inner| inner.delete(id))
397 }
398
399 fn exists(&self, id: &MemoryId) -> Result<bool> {
400 self.bulkhead
401 .execute("exists", "storage", |inner| inner.exists(id))
402 }
403
404 fn list_ids(&self) -> Result<Vec<MemoryId>> {
405 self.bulkhead
406 .execute("list_ids", "storage", |inner| inner.list_ids())
407 }
408
409 fn count(&self) -> Result<usize> {
410 self.bulkhead
411 .execute("count", "storage", |inner| inner.count())
412 }
413}
414
415pub struct BulkheadIndexBackend<I: IndexBackend> {
421 bulkhead: Bulkhead<I>,
422}
423
424impl<I: IndexBackend> BulkheadIndexBackend<I> {
425 #[must_use]
427 pub fn new(inner: I, config: StorageBulkheadConfig, backend_name: &'static str) -> Self {
428 Self {
429 bulkhead: Bulkhead::new(inner, config, backend_name),
430 }
431 }
432
433 #[must_use]
435 pub fn available_permits(&self) -> usize {
436 self.bulkhead.available_permits()
437 }
438}
439
440#[allow(clippy::redundant_closure_for_method_calls)]
441impl<I: IndexBackend> IndexBackend for BulkheadIndexBackend<I> {
442 fn index(&self, memory: &Memory) -> Result<()> {
443 self.bulkhead
444 .execute_quiet("index", |inner| inner.index(memory))
445 }
446
447 fn remove(&self, id: &MemoryId) -> Result<bool> {
448 self.bulkhead
449 .execute_quiet("index", |inner| inner.remove(id))
450 }
451
452 fn search(
453 &self,
454 query: &str,
455 filter: &SearchFilter,
456 limit: usize,
457 ) -> Result<Vec<(MemoryId, f32)>> {
458 self.bulkhead
459 .execute_quiet("index", |inner| inner.search(query, filter, limit))
460 }
461
462 fn reindex(&self, memories: &[Memory]) -> Result<()> {
463 self.bulkhead
464 .execute_quiet("index", |inner| inner.reindex(memories))
465 }
466
467 fn clear(&self) -> Result<()> {
468 self.bulkhead.execute_quiet("index", |inner| inner.clear())
469 }
470
471 fn list_all(&self, filter: &SearchFilter, limit: usize) -> Result<Vec<(MemoryId, f32)>> {
472 self.bulkhead
473 .execute_quiet("index", |inner| inner.list_all(filter, limit))
474 }
475
476 fn get_memory(&self, id: &MemoryId) -> Result<Option<Memory>> {
477 self.bulkhead
478 .execute_quiet("index", |inner| inner.get_memory(id))
479 }
480
481 fn get_memories_batch(&self, ids: &[MemoryId]) -> Result<Vec<Option<Memory>>> {
482 self.bulkhead
483 .execute_quiet("index", |inner| inner.get_memories_batch(ids))
484 }
485}
486
487pub struct BulkheadVectorBackend<V: VectorBackend> {
493 bulkhead: Bulkhead<V>,
494}
495
496impl<V: VectorBackend> BulkheadVectorBackend<V> {
497 #[must_use]
499 pub fn new(inner: V, config: StorageBulkheadConfig, backend_name: &'static str) -> Self {
500 Self {
501 bulkhead: Bulkhead::new(inner, config, backend_name),
502 }
503 }
504
505 #[must_use]
507 pub fn available_permits(&self) -> usize {
508 self.bulkhead.available_permits()
509 }
510}
511
512#[allow(clippy::redundant_closure_for_method_calls)]
513impl<V: VectorBackend> VectorBackend for BulkheadVectorBackend<V> {
514 fn dimensions(&self) -> usize {
515 self.bulkhead.inner().dimensions()
516 }
517
518 fn upsert(&self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
519 self.bulkhead
520 .execute("upsert", "vector", |inner| inner.upsert(id, embedding))
521 }
522
523 fn remove(&self, id: &MemoryId) -> Result<bool> {
524 self.bulkhead
525 .execute("remove", "vector", |inner| inner.remove(id))
526 }
527
528 fn search(
529 &self,
530 query_embedding: &[f32],
531 filter: &VectorFilter,
532 limit: usize,
533 ) -> Result<Vec<(MemoryId, f32)>> {
534 self.bulkhead.execute("search", "vector", |inner| {
535 inner.search(query_embedding, filter, limit)
536 })
537 }
538
539 fn count(&self) -> Result<usize> {
540 self.bulkhead
541 .execute("count", "vector", |inner| inner.count())
542 }
543
544 fn clear(&self) -> Result<()> {
545 self.bulkhead
546 .execute("clear", "vector", |inner| inner.clear())
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use crate::models::{Domain, Memory, MemoryId, MemoryStatus, Namespace};
554 use std::sync::atomic::{AtomicUsize, Ordering};
555
556 fn create_test_memory(content: &str) -> Memory {
558 Memory {
559 id: MemoryId::new("test-memory"),
560 content: content.to_string(),
561 namespace: Namespace::Decisions,
562 domain: Domain::default(),
563 project_id: None,
564 branch: None,
565 file_path: None,
566 status: MemoryStatus::Active,
567 created_at: 0,
568 updated_at: 0,
569 tombstoned_at: None,
570 expires_at: None,
571 embedding: None,
572 tags: vec![],
573 #[cfg(feature = "group-scope")]
574 group_id: None,
575 source: None,
576 is_summary: false,
577 source_memory_ids: None,
578 consolidation_timestamp: None,
579 }
580 }
581
582 struct MockPersistence {
584 delay_ms: u64,
585 call_count: AtomicUsize,
586 }
587
588 impl MockPersistence {
589 fn new(delay_ms: u64) -> Self {
590 Self {
591 delay_ms,
592 call_count: AtomicUsize::new(0),
593 }
594 }
595 }
596
597 impl PersistenceBackend for MockPersistence {
598 fn store(&self, _memory: &Memory) -> Result<()> {
599 self.call_count.fetch_add(1, Ordering::SeqCst);
600 if self.delay_ms > 0 {
601 std::thread::sleep(Duration::from_millis(self.delay_ms));
602 }
603 Ok(())
604 }
605
606 fn get(&self, _id: &MemoryId) -> Result<Option<Memory>> {
607 if self.delay_ms > 0 {
608 std::thread::sleep(Duration::from_millis(self.delay_ms));
609 }
610 Ok(None)
611 }
612
613 fn delete(&self, _id: &MemoryId) -> Result<bool> {
614 Ok(true)
615 }
616
617 fn list_ids(&self) -> Result<Vec<MemoryId>> {
618 Ok(vec![])
619 }
620 }
621
622 #[test]
623 fn test_storage_bulkhead_config_default() {
624 let config = StorageBulkheadConfig::default();
625 assert_eq!(config.max_concurrent, 10);
626 assert_eq!(config.acquire_timeout_ms, 5000);
627 assert!(!config.fail_fast);
628 }
629
630 #[test]
631 fn test_storage_bulkhead_config_builder() {
632 let config = StorageBulkheadConfig::new()
633 .with_max_concurrent(20)
634 .with_acquire_timeout_ms(10_000)
635 .with_fail_fast(true);
636
637 assert_eq!(config.max_concurrent, 20);
638 assert_eq!(config.acquire_timeout_ms, 10_000);
639 assert!(config.fail_fast);
640 }
641
642 #[test]
643 fn test_bulkhead_allows_operations_within_limit() {
644 let backend = MockPersistence::new(0);
645 let bulkhead =
646 BulkheadPersistenceBackend::new(backend, StorageBulkheadConfig::default(), "mock");
647
648 let memory = create_test_memory("test content");
649
650 let result = bulkhead.store(&memory);
651 assert!(result.is_ok());
652 }
653
654 #[test]
655 fn test_bulkhead_available_permits() {
656 let backend = MockPersistence::new(0);
657 let config = StorageBulkheadConfig::new().with_max_concurrent(5);
658 let bulkhead = BulkheadPersistenceBackend::new(backend, config, "mock");
659
660 assert_eq!(bulkhead.available_permits(), 5);
661 }
662
663 #[test]
664 fn test_bulkhead_fail_fast_when_full() {
665 let backend = MockPersistence::new(100);
666 let config = StorageBulkheadConfig::new()
667 .with_max_concurrent(1)
668 .with_fail_fast(true);
669 let bulkhead = Arc::new(BulkheadPersistenceBackend::new(backend, config, "mock"));
670
671 let memory = create_test_memory("test content");
672
673 let bulkhead_clone = Arc::clone(&bulkhead);
675 let memory_clone = memory.clone();
676 let handle = std::thread::spawn(move || bulkhead_clone.store(&memory_clone));
677
678 std::thread::sleep(Duration::from_millis(10));
680
681 let result = bulkhead.store(&memory);
683
684 let _ = handle.join();
685
686 if let Err(err) = result {
688 assert!(err.to_string().contains("bulkhead full"));
689 }
690 }
691
692 #[test]
693 fn test_bulkhead_timeout() {
694 let backend = MockPersistence::new(200);
695 let config = StorageBulkheadConfig::new()
696 .with_max_concurrent(1)
697 .with_acquire_timeout_ms(50);
698 let bulkhead = Arc::new(BulkheadPersistenceBackend::new(backend, config, "mock"));
699
700 let memory = create_test_memory("test content");
701
702 let bulkhead_clone = Arc::clone(&bulkhead);
704 let memory_clone = memory.clone();
705 let handle = std::thread::spawn(move || bulkhead_clone.store(&memory_clone));
706
707 std::thread::sleep(Duration::from_millis(10));
708
709 let result = bulkhead.store(&memory);
711
712 let _ = handle.join();
713
714 if let Err(err) = result {
715 assert!(err.to_string().contains("timed out"));
716 }
717 }
718}