1use crate::Result;
56use crate::models::graph::{Entity, EntityId, EntityMention, EntityType};
57use crate::models::{Memory, MemoryId, SearchFilter, SearchMode, SearchResult};
58use crate::services::{EntityExtractorService, GraphService, RecallService};
59use crate::storage::traits::GraphBackend;
60use std::collections::{HashMap, HashSet};
61use std::sync::Arc;
62
63#[derive(Debug, Clone)]
69pub struct GraphRAGConfig {
70 pub max_depth: usize,
72 pub expansion_boost: f32,
74 pub max_query_entities: usize,
76 pub max_expansion_results: usize,
78 pub min_entity_confidence: f32,
80 pub include_relationship_context: bool,
82}
83
84impl Default for GraphRAGConfig {
85 fn default() -> Self {
86 Self {
87 max_depth: 2,
88 expansion_boost: 1.2,
89 max_query_entities: 5,
90 max_expansion_results: 10,
91 min_entity_confidence: 0.5,
92 include_relationship_context: true,
93 }
94 }
95}
96
97impl GraphRAGConfig {
98 #[must_use]
100 pub fn new() -> Self {
101 Self::default()
102 }
103
104 #[must_use]
106 pub const fn with_max_depth(mut self, depth: usize) -> Self {
107 self.max_depth = depth;
108 self
109 }
110
111 #[must_use]
113 pub const fn with_expansion_boost(mut self, boost: f32) -> Self {
114 self.expansion_boost = boost;
115 self
116 }
117
118 #[must_use]
120 pub const fn with_max_query_entities(mut self, max: usize) -> Self {
121 self.max_query_entities = max;
122 self
123 }
124
125 #[must_use]
127 pub const fn with_max_expansion_results(mut self, max: usize) -> Self {
128 self.max_expansion_results = max;
129 self
130 }
131
132 #[must_use]
134 pub fn from_env() -> Self {
135 let mut config = Self::default();
136
137 if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_MAX_DEPTH")
138 && let Ok(depth) = val.parse()
139 {
140 config.max_depth = depth;
141 }
142
143 if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_EXPANSION_BOOST")
144 && let Ok(boost) = val.parse()
145 {
146 config.expansion_boost = boost;
147 }
148
149 if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_MAX_QUERY_ENTITIES")
150 && let Ok(max) = val.parse()
151 {
152 config.max_query_entities = max;
153 }
154
155 if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_MAX_EXPANSION_RESULTS")
156 && let Ok(max) = val.parse()
157 {
158 config.max_expansion_results = max;
159 }
160
161 config
162 }
163}
164
165#[derive(Debug, Clone)]
167pub struct ExpansionConfig {
168 pub depth: Option<usize>,
170 pub entity_type_filter: Option<Vec<EntityType>>,
172 pub use_relationship_weight: bool,
174}
175
176impl Default for ExpansionConfig {
177 fn default() -> Self {
178 Self {
179 depth: None,
180 entity_type_filter: None,
181 use_relationship_weight: true,
182 }
183 }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum SearchProvenance {
193 Semantic,
195 GraphExpansion {
197 source_entity: EntityId,
199 hop_count: usize,
201 },
202 Both {
204 semantic_score: u32, source_entity: EntityId,
208 },
209}
210
211#[derive(Debug, Clone)]
213pub struct GraphSearchHit {
214 pub memory: Memory,
216 pub score: f32,
218 pub provenance: SearchProvenance,
220 pub related_entities: Vec<EntityId>,
222}
223
224#[derive(Debug)]
226pub struct GraphSearchResults {
227 pub query: String,
229 pub hits: Vec<GraphSearchHit>,
231 pub semantic_count: usize,
233 pub graph_count: usize,
235 pub query_entities: Vec<String>,
237}
238
239impl GraphSearchResults {
240 #[must_use]
242 pub const fn len(&self) -> usize {
243 self.hits.len()
244 }
245
246 #[must_use]
248 pub const fn is_empty(&self) -> bool {
249 self.hits.is_empty()
250 }
251
252 #[must_use]
254 pub fn semantic_hits(&self) -> Vec<&GraphSearchHit> {
255 self.hits
256 .iter()
257 .filter(|h| matches!(h.provenance, SearchProvenance::Semantic))
258 .collect()
259 }
260
261 #[must_use]
263 pub fn graph_hits(&self) -> Vec<&GraphSearchHit> {
264 self.hits
265 .iter()
266 .filter(|h| matches!(h.provenance, SearchProvenance::GraphExpansion { .. }))
267 .collect()
268 }
269
270 #[must_use]
272 pub fn hybrid_hits(&self) -> Vec<&GraphSearchHit> {
273 self.hits
274 .iter()
275 .filter(|h| matches!(h.provenance, SearchProvenance::Both { .. }))
276 .collect()
277 }
278}
279
280pub struct GraphRAGService<G: GraphBackend> {
295 recall: Arc<RecallService>,
297 graph: Arc<GraphService<G>>,
299 extractor: EntityExtractorService,
301 config: GraphRAGConfig,
303}
304
305impl<G: GraphBackend> GraphRAGService<G> {
306 pub const fn new(
315 recall: Arc<RecallService>,
316 graph: Arc<GraphService<G>>,
317 extractor: EntityExtractorService,
318 config: GraphRAGConfig,
319 ) -> Self {
320 Self {
321 recall,
322 graph,
323 extractor,
324 config,
325 }
326 }
327
328 pub fn search_with_expansion(
352 &self,
353 query: &str,
354 filter: &SearchFilter,
355 limit: usize,
356 expansion: Option<ExpansionConfig>,
357 ) -> Result<GraphSearchResults> {
358 let expansion = expansion.unwrap_or_default();
359
360 let semantic_results = self
362 .recall
363 .search(query, SearchMode::Hybrid, filter, limit)?;
364 let semantic_count = semantic_results.memories.len();
365
366 let extraction = self.extractor.extract(query)?;
368 let query_entities: Vec<String> = extraction
369 .entities
370 .iter()
371 .filter(|e| e.confidence >= self.config.min_entity_confidence)
372 .take(self.config.max_query_entities)
373 .map(|e| e.name.clone())
374 .collect();
375
376 let graph_memories = self.expand_from_entities(&query_entities, &expansion)?;
378 let graph_count = graph_memories.len();
379
380 let hits = self.merge_results(semantic_results, graph_memories, &expansion)?;
382
383 let hits: Vec<GraphSearchHit> = hits.into_iter().take(limit).collect();
385
386 Ok(GraphSearchResults {
387 query: query.to_string(),
388 hits,
389 semantic_count,
390 graph_count,
391 query_entities,
392 })
393 }
394
395 fn expand_from_entities(
397 &self,
398 entity_names: &[String],
399 expansion: &ExpansionConfig,
400 ) -> Result<HashMap<MemoryId, (f32, EntityId, usize)>> {
401 let mut results: HashMap<MemoryId, (f32, EntityId, usize)> = HashMap::new();
402 let depth = expansion.depth.unwrap_or(self.config.max_depth);
403
404 for name in entity_names {
405 let entities = self.find_entities_by_name(name)?;
406 for entity in entities {
407 self.expand_single_entity(&entity, depth, expansion, &mut results)?;
408 }
409 }
410
411 Ok(results)
412 }
413
414 fn expand_single_entity(
416 &self,
417 entity: &Entity,
418 depth: usize,
419 expansion: &ExpansionConfig,
420 results: &mut HashMap<MemoryId, (f32, EntityId, usize)>,
421 ) -> Result<()> {
422 let related = self.traverse_entity(&entity.id, depth)?;
423
424 for (related_entity, hop_count) in related {
425 let memory_ids = self.get_entity_memory_links(&related_entity)?;
426 self.score_and_insert_memories(
427 &memory_ids,
428 &entity.id,
429 hop_count,
430 expansion.use_relationship_weight,
431 results,
432 );
433 }
434 Ok(())
435 }
436
437 fn score_and_insert_memories(
439 &self,
440 memory_ids: &[MemoryId],
441 source_entity: &EntityId,
442 hop_count: usize,
443 use_weight: bool,
444 results: &mut HashMap<MemoryId, (f32, EntityId, usize)>,
445 ) {
446 #[allow(clippy::cast_precision_loss)]
447 let base_score = 1.0 / (1.0 + hop_count as f32);
448 let score = if use_weight {
449 base_score * self.config.expansion_boost
450 } else {
451 base_score
452 };
453
454 for memory_id in memory_ids {
455 Self::update_or_insert_memory(results, memory_id, score, source_entity, hop_count);
456 }
457 }
458
459 fn update_or_insert_memory(
461 results: &mut HashMap<MemoryId, (f32, EntityId, usize)>,
462 memory_id: &MemoryId,
463 score: f32,
464 source_entity: &EntityId,
465 hop_count: usize,
466 ) {
467 results
468 .entry(memory_id.clone())
469 .and_modify(|(existing_score, _, existing_hops)| {
470 if hop_count < *existing_hops {
471 *existing_score = score;
472 *existing_hops = hop_count;
473 }
474 })
475 .or_insert((score, source_entity.clone(), hop_count));
476 }
477
478 fn find_entities_by_name(&self, name: &str) -> Result<Vec<Entity>> {
480 use crate::models::graph::EntityQuery;
481
482 let query = EntityQuery::new().with_name(name).with_limit(10);
483
484 self.graph.query_entities(&query)
485 }
486
487 fn traverse_entity(
489 &self,
490 entity_id: &EntityId,
491 max_depth: usize,
492 ) -> Result<Vec<(EntityId, usize)>> {
493 let mut visited: HashSet<String> = HashSet::new();
494 let mut result: Vec<(EntityId, usize)> = Vec::new();
495 let mut frontier: Vec<(EntityId, usize)> = vec![(entity_id.clone(), 0)];
496
497 while let Some((current_id, depth)) = frontier.pop() {
498 if depth > max_depth {
499 continue;
500 }
501
502 let id_str = current_id.as_ref().to_string();
503 if visited.contains(&id_str) {
504 continue;
505 }
506 visited.insert(id_str);
507
508 if depth > 0 {
509 result.push((current_id.clone(), depth));
510 }
511
512 if depth < max_depth {
514 self.add_neighbors_to_frontier(¤t_id, depth, &mut frontier)?;
515 }
516 }
517
518 Ok(result)
519 }
520
521 fn add_neighbors_to_frontier(
523 &self,
524 entity_id: &EntityId,
525 current_depth: usize,
526 frontier: &mut Vec<(EntityId, usize)>,
527 ) -> Result<()> {
528 let neighbors = self.graph.get_neighbors(entity_id, 1)?;
529 for neighbor in neighbors {
530 frontier.push((neighbor.id.clone(), current_depth + 1));
531 }
532 Ok(())
533 }
534
535 fn get_entity_memory_links(&self, entity_id: &EntityId) -> Result<Vec<MemoryId>> {
537 let mentions: Vec<EntityMention> = self.graph.get_mentions(entity_id)?;
539
540 Ok(mentions.into_iter().map(|m| m.memory_id).collect())
541 }
542
543 fn merge_results(
545 &self,
546 semantic: SearchResult,
547 graph: HashMap<MemoryId, (f32, EntityId, usize)>,
548 _expansion: &ExpansionConfig,
549 ) -> Result<Vec<GraphSearchHit>> {
550 let mut hits: HashMap<String, GraphSearchHit> = HashMap::new();
551
552 for memory_hit in semantic.memories {
554 let id = memory_hit.memory.id.as_str().to_string();
555 hits.insert(
556 id,
557 GraphSearchHit {
558 memory: memory_hit.memory,
559 score: memory_hit.score,
560 provenance: SearchProvenance::Semantic,
561 related_entities: Vec::new(),
562 },
563 );
564 }
565
566 for (memory_id, (graph_score, source_entity, hop_count)) in graph {
568 let id = memory_id.as_str().to_string();
569 self.merge_single_graph_result(
570 &mut hits,
571 id,
572 &memory_id,
573 graph_score,
574 source_entity,
575 hop_count,
576 );
577 }
578
579 let mut result: Vec<GraphSearchHit> = hits.into_values().collect();
581 result.sort_by(|a, b| {
582 b.score
583 .partial_cmp(&a.score)
584 .unwrap_or(std::cmp::Ordering::Equal)
585 });
586
587 Ok(result)
588 }
589
590 fn merge_single_graph_result(
592 &self,
593 hits: &mut HashMap<String, GraphSearchHit>,
594 id: String,
595 memory_id: &MemoryId,
596 graph_score: f32,
597 source_entity: EntityId,
598 hop_count: usize,
599 ) {
600 if let Some(existing) = hits.get_mut(&id) {
601 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
603 let semantic_score_int = (existing.score.abs() * 1000.0) as u32;
604 existing.provenance = SearchProvenance::Both {
605 semantic_score: semantic_score_int,
606 source_entity: source_entity.clone(),
607 };
608 existing.score =
610 f32::midpoint(existing.score, graph_score) * self.config.expansion_boost;
611 existing.related_entities.push(source_entity);
612 return;
613 }
614
615 let Ok(Some(memory)) = self.recall.get_by_id(memory_id) else {
617 return;
618 };
619
620 hits.insert(
621 id,
622 GraphSearchHit {
623 memory,
624 score: graph_score,
625 provenance: SearchProvenance::GraphExpansion {
626 source_entity: source_entity.clone(),
627 hop_count,
628 },
629 related_entities: vec![source_entity],
630 },
631 );
632 }
633
634 pub fn search_semantic_only(
642 &self,
643 query: &str,
644 filter: &SearchFilter,
645 limit: usize,
646 ) -> Result<GraphSearchResults> {
647 let semantic_results = self
648 .recall
649 .search(query, SearchMode::Hybrid, filter, limit)?;
650
651 let hits: Vec<GraphSearchHit> = semantic_results
652 .memories
653 .into_iter()
654 .map(|h| GraphSearchHit {
655 memory: h.memory,
656 score: h.score,
657 provenance: SearchProvenance::Semantic,
658 related_entities: Vec::new(),
659 })
660 .collect();
661
662 let count = hits.len();
663
664 Ok(GraphSearchResults {
665 query: query.to_string(),
666 hits,
667 semantic_count: count,
668 graph_count: 0,
669 query_entities: Vec::new(),
670 })
671 }
672
673 pub fn search_graph_only(
681 &self,
682 query: &str,
683 limit: usize,
684 expansion: Option<ExpansionConfig>,
685 ) -> Result<GraphSearchResults> {
686 let expansion = expansion.unwrap_or_default();
687
688 let extraction = self.extractor.extract(query)?;
690 let query_entities: Vec<String> = extraction
691 .entities
692 .iter()
693 .filter(|e| e.confidence >= self.config.min_entity_confidence)
694 .take(self.config.max_query_entities)
695 .map(|e| e.name.clone())
696 .collect();
697
698 let graph_memories = self.expand_from_entities(&query_entities, &expansion)?;
700 let graph_count = graph_memories.len();
701
702 let mut hits: Vec<GraphSearchHit> = Vec::new();
704 for (memory_id, (score, source_entity, hop_count)) in graph_memories {
705 if let Ok(Some(memory)) = self.recall.get_by_id(&memory_id) {
706 hits.push(GraphSearchHit {
707 memory,
708 score,
709 provenance: SearchProvenance::GraphExpansion {
710 source_entity: source_entity.clone(),
711 hop_count,
712 },
713 related_entities: vec![source_entity],
714 });
715 }
716 }
717
718 hits.sort_by(|a, b| {
720 b.score
721 .partial_cmp(&a.score)
722 .unwrap_or(std::cmp::Ordering::Equal)
723 });
724 hits.truncate(limit);
725
726 Ok(GraphSearchResults {
727 query: query.to_string(),
728 hits,
729 semantic_count: 0,
730 graph_count,
731 query_entities,
732 })
733 }
734}
735
736#[cfg(test)]
741mod tests {
742 use super::*;
743
744 #[test]
747 fn test_config_defaults() {
748 let config = GraphRAGConfig::default();
749 assert_eq!(config.max_depth, 2);
750 assert!((config.expansion_boost - 1.2).abs() < f32::EPSILON);
751 assert_eq!(config.max_query_entities, 5);
752 assert_eq!(config.max_expansion_results, 10);
753 }
754
755 #[test]
756 fn test_config_builder() {
757 let config = GraphRAGConfig::new()
758 .with_max_depth(3)
759 .with_expansion_boost(1.5)
760 .with_max_query_entities(10)
761 .with_max_expansion_results(20);
762
763 assert_eq!(config.max_depth, 3);
764 assert!((config.expansion_boost - 1.5).abs() < f32::EPSILON);
765 assert_eq!(config.max_query_entities, 10);
766 assert_eq!(config.max_expansion_results, 20);
767 }
768
769 #[test]
770 fn test_expansion_config_defaults() {
771 let config = ExpansionConfig::default();
772 assert!(config.depth.is_none());
773 assert!(config.entity_type_filter.is_none());
774 assert!(config.use_relationship_weight);
775 }
776
777 #[test]
780 fn test_provenance_semantic() {
781 let provenance = SearchProvenance::Semantic;
782 assert!(matches!(provenance, SearchProvenance::Semantic));
783 }
784
785 #[test]
786 fn test_provenance_graph_expansion() {
787 let provenance = SearchProvenance::GraphExpansion {
788 source_entity: EntityId::new("e123"),
789 hop_count: 2,
790 };
791 let SearchProvenance::GraphExpansion {
792 source_entity,
793 hop_count,
794 } = provenance
795 else {
796 unreachable!("Expected GraphExpansion variant");
797 };
798 assert_eq!(source_entity.as_ref(), "e123");
799 assert_eq!(hop_count, 2);
800 }
801
802 #[test]
803 fn test_provenance_both() {
804 let provenance = SearchProvenance::Both {
805 semantic_score: 850,
806 source_entity: EntityId::new("e456"),
807 };
808 let SearchProvenance::Both {
809 semantic_score,
810 source_entity,
811 } = provenance
812 else {
813 unreachable!("Expected Both variant");
814 };
815 assert_eq!(semantic_score, 850);
816 assert_eq!(source_entity.as_ref(), "e456");
817 }
818
819 #[test]
822 fn test_results_empty() {
823 let results = GraphSearchResults {
824 query: "test".to_string(),
825 hits: Vec::new(),
826 semantic_count: 0,
827 graph_count: 0,
828 query_entities: Vec::new(),
829 };
830 assert!(results.is_empty());
831 assert_eq!(results.len(), 0);
832 }
833
834 #[test]
835 fn test_results_counts() {
836 let results = GraphSearchResults {
837 query: "test".to_string(),
838 hits: Vec::new(),
839 semantic_count: 10,
840 graph_count: 5,
841 query_entities: vec!["Rust".to_string(), "PostgreSQL".to_string()],
842 };
843 assert_eq!(results.semantic_count, 10);
844 assert_eq!(results.graph_count, 5);
845 assert_eq!(results.query_entities.len(), 2);
846 }
847}