1use crate::models::MemoryId;
51use crate::storage::traits::{VectorBackend, VectorFilter};
52use crate::{Error, Result};
53
54#[cfg(feature = "redis")]
55use crate::storage::resilience::{StorageResilienceConfig, retry_connection};
56#[cfg(feature = "redis")]
57use redis::{Client, Commands, Connection, RedisResult};
58#[cfg(feature = "redis")]
59use std::sync::Mutex;
60#[cfg(feature = "redis")]
61use std::time::Duration;
62
63#[cfg(feature = "redis")]
67fn default_redis_timeout() -> Duration {
68 crate::config::OperationTimeoutConfig::from_env().get(crate::config::OperationType::Redis)
69}
70
71pub struct RedisVectorBackend {
86 connection_url: String,
88 index_name: String,
90 dimensions: usize,
92 #[cfg(feature = "redis")]
94 client: Client,
95 #[cfg(feature = "redis")]
97 connection: Mutex<Option<Connection>>,
98 #[cfg(feature = "redis")]
100 index_created: Mutex<bool>,
101 #[cfg(feature = "redis")]
103 timeout: Duration,
104}
105
106impl RedisVectorBackend {
107 pub const DEFAULT_DIMENSIONS: usize = crate::embedding::DEFAULT_DIMENSIONS;
111
112 #[cfg(feature = "redis")]
118 pub fn new(
119 connection_url: impl Into<String>,
120 index_name: impl Into<String>,
121 dimensions: usize,
122 ) -> Result<Self> {
123 let connection_url = connection_url.into();
124 let client = Client::open(connection_url.as_str()).map_err(|e| Error::OperationFailed {
125 operation: "redis_connect".to_string(),
126 cause: e.to_string(),
127 })?;
128
129 Ok(Self {
130 connection_url,
131 index_name: index_name.into(),
132 dimensions,
133 client,
134 connection: Mutex::new(None),
135 index_created: Mutex::new(false),
136 timeout: default_redis_timeout(),
137 })
138 }
139
140 #[cfg(not(feature = "redis"))]
142 #[must_use]
143 pub fn new(
144 connection_url: impl Into<String>,
145 index_name: impl Into<String>,
146 dimensions: usize,
147 ) -> Self {
148 Self {
149 connection_url: connection_url.into(),
150 index_name: index_name.into(),
151 dimensions,
152 }
153 }
154
155 #[cfg(feature = "redis")]
161 pub fn with_defaults() -> Result<Self> {
162 Self::new(
163 "redis://localhost:6379",
164 "subcog_vectors",
165 Self::DEFAULT_DIMENSIONS,
166 )
167 }
168
169 #[cfg(not(feature = "redis"))]
171 #[must_use]
172 pub fn with_defaults() -> Self {
173 Self::new(
174 "redis://localhost:6379",
175 "subcog_vectors",
176 Self::DEFAULT_DIMENSIONS,
177 )
178 }
179
180 #[must_use]
182 pub fn connection_url(&self) -> &str {
183 &self.connection_url
184 }
185
186 #[must_use]
188 pub fn index_name(&self) -> &str {
189 &self.index_name
190 }
191
192 #[cfg(feature = "redis")]
216 pub fn health_check(&self) -> Result<bool> {
217 let mut conn = match self.get_connection() {
218 Ok(c) => c,
219 Err(_) => return Ok(false),
220 };
221
222 let result: redis::RedisResult<String> = redis::cmd("PING").query(&mut conn);
223
224 let healthy = result.is_ok_and(|response| response == "PONG");
225
226 self.return_connection(conn);
227 Ok(healthy)
228 }
229
230 #[cfg(not(feature = "redis"))]
236 pub fn health_check(&self) -> Result<bool> {
237 Err(Error::FeatureNotEnabled("redis".to_string()))
238 }
239}
240
241#[cfg(feature = "redis")]
242impl RedisVectorBackend {
243 fn key_prefix(&self) -> String {
245 format!("{}:", self.index_name)
246 }
247
248 fn memory_key(&self, id: &MemoryId) -> String {
250 format!("{}:{}", self.index_name, id.as_str())
251 }
252
253 fn validate_embedding(&self, embedding: &[f32]) -> Result<()> {
255 if embedding.len() != self.dimensions {
256 return Err(Error::InvalidInput(format!(
257 "Embedding dimension mismatch: expected {}, got {}",
258 self.dimensions,
259 embedding.len()
260 )));
261 }
262 Ok(())
263 }
264
265 fn vector_to_bytes(embedding: &[f32]) -> Vec<u8> {
267 embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
268 }
269
270 fn get_connection(&self) -> Result<Connection> {
285 let mut guard = self.connection.lock().map_err(|e| Error::OperationFailed {
287 operation: "redis_lock_connection".to_string(),
288 cause: e.to_string(),
289 })?;
290
291 if let Some(conn) = guard.take() {
293 return Ok(conn);
294 }
295 drop(guard); let resilience_config = StorageResilienceConfig::from_env();
299 let timeout = self.timeout;
300
301 retry_connection(&resilience_config, "redis_vector", "get_connection", || {
302 let conn = self
303 .client
304 .get_connection()
305 .map_err(|e| Error::OperationFailed {
306 operation: "redis_get_connection".to_string(),
307 cause: e.to_string(),
308 })?;
309
310 conn.set_read_timeout(Some(timeout))
312 .map_err(|e| Error::OperationFailed {
313 operation: "redis_set_read_timeout".to_string(),
314 cause: e.to_string(),
315 })?;
316 conn.set_write_timeout(Some(timeout))
317 .map_err(|e| Error::OperationFailed {
318 operation: "redis_set_write_timeout".to_string(),
319 cause: e.to_string(),
320 })?;
321
322 Ok(conn)
323 })
324 }
325
326 fn return_connection(&self, conn: Connection) {
328 if let Ok(mut guard) = self.connection.lock() {
329 *guard = Some(conn);
330 }
331 }
333
334 fn ensure_index(&self, conn: &mut Connection) -> Result<()> {
336 {
338 let guard = self
339 .index_created
340 .lock()
341 .map_err(|e| Error::OperationFailed {
342 operation: "redis_lock_index_created".to_string(),
343 cause: e.to_string(),
344 })?;
345 if *guard {
346 return Ok(());
347 }
348 }
349
350 let info_result: RedisResult<redis::Value> =
352 redis::cmd("FT.INFO").arg(&self.index_name).query(conn);
353
354 if info_result.is_ok() {
355 let mut guard = self
356 .index_created
357 .lock()
358 .map_err(|e| Error::OperationFailed {
359 operation: "redis_lock_index_created".to_string(),
360 cause: e.to_string(),
361 })?;
362 *guard = true;
363 return Ok(());
364 }
365
366 let create_result: RedisResult<()> = redis::cmd("FT.CREATE")
368 .arg(&self.index_name)
369 .arg("ON")
370 .arg("HASH")
371 .arg("PREFIX")
372 .arg("1")
373 .arg(self.key_prefix())
374 .arg("SCHEMA")
375 .arg("embedding")
376 .arg("VECTOR")
377 .arg("HNSW")
378 .arg("6")
379 .arg("TYPE")
380 .arg("FLOAT32")
381 .arg("DIM")
382 .arg(self.dimensions)
383 .arg("DISTANCE_METRIC")
384 .arg("COSINE")
385 .arg("memory_id")
386 .arg("TAG")
387 .query(conn);
388
389 match create_result {
390 Ok(()) => {
391 let mut guard = self
392 .index_created
393 .lock()
394 .map_err(|e| Error::OperationFailed {
395 operation: "redis_lock_index_created".to_string(),
396 cause: e.to_string(),
397 })?;
398 *guard = true;
399 Ok(())
400 },
401 Err(e) => {
402 if e.to_string().contains("Index already exists") {
403 let mut guard =
404 self.index_created
405 .lock()
406 .map_err(|e| Error::OperationFailed {
407 operation: "redis_lock_index_created".to_string(),
408 cause: e.to_string(),
409 })?;
410 *guard = true;
411 Ok(())
412 } else {
413 Err(Error::OperationFailed {
414 operation: "create_index".to_string(),
415 cause: e.to_string(),
416 })
417 }
418 },
419 }
420 }
421
422 fn parse_search_results(value: &redis::Value) -> Vec<(MemoryId, f32)> {
424 use redis::Value;
425
426 let Value::Array(arr) = value else {
427 return Vec::new();
428 };
429
430 if arr.is_empty() {
431 return Vec::new();
432 }
433
434 let mut results = Vec::new();
435 let mut i = 1;
436
437 while i + 1 < arr.len() {
438 let Some(key) = Self::value_to_string(&arr[i]) else {
439 i += 2;
440 continue;
441 };
442
443 let memory_id = key.split(':').next_back().unwrap_or(&key);
444 let score = Self::extract_score_from_fields(&arr[i + 1]);
445 results.push((MemoryId::new(memory_id), score));
446 i += 2;
447 }
448
449 results
450 }
451
452 fn extract_score_from_fields(value: &redis::Value) -> f32 {
454 use redis::Value;
455
456 let Value::Array(fields) = value else {
457 return 0.0;
458 };
459
460 let mut j = 0;
461 while j + 1 < fields.len() {
462 let field_name = Self::value_to_string(&fields[j]).unwrap_or_default();
463 if field_name != "__embedding_score" {
464 j += 2;
465 continue;
466 }
467 let Some(s) = Self::value_to_string(&fields[j + 1]) else {
468 j += 2;
469 continue;
470 };
471 let Ok(distance) = s.parse::<f32>() else {
472 j += 2;
473 continue;
474 };
475 return 1.0 - distance.clamp(0.0, 2.0) / 2.0;
476 }
477 0.0
478 }
479
480 fn parse_info_num_docs(value: &redis::Value) -> usize {
482 use redis::Value;
483
484 let Value::Array(arr) = value else {
485 return 0;
486 };
487
488 let mut i = 0;
489 while i + 1 < arr.len() {
490 let key = Self::value_to_string(&arr[i]).unwrap_or_default();
491 if key != "num_docs" {
492 i += 2;
493 continue;
494 }
495 let Some(s) = Self::value_to_string(&arr[i + 1]) else {
496 i += 2;
497 continue;
498 };
499 return s.parse().unwrap_or(0);
500 }
501 0
502 }
503
504 fn value_to_string(value: &redis::Value) -> Option<String> {
506 use redis::Value;
507
508 match value {
509 Value::BulkString(s) => Some(String::from_utf8_lossy(s).to_string()),
510 Value::SimpleString(s) => Some(s.clone()),
511 Value::Int(i) => Some(i.to_string()),
512 _ => None,
513 }
514 }
515}
516
517#[cfg(feature = "redis")]
518impl VectorBackend for RedisVectorBackend {
519 fn dimensions(&self) -> usize {
520 self.dimensions
521 }
522
523 fn upsert(&self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
524 self.validate_embedding(embedding)?;
525
526 let mut conn = self.get_connection()?;
527
528 let result = self.ensure_index(&mut conn);
529 if let Err(e) = result {
530 self.return_connection(conn);
531 return Err(e);
532 }
533
534 let key = self.memory_key(id);
535 let vector_bytes = Self::vector_to_bytes(embedding);
536
537 let result: RedisResult<()> = conn.hset_multiple(
538 &key,
539 &[
540 ("embedding", vector_bytes.as_slice()),
541 ("memory_id", id.as_str().as_bytes()),
542 ],
543 );
544
545 match result {
546 Ok(()) => {
547 self.return_connection(conn);
548 Ok(())
549 },
550 Err(e) => {
551 self.return_connection(conn);
552 Err(Error::OperationFailed {
553 operation: "upsert".to_string(),
554 cause: e.to_string(),
555 })
556 },
557 }
558 }
559
560 fn remove(&self, id: &MemoryId) -> Result<bool> {
561 let mut conn = self.get_connection()?;
562 let key = self.memory_key(id);
563
564 let result: RedisResult<i32> = conn.del(&key);
565
566 match result {
567 Ok(deleted) => {
568 self.return_connection(conn);
569 Ok(deleted > 0)
570 },
571 Err(e) => {
572 self.return_connection(conn);
573 Err(Error::OperationFailed {
574 operation: "remove".to_string(),
575 cause: e.to_string(),
576 })
577 },
578 }
579 }
580
581 fn search(
582 &self,
583 query_embedding: &[f32],
584 _filter: &VectorFilter,
585 limit: usize,
586 ) -> Result<Vec<(MemoryId, f32)>> {
587 self.validate_embedding(query_embedding)?;
588
589 let mut conn = self.get_connection()?;
590
591 let vector_bytes = Self::vector_to_bytes(query_embedding);
592 let query = format!("*=>[KNN {limit} @embedding $BLOB]");
593
594 let result: RedisResult<redis::Value> = redis::cmd("FT.SEARCH")
595 .arg(&self.index_name)
596 .arg(&query)
597 .arg("PARAMS")
598 .arg("2")
599 .arg("BLOB")
600 .arg(vector_bytes.as_slice())
601 .arg("RETURN")
602 .arg("2")
603 .arg("memory_id")
604 .arg("__embedding_score")
605 .arg("DIALECT")
606 .arg("2")
607 .query(&mut conn);
608
609 match result {
610 Ok(value) => {
611 self.return_connection(conn);
612 Ok(Self::parse_search_results(&value))
613 },
614 Err(e) => {
615 self.return_connection(conn);
616 Err(Error::OperationFailed {
617 operation: "search".to_string(),
618 cause: e.to_string(),
619 })
620 },
621 }
622 }
623
624 fn count(&self) -> Result<usize> {
625 let mut conn = self.get_connection()?;
626
627 let result: RedisResult<redis::Value> =
628 redis::cmd("FT.INFO").arg(&self.index_name).query(&mut conn);
629
630 match result {
631 Ok(info) => {
632 self.return_connection(conn);
633 Ok(Self::parse_info_num_docs(&info))
634 },
635 Err(e) => {
636 self.return_connection(conn);
637 if e.to_string().contains("Unknown index name") {
638 return Err(Error::OperationFailed {
639 operation: "count".to_string(),
640 cause: "index_not_found".to_string(),
641 });
642 }
643 Err(Error::OperationFailed {
644 operation: "count".to_string(),
645 cause: e.to_string(),
646 })
647 },
648 }
649 }
650
651 fn clear(&self) -> Result<()> {
652 let mut conn = self.get_connection()?;
653
654 let _: RedisResult<()> = redis::cmd("FT.DROPINDEX")
655 .arg(&self.index_name)
656 .arg("DD")
657 .query(&mut conn);
658
659 {
661 let mut guard = self
662 .index_created
663 .lock()
664 .map_err(|e| Error::OperationFailed {
665 operation: "redis_lock_index_created".to_string(),
666 cause: e.to_string(),
667 })?;
668 *guard = false;
669 }
670
671 self.return_connection(conn);
672 Ok(())
673 }
674}
675
676#[cfg(not(feature = "redis"))]
677impl VectorBackend for RedisVectorBackend {
678 fn dimensions(&self) -> usize {
679 self.dimensions
680 }
681
682 fn upsert(&self, _id: &MemoryId, _embedding: &[f32]) -> Result<()> {
683 Err(Error::NotImplemented(
684 "Redis vector backend requires 'redis' feature".to_string(),
685 ))
686 }
687
688 fn remove(&self, _id: &MemoryId) -> Result<bool> {
689 Err(Error::NotImplemented(
690 "Redis vector backend requires 'redis' feature".to_string(),
691 ))
692 }
693
694 fn search(
695 &self,
696 _query_embedding: &[f32],
697 _filter: &VectorFilter,
698 _limit: usize,
699 ) -> Result<Vec<(MemoryId, f32)>> {
700 Err(Error::NotImplemented(
701 "Redis vector backend requires 'redis' feature".to_string(),
702 ))
703 }
704
705 fn count(&self) -> Result<usize> {
706 Err(Error::NotImplemented(
707 "Redis vector backend requires 'redis' feature".to_string(),
708 ))
709 }
710
711 fn clear(&self) -> Result<()> {
712 Err(Error::NotImplemented(
713 "Redis vector backend requires 'redis' feature".to_string(),
714 ))
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[cfg(not(feature = "redis"))]
723 #[test]
724 fn test_redis_backend_creation() {
725 let backend = RedisVectorBackend::new("redis://localhost:6379", "test_idx", 384);
726 assert_eq!(backend.dimensions(), 384);
727 assert_eq!(backend.connection_url(), "redis://localhost:6379");
728 assert_eq!(backend.index_name(), "test_idx");
729 }
730
731 #[cfg(not(feature = "redis"))]
732 #[test]
733 fn test_redis_backend_defaults() {
734 let backend = RedisVectorBackend::with_defaults();
735 assert_eq!(backend.dimensions(), RedisVectorBackend::DEFAULT_DIMENSIONS);
736 assert_eq!(backend.connection_url(), "redis://localhost:6379");
737 assert_eq!(backend.index_name(), "subcog_vectors");
738 }
739
740 #[cfg(feature = "redis")]
741 #[test]
742 fn test_redis_backend_creation() {
743 let result = RedisVectorBackend::new("redis://localhost:6379", "test_idx", 384);
746 if let Ok(backend) = result {
748 assert_eq!(backend.dimensions(), 384);
749 assert_eq!(backend.connection_url(), "redis://localhost:6379");
750 assert_eq!(backend.index_name(), "test_idx");
751 }
752 }
753
754 #[cfg(feature = "redis")]
755 #[test]
756 fn test_key_generation() {
757 if let Ok(backend) = RedisVectorBackend::new("redis://localhost", "idx", 384) {
758 assert_eq!(backend.key_prefix(), "idx:");
759 assert_eq!(backend.memory_key(&MemoryId::new("mem-001")), "idx:mem-001");
760 }
761 }
762
763 #[cfg(feature = "redis")]
764 #[test]
765 fn test_validate_embedding() {
766 if let Ok(backend) = RedisVectorBackend::new("redis://localhost", "idx", 384) {
767 let valid: Vec<f32> = vec![0.0; 384];
768 assert!(backend.validate_embedding(&valid).is_ok());
769
770 let invalid: Vec<f32> = vec![0.0; 256];
771 assert!(backend.validate_embedding(&invalid).is_err());
772 }
773 }
774
775 #[cfg(not(feature = "redis"))]
776 #[test]
777 fn test_stub_returns_not_implemented() {
778 let backend = RedisVectorBackend::with_defaults();
779 let embedding: Vec<f32> = vec![0.0; 384];
780 let id = MemoryId::new("test");
781
782 assert!(backend.upsert(&id, &embedding).is_err());
783 assert!(backend.remove(&id).is_err());
784 assert!(
785 backend
786 .search(&embedding, &VectorFilter::new(), 10)
787 .is_err()
788 );
789 assert!(backend.count().is_err());
790 assert!(backend.clear().is_err());
791 }
792}