1use crate::context::GitContext;
42use crate::current_timestamp;
43use crate::embedding::Embedder;
44use crate::gc::branch_exists;
45use crate::models::{
46 EventMeta, Memory, MemoryEvent, MemoryId, MemoryStatus, SearchFilter, SearchHit, SearchMode,
47 SearchResult,
48};
49use crate::observability::current_request_id;
50use crate::security::record_event;
51use crate::storage::index::SqliteBackend;
52use crate::storage::traits::{GraphBackend, IndexBackend, VectorBackend};
53use crate::{Error, Result};
54use chrono::{TimeZone, Utc};
55use git2::{BranchType, Repository};
56use std::borrow::Cow;
57use std::collections::{HashMap, HashSet};
58use std::sync::Arc;
59use std::time::Instant;
60use tracing::{info_span, instrument, warn};
61
62type RrfEntry = (f32, Option<usize>, Option<usize>, Option<f32>);
64
65pub const DEFAULT_SEARCH_TIMEOUT_MS: u64 = 5_000;
67
68pub struct RecallService {
87 index: Option<SqliteBackend>,
89 embedder: Option<Arc<dyn Embedder>>,
91 vector: Option<Arc<dyn VectorBackend + Send + Sync>>,
93 graph: Option<Arc<dyn GraphBackend>>,
95 scope_filter: Option<SearchFilter>,
97 timeout_ms: u64,
99}
100
101impl RecallService {
102 #[must_use]
107 pub const fn new() -> Self {
108 Self {
109 index: None,
110 embedder: None,
111 vector: None,
112 graph: None,
113 scope_filter: None,
114 timeout_ms: DEFAULT_SEARCH_TIMEOUT_MS,
115 }
116 }
117
118 #[must_use]
122 pub const fn with_index(index: SqliteBackend) -> Self {
123 Self {
124 index: Some(index),
125 embedder: None,
126 vector: None,
127 graph: None,
128 scope_filter: None,
129 timeout_ms: DEFAULT_SEARCH_TIMEOUT_MS,
130 }
131 }
132
133 #[must_use]
141 pub fn with_backends(
142 index: SqliteBackend,
143 embedder: Arc<dyn Embedder>,
144 vector: Arc<dyn VectorBackend + Send + Sync>,
145 ) -> Self {
146 Self {
147 index: Some(index),
148 embedder: Some(embedder),
149 vector: Some(vector),
150 graph: None,
151 scope_filter: None,
152 timeout_ms: DEFAULT_SEARCH_TIMEOUT_MS,
153 }
154 }
155
156 #[must_use]
158 pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
159 self.embedder = Some(embedder);
160 self
161 }
162
163 #[must_use]
165 pub fn with_vector(mut self, vector: Arc<dyn VectorBackend + Send + Sync>) -> Self {
166 self.vector = Some(vector);
167 self
168 }
169
170 #[must_use]
175 pub fn with_graph(mut self, graph: Arc<dyn GraphBackend>) -> Self {
176 self.graph = Some(graph);
177 self
178 }
179
180 #[must_use]
185 pub fn with_scope_filter(mut self, filter: SearchFilter) -> Self {
186 self.scope_filter = Some(filter);
187 self
188 }
189
190 #[must_use]
198 pub const fn with_timeout_ms(mut self, timeout_ms: u64) -> Self {
199 self.timeout_ms = timeout_ms;
200 self
201 }
202
203 #[must_use]
205 pub const fn timeout_ms(&self) -> u64 {
206 self.timeout_ms
207 }
208
209 #[must_use]
211 pub fn has_vector_search(&self) -> bool {
212 self.embedder.is_some() && self.vector.is_some()
213 }
214
215 fn effective_filter<'a>(&'a self, filter: &'a SearchFilter) -> Cow<'a, SearchFilter> {
216 let Some(scope_filter) = &self.scope_filter else {
217 return Cow::Borrowed(filter);
218 };
219
220 let mut merged = filter.clone();
221
222 if merged.project_id.is_none() {
223 merged.project_id.clone_from(&scope_filter.project_id);
224 }
225 if merged.branch.is_none() {
226 merged.branch.clone_from(&scope_filter.branch);
227 }
228 if merged.file_path.is_none() {
229 merged.file_path.clone_from(&scope_filter.file_path);
230 }
231 if merged.source_pattern.is_none() {
232 merged
233 .source_pattern
234 .clone_from(&scope_filter.source_pattern);
235 }
236
237 Cow::Owned(merged)
238 }
239
240 #[allow(clippy::cast_possible_truncation)]
252 #[instrument(
253 name = "subcog.memory.recall",
254 skip(self, query, filter),
255 fields(
256 request_id = tracing::field::Empty,
257 component = "memory",
258 operation = "recall",
259 mode = %mode,
260 query_length = query.len(),
261 limit = limit,
262 timeout_ms = self.timeout_ms
263 )
264 )]
265 pub fn search(
266 &self,
267 query: &str,
268 mode: SearchMode,
269 filter: &SearchFilter,
270 limit: usize,
271 ) -> Result<SearchResult> {
272 let start = Instant::now();
273 let effective_filter = self.effective_filter(filter);
274 let filter = effective_filter.as_ref();
275 let domain_label = domain_label(filter);
276 let mode_label = mode.as_str();
277 if let Some(request_id) = current_request_id() {
278 tracing::Span::current().record("request_id", request_id.as_str());
279 }
280 tracing::info!(mode = %mode_label, query_length = query.len(), limit = limit, timeout_ms = self.timeout_ms, "Searching memories");
281 const MAX_QUERY_SIZE: usize = 10_000;
283 let deadline_ms = self.timeout_ms;
285 let result = (|| {
286 if query.trim().is_empty() {
288 return Err(Error::InvalidInput("Query cannot be empty".to_string()));
289 }
290 if query.len() > MAX_QUERY_SIZE {
291 return Err(Error::InvalidInput(format!(
292 "Query exceeds maximum size of {} bytes (got {} bytes)",
293 MAX_QUERY_SIZE,
294 query.len()
295 )));
296 }
297 if deadline_ms > 0 && start.elapsed().as_millis() as u64 >= deadline_ms {
299 tracing::warn!(
300 elapsed_ms = start.elapsed().as_millis(),
301 timeout_ms = deadline_ms,
302 "Search timeout before execution"
303 );
304 metrics::counter!("memory_search_timeouts_total", "mode" => mode_label, "phase" => "pre_search").increment(1);
305 return Err(Error::OperationFailed {
306 operation: "search".to_string(),
307 cause: format!("Search timeout exceeded ({deadline_ms}ms)"),
308 });
309 }
310 let mut memories = match mode {
311 SearchMode::Text => {
312 let _span = info_span!("subcog.memory.recall.text_search").entered();
313 self.text_search(query, filter, limit)?
314 },
315 SearchMode::Vector => {
316 let _span = info_span!("subcog.memory.recall.vector_search").entered();
317 self.vector_search(query, filter, limit)?
318 },
319 SearchMode::Hybrid => {
320 let _span = info_span!("subcog.memory.recall.hybrid_search").entered();
321 self.hybrid_search(query, filter, limit)?
322 },
323 };
324
325 if !filter.entity_names.is_empty() {
327 self.apply_entity_filter(&mut memories, &filter.entity_names);
328 }
329
330 if deadline_ms > 0 && start.elapsed().as_millis() as u64 >= deadline_ms {
332 tracing::warn!(
333 elapsed_ms = start.elapsed().as_millis(),
334 timeout_ms = deadline_ms,
335 results_found = memories.len(),
336 "Search timeout after execution, returning partial results"
337 );
338 metrics::counter!("memory_search_timeouts_total", "mode" => mode_label, "phase" => "post_search").increment(1);
339 }
341
342 if mode != SearchMode::Hybrid {
345 normalize_scores(&mut memories);
346 }
347
348 self.lazy_tombstone_stale_branches(&mut memories, filter);
349
350 let execution_time_ms = start.elapsed().as_millis() as u64;
352 let total_count = memories.len();
353 record_recall_events(&memories, query);
354
355 Ok(SearchResult {
356 memories,
357 total_count,
358 mode,
359 execution_time_ms,
360 })
361 })();
362
363 let status = if result.is_ok() { "success" } else { "error" };
364 metrics::counter!(
365 "memory_search_total",
366 "mode" => mode_label,
367 "domain" => domain_label,
368 "status" => status
369 )
370 .increment(1);
371 metrics::histogram!(
372 "memory_search_duration_ms",
373 "mode" => mode_label,
374 "backend" => "sqlite"
375 )
376 .record(start.elapsed().as_secs_f64() * 1000.0);
377 metrics::histogram!(
378 "memory_lifecycle_duration_ms",
379 "component" => "memory",
380 "operation" => "recall",
381 "mode" => mode_label
382 )
383 .record(start.elapsed().as_secs_f64() * 1000.0);
384
385 result
386 }
387
388 fn lazy_tombstone_stale_branches(&self, hits: &mut Vec<SearchHit>, filter: &SearchFilter) {
397 let ctx = GitContext::from_cwd();
398 let Some(project_id) = ctx.project_id else {
399 return;
400 };
401
402 let tombstoned_ids = Self::mark_stale_branch_hits(hits, &project_id);
404
405 if !tombstoned_ids.is_empty() {
407 self.persist_tombstones_to_index(hits, &tombstoned_ids);
408 }
409
410 if !filter.include_tombstoned {
412 hits.retain(|hit| hit.memory.status != MemoryStatus::Tombstoned);
413 }
414 }
415
416 fn mark_stale_branch_hits(hits: &mut [SearchHit], project_id: &str) -> Vec<usize> {
423 let now = current_timestamp();
424 let now_i64 = i64::try_from(now).unwrap_or(i64::MAX);
425 let now_dt = Utc
426 .timestamp_opt(now_i64, 0)
427 .single()
428 .unwrap_or_else(Utc::now);
429
430 let branch_names = Self::load_branch_names();
431 let mut tombstoned_indices = Vec::new();
432
433 for (idx, hit) in hits.iter_mut().enumerate() {
434 let Some(branch) = hit.memory.branch.as_deref() else {
435 continue;
436 };
437
438 if hit.memory.status == MemoryStatus::Tombstoned {
439 continue;
440 }
441
442 if hit.memory.project_id.as_deref() != Some(project_id) {
443 continue;
444 }
445
446 let exists = match &branch_names {
447 Some(names) => names.contains(branch),
448 None => branch_exists(branch),
449 };
450
451 if exists {
452 continue;
453 }
454
455 hit.memory.status = MemoryStatus::Tombstoned;
457 hit.memory.tombstoned_at = Some(now_dt);
458 tombstoned_indices.push(idx);
459 }
460
461 tombstoned_indices
462 }
463
464 fn persist_tombstones_to_index(&self, hits: &[SearchHit], indices: &[usize]) {
468 let Some(index) = self.index.as_ref() else {
469 return;
470 };
471
472 for &idx in indices {
473 let Some(hit) = hits.get(idx) else {
475 continue;
476 };
477 if let Err(err) = index.index(&hit.memory) {
478 warn!(
479 error = %err,
480 memory_id = %hit.memory.id.as_str(),
481 "Failed to persist tombstone to index during recall"
482 );
483 }
484 }
485 }
486
487 fn apply_entity_filter(&self, hits: &mut Vec<SearchHit>, entity_names: &[String]) {
492 let Some(graph) = &self.graph else {
493 tracing::debug!("Entity filter requested but no graph backend configured");
494 return;
495 };
496
497 let _span = info_span!(
498 "subcog.memory.recall.entity_filter",
499 entity_count = entity_names.len()
500 )
501 .entered();
502
503 let allowed_memory_ids: HashSet<MemoryId> = entity_names
505 .iter()
506 .flat_map(|name| self.collect_mentions_for_entity_name(graph.as_ref(), name))
507 .collect();
508
509 if allowed_memory_ids.is_empty() {
511 tracing::debug!(
512 entity_names = ?entity_names,
513 "No entities found for filter, returning empty results"
514 );
515 hits.clear();
516 return;
517 }
518
519 let before_count = hits.len();
521 hits.retain(|hit| allowed_memory_ids.contains(&hit.memory.id));
522
523 tracing::debug!(
524 before = before_count,
525 after = hits.len(),
526 allowed_memories = allowed_memory_ids.len(),
527 "Applied entity filter"
528 );
529 }
530
531 fn collect_mentions_for_entity_name(
533 &self,
534 graph: &dyn GraphBackend,
535 entity_name: &str,
536 ) -> Vec<MemoryId> {
537 let entities = match graph.find_entities_by_name(entity_name, None, None, 10) {
538 Ok(e) => e,
539 Err(e) => {
540 tracing::warn!(error = %e, entity_name = %entity_name, "Failed to find entities");
541 return Vec::new();
542 },
543 };
544
545 entities
546 .into_iter()
547 .flat_map(|entity| {
548 graph
549 .get_mentions_for_entity(&entity.id)
550 .inspect_err(|e| {
551 tracing::warn!(error = %e, entity_id = %entity.id, "Failed to get mentions");
552 })
553 .unwrap_or_default()
554 .into_iter()
555 .map(|m| m.memory_id)
556 })
557 .collect()
558 }
559
560 fn load_branch_names() -> Option<HashSet<String>> {
561 let cwd = std::env::current_dir().ok()?;
562 let repo = Repository::discover(&cwd).ok()?;
563 let mut names = HashSet::new();
564
565 if let Ok(branches) = repo.branches(Some(BranchType::Local)) {
566 for name in branches
567 .flatten()
568 .filter_map(|(branch, _)| branch.name().ok().flatten().map(str::to_string))
569 {
570 names.insert(name);
571 }
572 }
573
574 if let Ok(branches) = repo.branches(Some(BranchType::Remote)) {
575 for name in branches
576 .flatten()
577 .filter_map(|(branch, _)| branch.name().ok().flatten().map(str::to_string))
578 .filter_map(|name| {
579 name.split_once('/')
580 .map(|(_, branch_name)| branch_name.to_string())
581 })
582 {
583 names.insert(name);
584 }
585 }
586
587 Some(names)
588 }
589
590 #[allow(clippy::cast_possible_truncation)]
602 #[instrument(
603 name = "subcog.memory.recall.list_all",
604 skip(self, filter),
605 fields(
606 request_id = tracing::field::Empty,
607 component = "memory",
608 operation = "list_all",
609 limit = limit
610 )
611 )]
612 pub fn list_all(&self, filter: &SearchFilter, limit: usize) -> Result<SearchResult> {
613 let start = Instant::now();
614 let effective_filter = self.effective_filter(filter);
615 let filter = effective_filter.as_ref();
616 let domain_label = domain_label(filter);
617 if let Some(request_id) = current_request_id() {
618 tracing::Span::current().record("request_id", request_id.as_str());
619 }
620
621 let result = (|| {
622 let index = self.index.as_ref().ok_or_else(|| Error::OperationFailed {
623 operation: "list_all".to_string(),
624 cause: "No index backend configured".to_string(),
625 })?;
626
627 let results = index.list_all(filter, limit)?;
628
629 let ids: Vec<_> = results.iter().map(|(id, _)| id.clone()).collect();
631 let batch_memories = index.get_memories_batch(&ids)?;
632
633 let memories: Vec<SearchHit> = results
635 .into_iter()
636 .zip(batch_memories)
637 .filter_map(|((_, score), memory_opt)| {
638 memory_opt.map(|mut memory| {
639 memory.content = String::new();
641 SearchHit {
642 memory,
643 score,
644 raw_score: score,
645 vector_score: None,
646 bm25_score: None,
647 }
648 })
649 })
650 .collect();
651
652 let execution_time_ms = start.elapsed().as_millis() as u64;
653 let total_count = memories.len();
654 record_recall_events(&memories, "*");
655
656 Ok(SearchResult {
657 memories,
658 total_count,
659 mode: SearchMode::Text,
660 execution_time_ms,
661 })
662 })();
663
664 let status = if result.is_ok() { "success" } else { "error" };
665 metrics::counter!(
666 "memory_search_total",
667 "mode" => "list_all",
668 "domain" => domain_label,
669 "status" => status
670 )
671 .increment(1);
672 metrics::histogram!(
673 "memory_search_duration_ms",
674 "mode" => "list_all",
675 "backend" => "sqlite"
676 )
677 .record(start.elapsed().as_secs_f64() * 1000.0);
678 metrics::histogram!(
679 "memory_lifecycle_duration_ms",
680 "component" => "memory",
681 "operation" => "recall",
682 "mode" => "list_all"
683 )
684 .record(start.elapsed().as_secs_f64() * 1000.0);
685
686 result
687 }
688
689 #[allow(clippy::cast_possible_truncation)]
701 #[instrument(
702 name = "subcog.memory.recall.list_all_with_content",
703 skip(self, filter),
704 fields(
705 request_id = tracing::field::Empty,
706 component = "memory",
707 operation = "list_all_with_content",
708 limit = limit
709 )
710 )]
711 pub fn list_all_with_content(
712 &self,
713 filter: &SearchFilter,
714 limit: usize,
715 ) -> Result<SearchResult> {
716 let start = Instant::now();
717 let effective_filter = self.effective_filter(filter);
718 let filter = effective_filter.as_ref();
719 let domain_label = domain_label(filter);
720 if let Some(request_id) = current_request_id() {
721 tracing::Span::current().record("request_id", request_id.as_str());
722 }
723
724 let result = (|| {
725 let index = self.index.as_ref().ok_or_else(|| Error::OperationFailed {
726 operation: "list_all_with_content".to_string(),
727 cause: "No index backend configured".to_string(),
728 })?;
729
730 let results = index.list_all(filter, limit)?;
731
732 let ids: Vec<_> = results.iter().map(|(id, _)| id.clone()).collect();
734 let batch_memories = index.get_memories_batch(&ids)?;
735
736 let memories: Vec<SearchHit> = results
738 .into_iter()
739 .zip(batch_memories)
740 .filter_map(|((_, score), memory_opt)| {
741 memory_opt.map(|memory| SearchHit {
742 memory, score,
744 raw_score: score,
745 vector_score: None,
746 bm25_score: None,
747 })
748 })
749 .collect();
750
751 let execution_time_ms = start.elapsed().as_millis() as u64;
752 let total_count = memories.len();
753 record_recall_events(&memories, "*");
754
755 Ok(SearchResult {
756 memories,
757 total_count,
758 mode: SearchMode::Text,
759 execution_time_ms,
760 })
761 })();
762
763 let status = if result.is_ok() { "success" } else { "error" };
764 metrics::counter!(
765 "memory_search_total",
766 "mode" => "list_all_with_content",
767 "domain" => domain_label,
768 "status" => status
769 )
770 .increment(1);
771 metrics::histogram!(
772 "memory_search_duration_ms",
773 "mode" => "list_all_with_content",
774 "backend" => "sqlite"
775 )
776 .record(start.elapsed().as_secs_f64() * 1000.0);
777 metrics::histogram!(
778 "memory_lifecycle_duration_ms",
779 "component" => "memory",
780 "operation" => "recall",
781 "mode" => "list_all_with_content"
782 )
783 .record(start.elapsed().as_secs_f64() * 1000.0);
784
785 result
786 }
787
788 fn text_search(
794 &self,
795 query: &str,
796 filter: &SearchFilter,
797 limit: usize,
798 ) -> Result<Vec<SearchHit>> {
799 let index = self.index.as_ref().ok_or_else(|| Error::OperationFailed {
800 operation: "text_search".to_string(),
801 cause: "No index backend configured".to_string(),
802 })?;
803
804 let results = index.search(query, filter, limit)?;
805
806 let ids: Vec<_> = results.iter().map(|(id, _)| id.clone()).collect();
808 let batch_memories = index.get_memories_batch(&ids)?;
809
810 let hits: Vec<SearchHit> = results
812 .into_iter()
813 .zip(batch_memories)
814 .map(|((id, score), memory_opt)| {
815 let memory = memory_opt.unwrap_or_else(|| create_placeholder_memory(id));
816 SearchHit {
817 memory,
818 score,
819 raw_score: score,
820 vector_score: None,
821 bm25_score: Some(score),
822 }
823 })
824 .collect();
825
826 Ok(hits)
827 }
828
829 fn vector_search(
841 &self,
842 query: &str,
843 filter: &SearchFilter,
844 limit: usize,
845 ) -> Result<Vec<SearchHit>> {
846 let (embedder, vector) = match (&self.embedder, &self.vector) {
848 (Some(e), Some(v)) => (e, v),
849 (None, _) => {
850 tracing::debug!("Vector search unavailable: no embedder configured");
851 return Ok(Vec::new());
852 },
853 (_, None) => {
854 tracing::debug!("Vector search unavailable: no vector backend configured");
855 return Ok(Vec::new());
856 },
857 };
858
859 let query_embedding = match embedder.embed(query) {
861 Ok(emb) => emb,
862 Err(e) => {
863 tracing::warn!("Failed to embed query for vector search: {e}");
864 return Ok(Vec::new());
865 },
866 };
867
868 let vector_filter = crate::storage::traits::VectorFilter::from(filter);
870 let results = match vector.search(&query_embedding, &vector_filter, limit) {
871 Ok(r) => r,
872 Err(e) => {
873 tracing::warn!("Vector search failed: {e}");
874 return Ok(Vec::new());
875 },
876 };
877
878 let index = match &self.index {
880 Some(idx) => idx,
881 None => {
882 return Ok(results
884 .into_iter()
885 .map(|(id, score)| SearchHit {
886 memory: create_placeholder_memory(id),
887 score,
888 raw_score: score,
889 vector_score: Some(score),
890 bm25_score: None,
891 })
892 .collect());
893 },
894 };
895
896 let ids: Vec<_> = results.iter().map(|(id, _)| id.clone()).collect();
898 let batch_memories = match index.get_memories_batch(&ids) {
899 Ok(m) => m,
900 Err(e) => {
901 tracing::warn!("Failed to fetch memories for vector results: {e}");
902 return Ok(results
904 .into_iter()
905 .map(|(id, score)| SearchHit {
906 memory: create_placeholder_memory(id),
907 score,
908 raw_score: score,
909 vector_score: Some(score),
910 bm25_score: None,
911 })
912 .collect());
913 },
914 };
915
916 let hits: Vec<SearchHit> = results
918 .into_iter()
919 .zip(batch_memories)
920 .map(|((id, score), memory_opt)| {
921 let memory = memory_opt.unwrap_or_else(|| create_placeholder_memory(id));
922 SearchHit {
923 memory,
924 score,
925 raw_score: score,
926 vector_score: Some(score),
927 bm25_score: None,
928 }
929 })
930 .collect();
931
932 Ok(hits)
933 }
934
935 fn hybrid_search(
937 &self,
938 query: &str,
939 filter: &SearchFilter,
940 limit: usize,
941 ) -> Result<Vec<SearchHit>> {
942 let text_results = self.text_search(query, filter, limit * 2)?;
944 let vector_results = self.vector_search(query, filter, limit * 2)?;
945
946 let mut fused = self.rrf_fusion(&text_results, &vector_results, limit);
948
949 normalize_scores(&mut fused);
951
952 Ok(fused)
953 }
954
955 fn rrf_fusion(
997 &self,
998 text_results: &[SearchHit],
999 vector_results: &[SearchHit],
1000 limit: usize,
1001 ) -> Vec<SearchHit> {
1002 const K: f32 = 60.0; let capacity = text_results.len() + vector_results.len();
1010 let mut scores: HashMap<String, RrfEntry> = HashMap::with_capacity(capacity);
1011
1012 for (rank, hit) in text_results.iter().enumerate() {
1014 let id = hit.memory.id.to_string();
1015 let rrf_score = 1.0 / (K + rank as f32 + 1.0);
1016
1017 scores
1018 .entry(id)
1019 .and_modify(|(s, _, _, _)| *s += rrf_score)
1020 .or_insert((rrf_score, Some(rank), None, None));
1021 }
1022
1023 for (rank, hit) in vector_results.iter().enumerate() {
1025 let id = hit.memory.id.to_string();
1026 let rrf_score = 1.0 / (K + rank as f32 + 1.0);
1027
1028 scores
1029 .entry(id)
1030 .and_modify(|(s, _, vec_idx, vec_score)| {
1031 *s += rrf_score;
1032 *vec_idx = Some(rank);
1034 *vec_score = hit.vector_score;
1035 })
1036 .or_insert((rrf_score, None, Some(rank), hit.vector_score));
1037 }
1038
1039 let mut results: Vec<_> = scores
1041 .into_iter()
1042 .filter_map(|(_, (score, text_idx, vec_idx, vec_score))| {
1043 let mut hit = if let Some(idx) = text_idx {
1045 text_results.get(idx).cloned()
1046 } else {
1047 vec_idx.and_then(|idx| vector_results.get(idx).cloned())
1048 }?;
1049
1050 if vec_score.is_some() {
1052 hit.vector_score = vec_score;
1053 }
1054
1055 hit.score = score;
1056 Some(hit)
1057 })
1058 .collect();
1059
1060 results.sort_by(|a, b| {
1061 b.score
1062 .partial_cmp(&a.score)
1063 .unwrap_or(std::cmp::Ordering::Equal)
1064 });
1065 results.truncate(limit);
1066
1067 results
1068 }
1069
1070 pub fn get_by_id(&self, id: &MemoryId) -> Result<Option<Memory>> {
1080 let index = self.index.as_ref().ok_or_else(|| Error::OperationFailed {
1081 operation: "get_by_id".to_string(),
1082 cause: "No index backend configured".to_string(),
1083 })?;
1084
1085 index.get_memory(id)
1086 }
1087
1088 pub const fn recent(&self, _limit: usize, _filter: &SearchFilter) -> Result<Vec<Memory>> {
1100 Ok(Vec::new())
1102 }
1103
1104 pub fn search_authorized(
1122 &self,
1123 query: &str,
1124 mode: SearchMode,
1125 filter: &SearchFilter,
1126 limit: usize,
1127 auth: &super::auth::AuthContext,
1128 ) -> Result<SearchResult> {
1129 auth.require(super::auth::Permission::Read)?;
1130 self.search(query, mode, filter, limit)
1131 }
1132
1133 pub fn get_by_id_authorized(
1141 &self,
1142 id: &MemoryId,
1143 auth: &super::auth::AuthContext,
1144 ) -> Result<Option<Memory>> {
1145 auth.require(super::auth::Permission::Read)?;
1146 self.get_by_id(id)
1147 }
1148}
1149
1150fn domain_label(filter: &SearchFilter) -> Cow<'static, str> {
1152 match filter.domains.len() {
1153 0 => Cow::Borrowed("all"),
1154 1 => Cow::Owned(filter.domains[0].to_string()),
1155 _ => Cow::Borrowed("multi"),
1156 }
1157}
1158
1159fn record_recall_events(memories: &[SearchHit], query: &str) {
1160 let timestamp = current_timestamp();
1161 let query_arc: std::sync::Arc<str> = query.into();
1162 for hit in memories {
1163 record_event(MemoryEvent::Retrieved {
1164 meta: EventMeta::with_timestamp("recall", current_request_id(), timestamp),
1165 memory_id: hit.memory.id.clone(),
1166 query: std::sync::Arc::clone(&query_arc),
1167 score: hit.score,
1168 });
1169 }
1170}
1171
1172fn normalize_scores(results: &mut [SearchHit]) {
1200 if results.is_empty() {
1201 return;
1202 }
1203
1204 let max_score = results.iter().map(|h| h.score).fold(0.0_f32, f32::max);
1206
1207 if max_score <= f32::EPSILON {
1209 return;
1210 }
1211
1212 for hit in results {
1214 hit.raw_score = hit.score;
1216 hit.score /= max_score;
1218 }
1219}
1220
1221impl Default for RecallService {
1222 fn default() -> Self {
1223 Self::new()
1224 }
1225}
1226
1227#[allow(clippy::missing_const_for_fn)] fn create_placeholder_memory(id: MemoryId) -> Memory {
1230 use crate::models::{Domain, Namespace};
1231
1232 Memory {
1233 id,
1234 content: String::new(),
1235 namespace: Namespace::Decisions,
1236 domain: Domain::new(),
1237 project_id: None,
1238 branch: None,
1239 file_path: None,
1240 status: MemoryStatus::Active,
1241 created_at: 0,
1242 updated_at: 0,
1243 tombstoned_at: None,
1244 expires_at: None,
1245 embedding: None,
1246 tags: Vec::new(),
1247 #[cfg(feature = "group-scope")]
1248 group_id: None,
1249 source: None,
1250 is_summary: false,
1251 source_memory_ids: None,
1252 consolidation_timestamp: None,
1253 }
1254}
1255
1256#[cfg(test)]
1257mod tests {
1258 use super::*;
1259 use crate::models::Namespace;
1260
1261 fn create_test_memory(id: &str, content: &str) -> Memory {
1262 use crate::models::Domain;
1263
1264 Memory {
1265 id: MemoryId::new(id),
1266 content: content.to_string(),
1267 namespace: Namespace::Decisions,
1268 domain: Domain::new(),
1269 project_id: None,
1270 branch: None,
1271 file_path: None,
1272 status: MemoryStatus::Active,
1273 created_at: 0,
1274 updated_at: 0,
1275 tombstoned_at: None,
1276 expires_at: None,
1277 embedding: None,
1278 tags: Vec::new(),
1279 #[cfg(feature = "group-scope")]
1280 group_id: None,
1281 source: None,
1282 is_summary: false,
1283 source_memory_ids: None,
1284 consolidation_timestamp: None,
1285 }
1286 }
1287
1288 #[test]
1289 fn test_search_empty_query() {
1290 let service = RecallService::default();
1291 let result = service.search("", SearchMode::Text, &SearchFilter::new(), 10);
1292 assert!(result.is_err());
1293 }
1294
1295 #[test]
1296 fn test_search_no_backend() {
1297 let service = RecallService::default();
1298 let result = service.search("test", SearchMode::Text, &SearchFilter::new(), 10);
1299 assert!(result.is_err());
1300 }
1301
1302 #[test]
1303 fn test_search_with_backend() {
1304 let index = SqliteBackend::in_memory().unwrap();
1305
1306 index
1308 .index(&create_test_memory("id1", "Rust programming language"))
1309 .unwrap();
1310 index
1311 .index(&create_test_memory("id2", "Python scripting"))
1312 .unwrap();
1313
1314 let service = RecallService::with_index(index);
1315
1316 let result = service.search("Rust", SearchMode::Text, &SearchFilter::new(), 10);
1317 assert!(result.is_ok());
1318
1319 let result = result.unwrap();
1320 assert!(!result.memories.is_empty());
1321 }
1322
1323 #[test]
1324 fn test_rrf_fusion() {
1325 let service = RecallService::default();
1326
1327 let text_hits = vec![
1328 SearchHit {
1329 memory: create_test_memory("id1", ""),
1330 score: 0.9,
1331 raw_score: 0.9,
1332 vector_score: None,
1333 bm25_score: Some(0.9),
1334 },
1335 SearchHit {
1336 memory: create_test_memory("id2", ""),
1337 score: 0.8,
1338 raw_score: 0.8,
1339 vector_score: None,
1340 bm25_score: Some(0.8),
1341 },
1342 ];
1343
1344 let vector_hits = vec![
1345 SearchHit {
1346 memory: create_test_memory("id2", ""),
1347 score: 0.95,
1348 raw_score: 0.95,
1349 vector_score: Some(0.95),
1350 bm25_score: None,
1351 },
1352 SearchHit {
1353 memory: create_test_memory("id3", ""),
1354 score: 0.85,
1355 raw_score: 0.85,
1356 vector_score: Some(0.85),
1357 bm25_score: None,
1358 },
1359 ];
1360
1361 let fused = service.rrf_fusion(&text_hits, &vector_hits, 10);
1362
1363 assert!(!fused.is_empty());
1365
1366 let id2_score = fused
1368 .iter()
1369 .find(|h| h.memory.id.as_str() == "id2")
1370 .map(|h| h.score);
1371 let id1_score = fused
1372 .iter()
1373 .find(|h| h.memory.id.as_str() == "id1")
1374 .map(|h| h.score);
1375
1376 assert!(id2_score > id1_score);
1377 }
1378
1379 #[test]
1380 fn test_hybrid_search_mode() {
1381 let result =
1382 RecallService::default().search("test", SearchMode::Hybrid, &SearchFilter::new(), 10);
1383 assert!(result.is_err());
1385 }
1386
1387 #[test]
1388 fn test_vector_search_no_embedder() {
1389 let service = RecallService::default();
1390 let result = service.vector_search("test query", &SearchFilter::new(), 10);
1391
1392 assert!(result.is_ok());
1394 assert!(result.expect("vector_search failed").is_empty());
1395 }
1396
1397 #[test]
1398 fn test_vector_search_no_vector_backend() {
1399 use crate::embedding::FastEmbedEmbedder;
1400
1401 let embedder: Arc<dyn Embedder> = Arc::new(FastEmbedEmbedder::new());
1402 let service = RecallService::new().with_embedder(embedder);
1403
1404 let result = service.vector_search("test query", &SearchFilter::new(), 10);
1405
1406 assert!(result.is_ok());
1408 assert!(result.expect("vector_search failed").is_empty());
1409 }
1410
1411 #[test]
1412 fn test_has_vector_search() {
1413 use crate::embedding::FastEmbedEmbedder;
1414
1415 let service = RecallService::default();
1416 assert!(!service.has_vector_search());
1417
1418 let embedder: Arc<dyn Embedder> = Arc::new(FastEmbedEmbedder::new());
1419 let service_with_embedder = RecallService::new().with_embedder(embedder);
1420 assert!(!service_with_embedder.has_vector_search());
1421 }
1422
1423 #[test]
1424 fn test_with_backends_builder() {
1425 let index = SqliteBackend::in_memory().expect("in_memory failed");
1426 let service = RecallService::with_index(index);
1427
1428 assert!(!service.has_vector_search());
1430 }
1431
1432 #[test]
1433 fn test_hybrid_search_fallback_text_only() {
1434 let index = SqliteBackend::in_memory().expect("in_memory failed");
1435
1436 index
1438 .index(&create_test_memory("id1", "Rust programming language"))
1439 .expect("index failed");
1440
1441 let service = RecallService::with_index(index);
1443
1444 let result = service.search("Rust", SearchMode::Hybrid, &SearchFilter::new(), 10);
1446 assert!(result.is_ok());
1447
1448 let search_result = result.expect("search failed");
1449 assert!(!search_result.memories.is_empty());
1451 }
1452
1453 #[test]
1454 fn test_vector_search_mode_graceful() {
1455 let index = SqliteBackend::in_memory().expect("in_memory failed");
1456 let service = RecallService::with_index(index);
1457
1458 let result = service.search("test", SearchMode::Vector, &SearchFilter::new(), 10);
1460 assert!(result.is_ok());
1461
1462 let search_result = result.expect("search failed");
1463 assert!(search_result.memories.is_empty());
1464 }
1465
1466 #[test]
1467 fn test_rrf_with_empty_vector_results() {
1468 let service = RecallService::default();
1469
1470 let text_hits = vec![SearchHit {
1471 memory: create_test_memory("id1", "content"),
1472 score: 0.9,
1473 raw_score: 0.9,
1474 vector_score: None,
1475 bm25_score: Some(0.9),
1476 }];
1477 let vector_hits: Vec<SearchHit> = vec![]; let fused = service.rrf_fusion(&text_hits, &vector_hits, 10);
1480
1481 assert_eq!(fused.len(), 1);
1483 assert_eq!(fused[0].memory.id.as_str(), "id1");
1484 }
1485
1486 #[test]
1487 fn test_rrf_with_empty_text_results() {
1488 let service = RecallService::default();
1489
1490 let text_hits: Vec<SearchHit> = vec![]; let vector_hits = vec![SearchHit {
1492 memory: create_test_memory("id1", "content"),
1493 score: 0.9,
1494 raw_score: 0.9,
1495 vector_score: Some(0.9),
1496 bm25_score: None,
1497 }];
1498
1499 let fused = service.rrf_fusion(&text_hits, &vector_hits, 10);
1500
1501 assert_eq!(fused.len(), 1);
1503 assert_eq!(fused[0].memory.id.as_str(), "id1");
1504 }
1505
1506 #[test]
1507 fn test_domain_label() {
1508 let filter = SearchFilter::new();
1509 assert_eq!(domain_label(&filter), "all");
1510 }
1511
1512 #[test]
1517 fn test_default_timeout() {
1518 let service = RecallService::new();
1519 assert_eq!(service.timeout_ms(), DEFAULT_SEARCH_TIMEOUT_MS);
1520 assert_eq!(service.timeout_ms(), 5_000);
1521 }
1522
1523 #[test]
1524 fn test_with_timeout_ms() {
1525 let service = RecallService::new().with_timeout_ms(1_000);
1526 assert_eq!(service.timeout_ms(), 1_000);
1527 }
1528
1529 #[test]
1530 fn test_timeout_zero_disables_check() {
1531 let index = SqliteBackend::in_memory().expect("in_memory failed");
1532 index
1533 .index(&create_test_memory("id1", "Rust programming"))
1534 .expect("index failed");
1535
1536 let service = RecallService::with_index(index).with_timeout_ms(0);
1538
1539 let result = service.search("Rust", SearchMode::Text, &SearchFilter::new(), 10);
1540 assert!(
1541 result.is_ok(),
1542 "Search should succeed with timeout disabled"
1543 );
1544 }
1545
1546 #[test]
1547 fn test_timeout_with_index_builder() {
1548 let index = SqliteBackend::in_memory().expect("in_memory failed");
1549 let service = RecallService::with_index(index);
1550
1551 assert_eq!(service.timeout_ms(), DEFAULT_SEARCH_TIMEOUT_MS);
1553 }
1554
1555 #[test]
1556 fn test_timeout_builder_chaining() {
1557 let service = RecallService::new().with_timeout_ms(2_500);
1558
1559 assert_eq!(service.timeout_ms(), 2_500);
1560 }
1561
1562 #[test]
1567 fn test_normalize_scores_max_becomes_one() {
1568 let mut hits = vec![
1569 SearchHit {
1570 memory: create_test_memory("id1", "high score"),
1571 score: 0.033,
1572 raw_score: 0.0,
1573 vector_score: None,
1574 bm25_score: None,
1575 },
1576 SearchHit {
1577 memory: create_test_memory("id2", "low score"),
1578 score: 0.020,
1579 raw_score: 0.0,
1580 vector_score: None,
1581 bm25_score: None,
1582 },
1583 ];
1584
1585 normalize_scores(&mut hits);
1586
1587 assert!(
1589 (hits[0].score - 1.0).abs() < f32::EPSILON,
1590 "Max score should be 1.0"
1591 );
1592 assert!(
1594 (hits[0].raw_score - 0.033).abs() < f32::EPSILON,
1595 "raw_score should be preserved"
1596 );
1597 }
1598
1599 #[test]
1600 fn test_normalize_scores_all_in_range() {
1601 let mut hits = vec![
1602 SearchHit {
1603 memory: create_test_memory("id1", ""),
1604 score: 0.033,
1605 raw_score: 0.0,
1606 vector_score: None,
1607 bm25_score: None,
1608 },
1609 SearchHit {
1610 memory: create_test_memory("id2", ""),
1611 score: 0.020,
1612 raw_score: 0.0,
1613 vector_score: None,
1614 bm25_score: None,
1615 },
1616 SearchHit {
1617 memory: create_test_memory("id3", ""),
1618 score: 0.016,
1619 raw_score: 0.0,
1620 vector_score: None,
1621 bm25_score: None,
1622 },
1623 ];
1624
1625 normalize_scores(&mut hits);
1626
1627 for hit in &hits {
1628 assert!(
1629 hit.score >= 0.0,
1630 "Score should be >= 0.0, got {}",
1631 hit.score
1632 );
1633 assert!(
1634 hit.score <= 1.0,
1635 "Score should be <= 1.0, got {}",
1636 hit.score
1637 );
1638 }
1639 }
1640
1641 #[test]
1642 fn test_normalize_scores_empty_results() {
1643 let mut hits: Vec<SearchHit> = vec![];
1644 normalize_scores(&mut hits);
1645 assert!(hits.is_empty());
1647 }
1648
1649 #[test]
1650 fn test_normalize_scores_single_result() {
1651 let mut hits = vec![SearchHit {
1652 memory: create_test_memory("id1", ""),
1653 score: 0.5,
1654 raw_score: 0.0,
1655 vector_score: None,
1656 bm25_score: None,
1657 }];
1658
1659 normalize_scores(&mut hits);
1660
1661 assert!(
1663 (hits[0].score - 1.0).abs() < f32::EPSILON,
1664 "Single result should have score 1.0"
1665 );
1666 }
1667
1668 #[test]
1669 fn test_normalize_scores_proportions_preserved() {
1670 let mut hits = vec![
1671 SearchHit {
1672 memory: create_test_memory("id1", ""),
1673 score: 0.040,
1674 raw_score: 0.0,
1675 vector_score: None,
1676 bm25_score: None,
1677 },
1678 SearchHit {
1679 memory: create_test_memory("id2", ""),
1680 score: 0.020,
1681 raw_score: 0.0,
1682 vector_score: None,
1683 bm25_score: None,
1684 },
1685 ];
1686
1687 let ratio_before = hits[0].score / hits[1].score;
1689
1690 normalize_scores(&mut hits);
1691
1692 let ratio_after = hits[0].score / hits[1].score;
1694 assert!(
1695 (ratio_before - ratio_after).abs() < 0.001,
1696 "Proportions should be preserved: before={ratio_before}, after={ratio_after}"
1697 );
1698 }
1699
1700 #[test]
1701 fn test_normalize_scores_ordering_preserved() {
1702 let mut hits = vec![
1703 SearchHit {
1704 memory: create_test_memory("id1", ""),
1705 score: 0.033,
1706 raw_score: 0.0,
1707 vector_score: None,
1708 bm25_score: None,
1709 },
1710 SearchHit {
1711 memory: create_test_memory("id2", ""),
1712 score: 0.020,
1713 raw_score: 0.0,
1714 vector_score: None,
1715 bm25_score: None,
1716 },
1717 SearchHit {
1718 memory: create_test_memory("id3", ""),
1719 score: 0.016,
1720 raw_score: 0.0,
1721 vector_score: None,
1722 bm25_score: None,
1723 },
1724 ];
1725
1726 normalize_scores(&mut hits);
1727
1728 assert!(hits[0].score > hits[1].score, "id1 > id2");
1730 assert!(hits[1].score > hits[2].score, "id2 > id3");
1731 }
1732
1733 #[test]
1734 fn test_normalize_scores_zero_scores() {
1735 let mut hits = vec![
1736 SearchHit {
1737 memory: create_test_memory("id1", ""),
1738 score: 0.0,
1739 raw_score: 0.0,
1740 vector_score: None,
1741 bm25_score: None,
1742 },
1743 SearchHit {
1744 memory: create_test_memory("id2", ""),
1745 score: 0.0,
1746 raw_score: 0.0,
1747 vector_score: None,
1748 bm25_score: None,
1749 },
1750 ];
1751
1752 normalize_scores(&mut hits);
1753
1754 assert!(
1756 hits[0].score.abs() < f32::EPSILON,
1757 "Zero score should remain zero"
1758 );
1759 assert!(
1760 hits[1].score.abs() < f32::EPSILON,
1761 "Zero score should remain zero"
1762 );
1763 }
1764
1765 #[test]
1766 fn test_scope_filter_applies_project_id() {
1767 let base = SearchFilter::new().with_project_id("github.com/org/repo");
1768 let service = RecallService::new().with_scope_filter(base);
1769 let filter = SearchFilter::new();
1770 let effective = service.effective_filter(&filter);
1771
1772 assert_eq!(
1773 effective.as_ref().project_id.as_deref(),
1774 Some("github.com/org/repo")
1775 );
1776 }
1777
1778 #[test]
1779 fn test_scope_filter_does_not_override_explicit_project_id() {
1780 let base = SearchFilter::new().with_project_id("github.com/org/repo");
1781 let service = RecallService::new().with_scope_filter(base);
1782 let user_filter = SearchFilter::new().with_project_id("github.com/other/repo");
1783 let effective = service.effective_filter(&user_filter);
1784
1785 assert_eq!(
1786 effective.as_ref().project_id.as_deref(),
1787 Some("github.com/other/repo")
1788 );
1789 }
1790
1791 #[test]
1792 fn test_normalize_scores_raw_score_preserved() {
1793 let mut hits = vec![
1794 SearchHit {
1795 memory: create_test_memory("id1", ""),
1796 score: 0.033,
1797 raw_score: 0.0,
1798 vector_score: None,
1799 bm25_score: None,
1800 },
1801 SearchHit {
1802 memory: create_test_memory("id2", ""),
1803 score: 0.020,
1804 raw_score: 0.0,
1805 vector_score: None,
1806 bm25_score: None,
1807 },
1808 ];
1809
1810 normalize_scores(&mut hits);
1811
1812 assert!(
1814 (hits[0].raw_score - 0.033).abs() < f32::EPSILON,
1815 "raw_score should be 0.033"
1816 );
1817 assert!(
1818 (hits[1].raw_score - 0.020).abs() < f32::EPSILON,
1819 "raw_score should be 0.020"
1820 );
1821 }
1822}
1823
1824#[cfg(test)]
1828mod proptests {
1829 use super::*;
1830 use crate::models::{Domain, Namespace};
1831 use proptest::prelude::*;
1832
1833 fn create_test_memory_prop(id: &str) -> Memory {
1834 Memory {
1835 id: MemoryId::new(id),
1836 content: String::new(),
1837 namespace: Namespace::Decisions,
1838 domain: Domain::new(),
1839 project_id: None,
1840 branch: None,
1841 file_path: None,
1842 status: MemoryStatus::Active,
1843 created_at: 0,
1844 updated_at: 0,
1845 tombstoned_at: None,
1846 expires_at: None,
1847 embedding: None,
1848 tags: Vec::new(),
1849 #[cfg(feature = "group-scope")]
1850 group_id: None,
1851 source: None,
1852 is_summary: false,
1853 source_memory_ids: None,
1854 consolidation_timestamp: None,
1855 }
1856 }
1857
1858 fn score_strategy() -> impl Strategy<Value = f32> {
1860 (1u32..=1_000_000u32).prop_map(|n| n as f32 / 1_000_000.0)
1862 }
1863
1864 proptest! {
1865 #![proptest_config(ProptestConfig::with_cases(100))]
1866
1867 #[test]
1869 fn prop_normalized_scores_in_range(
1870 scores in prop::collection::vec(score_strategy(), 1..20)
1871 ) {
1872 let mut hits: Vec<SearchHit> = scores
1873 .into_iter()
1874 .enumerate()
1875 .map(|(i, score)| SearchHit {
1876 memory: create_test_memory_prop(&format!("id{i}")),
1877 score,
1878 raw_score: 0.0,
1879 vector_score: None,
1880 bm25_score: None,
1881 })
1882 .collect();
1883
1884 normalize_scores(&mut hits);
1885
1886 for hit in &hits {
1887 prop_assert!(
1888 hit.score >= 0.0,
1889 "Score {} should be >= 0.0",
1890 hit.score
1891 );
1892 prop_assert!(
1893 hit.score <= 1.0,
1894 "Score {} should be <= 1.0",
1895 hit.score
1896 );
1897 }
1898 }
1899
1900 #[test]
1902 fn prop_ordering_preserved(
1903 scores in prop::collection::vec(score_strategy(), 2..20)
1904 ) {
1905 let mut hits: Vec<SearchHit> = scores
1906 .iter()
1907 .enumerate()
1908 .map(|(i, &score)| SearchHit {
1909 memory: create_test_memory_prop(&format!("id{i}")),
1910 score,
1911 raw_score: 0.0,
1912 vector_score: None,
1913 bm25_score: None,
1914 })
1915 .collect();
1916
1917 let mut original_order: Vec<_> = scores.iter().enumerate().collect();
1919 original_order.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
1920 let original_ids: Vec<_> = original_order.iter().map(|(i, _)| *i).collect();
1921
1922 normalize_scores(&mut hits);
1923
1924 hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
1926 let normalized_ids: Vec<_> = hits
1927 .iter()
1928 .map(|h| h.memory.id.as_str().strip_prefix("id").unwrap().parse::<usize>().unwrap())
1929 .collect();
1930
1931 prop_assert_eq!(
1932 original_ids,
1933 normalized_ids,
1934 "Score ordering should be preserved"
1935 );
1936 }
1937
1938 #[test]
1940 fn prop_max_score_is_one(
1941 scores in prop::collection::vec(score_strategy(), 1..20)
1942 ) {
1943 let mut hits: Vec<SearchHit> = scores
1944 .into_iter()
1945 .enumerate()
1946 .map(|(i, score)| SearchHit {
1947 memory: create_test_memory_prop(&format!("id{i}")),
1948 score,
1949 raw_score: 0.0,
1950 vector_score: None,
1951 bm25_score: None,
1952 })
1953 .collect();
1954
1955 normalize_scores(&mut hits);
1956
1957 let max_score = hits.iter().map(|h| h.score).fold(0.0_f32, f32::max);
1958 prop_assert!(
1959 (max_score - 1.0).abs() < f32::EPSILON,
1960 "Max score should be 1.0, got {}",
1961 max_score
1962 );
1963 }
1964
1965 #[test]
1967 fn prop_raw_score_preserved(
1968 scores in prop::collection::vec(score_strategy(), 1..20)
1969 ) {
1970 let original_scores = scores.clone();
1971
1972 let mut hits: Vec<SearchHit> = scores
1973 .into_iter()
1974 .enumerate()
1975 .map(|(i, score)| SearchHit {
1976 memory: create_test_memory_prop(&format!("id{i}")),
1977 score,
1978 raw_score: 0.0,
1979 vector_score: None,
1980 bm25_score: None,
1981 })
1982 .collect();
1983
1984 normalize_scores(&mut hits);
1985
1986 for (hit, original) in hits.iter().zip(original_scores.iter()) {
1987 prop_assert!(
1988 (hit.raw_score - original).abs() < f32::EPSILON,
1989 "raw_score {} should equal original {}",
1990 hit.raw_score,
1991 original
1992 );
1993 }
1994 }
1995 }
1996}