Skip to main content

subcog/services/
graph_rag.rs

1//! Graph RAG (Retrieval-Augmented Generation) Service.
2//!
3//! Provides hybrid search that combines traditional semantic/text search
4//! with knowledge graph expansion for enhanced memory recall.
5//!
6//! # Architecture
7//!
8//! ```text
9//! User Query: "How do we handle auth?"
10//!     │
11//!     ▼
12//! GraphRAGService.search_with_expansion()
13//!     │
14//!     ├──▶ RecallService.search() → 10 memories (semantic)
15//!     │
16//!     └──▶ EntityExtractorService.extract_from_query("auth")
17//!              │
18//!              ▼
19//!          ["AuthService", "JWT", "OAuth"]
20//!              │
21//!              ▼
22//!          GraphService.traverse(depth=2)
23//!              │
24//!              ▼
25//!          Related entities + their source_memory_ids
26//!              │
27//!              ▼
28//!          5 additional memories via graph
29//!     │
30//!     ▼
31//! Merge + Re-rank (boost graph-based by config.expansion_boost)
32//!     │
33//!     ▼
34//! Return 15 memories with provenance
35//! ```
36//!
37//! # Example
38//!
39//! ```rust,ignore
40//! use subcog::services::{GraphRAGService, GraphRAGConfig, ExpansionConfig};
41//!
42//! let service = GraphRAGService::new(recall, graph, config);
43//!
44//! let results = service.search_with_expansion(
45//!     "authentication patterns",
46//!     &SearchFilter::new(),
47//!     ExpansionConfig::default(),
48//! )?;
49//!
50//! for hit in results.memories {
51//!     println!("{}: {} (provenance: {:?})", hit.memory.id, hit.score, hit.provenance);
52//! }
53//! ```
54
55use crate::Result;
56use crate::models::graph::{Entity, EntityId, EntityMention, EntityType};
57use crate::models::{Memory, MemoryId, SearchFilter, SearchMode, SearchResult};
58use crate::services::{EntityExtractorService, GraphService, RecallService};
59use crate::storage::traits::GraphBackend;
60use std::collections::{HashMap, HashSet};
61use std::sync::Arc;
62
63// ============================================================================
64// Configuration
65// ============================================================================
66
67/// Configuration for Graph RAG service.
68#[derive(Debug, Clone)]
69pub struct GraphRAGConfig {
70    /// Maximum depth for graph traversal during expansion.
71    pub max_depth: usize,
72    /// Boost factor for graph-sourced memories (1.0 = no boost).
73    pub expansion_boost: f32,
74    /// Maximum entities to extract from query.
75    pub max_query_entities: usize,
76    /// Maximum additional memories to retrieve via graph expansion.
77    pub max_expansion_results: usize,
78    /// Minimum confidence for entity extraction.
79    pub min_entity_confidence: f32,
80    /// Whether to include relationship context in results.
81    pub include_relationship_context: bool,
82}
83
84impl Default for GraphRAGConfig {
85    fn default() -> Self {
86        Self {
87            max_depth: 2,
88            expansion_boost: 1.2,
89            max_query_entities: 5,
90            max_expansion_results: 10,
91            min_entity_confidence: 0.5,
92            include_relationship_context: true,
93        }
94    }
95}
96
97impl GraphRAGConfig {
98    /// Creates a new configuration with default values.
99    #[must_use]
100    pub fn new() -> Self {
101        Self::default()
102    }
103
104    /// Sets the maximum traversal depth.
105    #[must_use]
106    pub const fn with_max_depth(mut self, depth: usize) -> Self {
107        self.max_depth = depth;
108        self
109    }
110
111    /// Sets the expansion boost factor.
112    #[must_use]
113    pub const fn with_expansion_boost(mut self, boost: f32) -> Self {
114        self.expansion_boost = boost;
115        self
116    }
117
118    /// Sets the maximum query entities.
119    #[must_use]
120    pub const fn with_max_query_entities(mut self, max: usize) -> Self {
121        self.max_query_entities = max;
122        self
123    }
124
125    /// Sets the maximum expansion results.
126    #[must_use]
127    pub const fn with_max_expansion_results(mut self, max: usize) -> Self {
128        self.max_expansion_results = max;
129        self
130    }
131
132    /// Loads configuration from environment variables.
133    #[must_use]
134    pub fn from_env() -> Self {
135        let mut config = Self::default();
136
137        if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_MAX_DEPTH")
138            && let Ok(depth) = val.parse()
139        {
140            config.max_depth = depth;
141        }
142
143        if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_EXPANSION_BOOST")
144            && let Ok(boost) = val.parse()
145        {
146            config.expansion_boost = boost;
147        }
148
149        if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_MAX_QUERY_ENTITIES")
150            && let Ok(max) = val.parse()
151        {
152            config.max_query_entities = max;
153        }
154
155        if let Ok(val) = std::env::var("SUBCOG_GRAPH_RAG_MAX_EXPANSION_RESULTS")
156            && let Ok(max) = val.parse()
157        {
158            config.max_expansion_results = max;
159        }
160
161        config
162    }
163}
164
165/// Configuration for a specific expansion operation.
166#[derive(Debug, Clone)]
167pub struct ExpansionConfig {
168    /// Traversal depth for this expansion (overrides default).
169    pub depth: Option<usize>,
170    /// Entity types to prioritize during expansion.
171    pub entity_type_filter: Option<Vec<EntityType>>,
172    /// Whether to boost results based on relationship strength.
173    pub use_relationship_weight: bool,
174}
175
176impl Default for ExpansionConfig {
177    fn default() -> Self {
178        Self {
179            depth: None,
180            entity_type_filter: None,
181            use_relationship_weight: true,
182        }
183    }
184}
185
186// ============================================================================
187// Provenance Tracking
188// ============================================================================
189
190/// Indicates how a memory was discovered.
191#[derive(Debug, Clone, PartialEq, Eq)]
192pub enum SearchProvenance {
193    /// Found via traditional semantic/text search.
194    Semantic,
195    /// Found via graph expansion.
196    GraphExpansion {
197        /// The entity that linked to this memory.
198        source_entity: EntityId,
199        /// The relationship path length.
200        hop_count: usize,
201    },
202    /// Found via both semantic search and graph expansion.
203    Both {
204        /// Semantic search score.
205        semantic_score: u32, // Using u32 to represent f32 * 1000 for Eq
206        /// Graph expansion details.
207        source_entity: EntityId,
208    },
209}
210
211/// A search result with provenance information.
212#[derive(Debug, Clone)]
213pub struct GraphSearchHit {
214    /// The memory that was found.
215    pub memory: Memory,
216    /// The relevance score (0.0 to 1.0).
217    pub score: f32,
218    /// How this memory was discovered.
219    pub provenance: SearchProvenance,
220    /// Related entities found via graph (if any).
221    pub related_entities: Vec<EntityId>,
222}
223
224/// Results from a Graph RAG search.
225#[derive(Debug)]
226pub struct GraphSearchResults {
227    /// The search query.
228    pub query: String,
229    /// All matched memories with provenance.
230    pub hits: Vec<GraphSearchHit>,
231    /// Total semantic results before merging.
232    pub semantic_count: usize,
233    /// Total graph expansion results before merging.
234    pub graph_count: usize,
235    /// Entities extracted from the query.
236    pub query_entities: Vec<String>,
237}
238
239impl GraphSearchResults {
240    /// Returns the total number of hits.
241    #[must_use]
242    pub const fn len(&self) -> usize {
243        self.hits.len()
244    }
245
246    /// Returns whether there are no hits.
247    #[must_use]
248    pub const fn is_empty(&self) -> bool {
249        self.hits.is_empty()
250    }
251
252    /// Returns hits found via semantic search only.
253    #[must_use]
254    pub fn semantic_hits(&self) -> Vec<&GraphSearchHit> {
255        self.hits
256            .iter()
257            .filter(|h| matches!(h.provenance, SearchProvenance::Semantic))
258            .collect()
259    }
260
261    /// Returns hits found via graph expansion only.
262    #[must_use]
263    pub fn graph_hits(&self) -> Vec<&GraphSearchHit> {
264        self.hits
265            .iter()
266            .filter(|h| matches!(h.provenance, SearchProvenance::GraphExpansion { .. }))
267            .collect()
268    }
269
270    /// Returns hits found via both methods.
271    #[must_use]
272    pub fn hybrid_hits(&self) -> Vec<&GraphSearchHit> {
273        self.hits
274            .iter()
275            .filter(|h| matches!(h.provenance, SearchProvenance::Both { .. }))
276            .collect()
277    }
278}
279
280// ============================================================================
281// Graph RAG Service
282// ============================================================================
283
284/// Service for hybrid search combining semantic search with graph expansion.
285///
286/// The Graph RAG service enhances traditional memory recall by:
287/// 1. Extracting entities from the user's query
288/// 2. Expanding the knowledge graph to find related entities
289/// 3. Retrieving memories linked to those entities
290/// 4. Merging and re-ranking results from both sources
291///
292/// This provides contextually richer results that leverage the connections
293/// between concepts, people, and technologies in the knowledge graph.
294pub struct GraphRAGService<G: GraphBackend> {
295    /// The recall service for semantic/text search.
296    recall: Arc<RecallService>,
297    /// The graph service for knowledge graph operations.
298    graph: Arc<GraphService<G>>,
299    /// The entity extractor for query analysis.
300    extractor: EntityExtractorService,
301    /// Configuration for the service.
302    config: GraphRAGConfig,
303}
304
305impl<G: GraphBackend> GraphRAGService<G> {
306    /// Creates a new Graph RAG service.
307    ///
308    /// # Arguments
309    ///
310    /// * `recall` - The recall service for semantic search.
311    /// * `graph` - The graph service for knowledge graph operations.
312    /// * `extractor` - The entity extractor for query analysis.
313    /// * `config` - Configuration for the service.
314    pub const fn new(
315        recall: Arc<RecallService>,
316        graph: Arc<GraphService<G>>,
317        extractor: EntityExtractorService,
318        config: GraphRAGConfig,
319    ) -> Self {
320        Self {
321            recall,
322            graph,
323            extractor,
324            config,
325        }
326    }
327
328    /// Performs a hybrid search with graph expansion.
329    ///
330    /// This method:
331    /// 1. Runs traditional semantic/text search via `RecallService`
332    /// 2. Extracts entities from the query
333    /// 3. Traverses the knowledge graph to find related entities
334    /// 4. Retrieves memories linked to those entities
335    /// 5. Merges and re-ranks all results
336    ///
337    /// # Arguments
338    ///
339    /// * `query` - The search query.
340    /// * `filter` - Search filter to apply.
341    /// * `limit` - Maximum number of results to return.
342    /// * `expansion` - Optional expansion configuration.
343    ///
344    /// # Returns
345    ///
346    /// A [`GraphSearchResults`] containing all matched memories with provenance.
347    ///
348    /// # Errors
349    ///
350    /// Returns an error if search or graph operations fail.
351    pub fn search_with_expansion(
352        &self,
353        query: &str,
354        filter: &SearchFilter,
355        limit: usize,
356        expansion: Option<ExpansionConfig>,
357    ) -> Result<GraphSearchResults> {
358        let expansion = expansion.unwrap_or_default();
359
360        // Step 1: Run semantic search
361        let semantic_results = self
362            .recall
363            .search(query, SearchMode::Hybrid, filter, limit)?;
364        let semantic_count = semantic_results.memories.len();
365
366        // Step 2: Extract entities from query
367        let extraction = self.extractor.extract(query)?;
368        let query_entities: Vec<String> = extraction
369            .entities
370            .iter()
371            .filter(|e| e.confidence >= self.config.min_entity_confidence)
372            .take(self.config.max_query_entities)
373            .map(|e| e.name.clone())
374            .collect();
375
376        // Step 3: Find matching entities in graph and expand
377        let graph_memories = self.expand_from_entities(&query_entities, &expansion)?;
378        let graph_count = graph_memories.len();
379
380        // Step 4: Merge and re-rank results
381        let hits = self.merge_results(semantic_results, graph_memories, &expansion)?;
382
383        // Step 5: Apply final limit
384        let hits: Vec<GraphSearchHit> = hits.into_iter().take(limit).collect();
385
386        Ok(GraphSearchResults {
387            query: query.to_string(),
388            hits,
389            semantic_count,
390            graph_count,
391            query_entities,
392        })
393    }
394
395    /// Expands the graph from extracted entity names.
396    fn expand_from_entities(
397        &self,
398        entity_names: &[String],
399        expansion: &ExpansionConfig,
400    ) -> Result<HashMap<MemoryId, (f32, EntityId, usize)>> {
401        let mut results: HashMap<MemoryId, (f32, EntityId, usize)> = HashMap::new();
402        let depth = expansion.depth.unwrap_or(self.config.max_depth);
403
404        for name in entity_names {
405            let entities = self.find_entities_by_name(name)?;
406            for entity in entities {
407                self.expand_single_entity(&entity, depth, expansion, &mut results)?;
408            }
409        }
410
411        Ok(results)
412    }
413
414    /// Expands a single entity and collects memory links.
415    fn expand_single_entity(
416        &self,
417        entity: &Entity,
418        depth: usize,
419        expansion: &ExpansionConfig,
420        results: &mut HashMap<MemoryId, (f32, EntityId, usize)>,
421    ) -> Result<()> {
422        let related = self.traverse_entity(&entity.id, depth)?;
423
424        for (related_entity, hop_count) in related {
425            let memory_ids = self.get_entity_memory_links(&related_entity)?;
426            self.score_and_insert_memories(
427                &memory_ids,
428                &entity.id,
429                hop_count,
430                expansion.use_relationship_weight,
431                results,
432            );
433        }
434        Ok(())
435    }
436
437    /// Calculates scores and inserts memory links into results.
438    fn score_and_insert_memories(
439        &self,
440        memory_ids: &[MemoryId],
441        source_entity: &EntityId,
442        hop_count: usize,
443        use_weight: bool,
444        results: &mut HashMap<MemoryId, (f32, EntityId, usize)>,
445    ) {
446        #[allow(clippy::cast_precision_loss)]
447        let base_score = 1.0 / (1.0 + hop_count as f32);
448        let score = if use_weight {
449            base_score * self.config.expansion_boost
450        } else {
451            base_score
452        };
453
454        for memory_id in memory_ids {
455            Self::update_or_insert_memory(results, memory_id, score, source_entity, hop_count);
456        }
457    }
458
459    /// Updates an existing memory entry or inserts a new one.
460    fn update_or_insert_memory(
461        results: &mut HashMap<MemoryId, (f32, EntityId, usize)>,
462        memory_id: &MemoryId,
463        score: f32,
464        source_entity: &EntityId,
465        hop_count: usize,
466    ) {
467        results
468            .entry(memory_id.clone())
469            .and_modify(|(existing_score, _, existing_hops)| {
470                if hop_count < *existing_hops {
471                    *existing_score = score;
472                    *existing_hops = hop_count;
473                }
474            })
475            .or_insert((score, source_entity.clone(), hop_count));
476    }
477
478    /// Finds entities by name (case-insensitive search).
479    fn find_entities_by_name(&self, name: &str) -> Result<Vec<Entity>> {
480        use crate::models::graph::EntityQuery;
481
482        let query = EntityQuery::new().with_name(name).with_limit(10);
483
484        self.graph.query_entities(&query)
485    }
486
487    /// Traverses the graph from an entity up to the specified depth.
488    fn traverse_entity(
489        &self,
490        entity_id: &EntityId,
491        max_depth: usize,
492    ) -> Result<Vec<(EntityId, usize)>> {
493        let mut visited: HashSet<String> = HashSet::new();
494        let mut result: Vec<(EntityId, usize)> = Vec::new();
495        let mut frontier: Vec<(EntityId, usize)> = vec![(entity_id.clone(), 0)];
496
497        while let Some((current_id, depth)) = frontier.pop() {
498            if depth > max_depth {
499                continue;
500            }
501
502            let id_str = current_id.as_ref().to_string();
503            if visited.contains(&id_str) {
504                continue;
505            }
506            visited.insert(id_str);
507
508            if depth > 0 {
509                result.push((current_id.clone(), depth));
510            }
511
512            // Get relationships from this entity
513            if depth < max_depth {
514                self.add_neighbors_to_frontier(&current_id, depth, &mut frontier)?;
515            }
516        }
517
518        Ok(result)
519    }
520
521    /// Adds neighbors of an entity to the traversal frontier.
522    fn add_neighbors_to_frontier(
523        &self,
524        entity_id: &EntityId,
525        current_depth: usize,
526        frontier: &mut Vec<(EntityId, usize)>,
527    ) -> Result<()> {
528        let neighbors = self.graph.get_neighbors(entity_id, 1)?;
529        for neighbor in neighbors {
530            frontier.push((neighbor.id.clone(), current_depth + 1));
531        }
532        Ok(())
533    }
534
535    /// Gets memory IDs linked to an entity via mentions.
536    fn get_entity_memory_links(&self, entity_id: &EntityId) -> Result<Vec<MemoryId>> {
537        // Get mentions to find memories that reference this entity
538        let mentions: Vec<EntityMention> = self.graph.get_mentions(entity_id)?;
539
540        Ok(mentions.into_iter().map(|m| m.memory_id).collect())
541    }
542
543    /// Merges semantic results with graph expansion results.
544    fn merge_results(
545        &self,
546        semantic: SearchResult,
547        graph: HashMap<MemoryId, (f32, EntityId, usize)>,
548        _expansion: &ExpansionConfig,
549    ) -> Result<Vec<GraphSearchHit>> {
550        let mut hits: HashMap<String, GraphSearchHit> = HashMap::new();
551
552        // Add semantic results
553        for memory_hit in semantic.memories {
554            let id = memory_hit.memory.id.as_str().to_string();
555            hits.insert(
556                id,
557                GraphSearchHit {
558                    memory: memory_hit.memory,
559                    score: memory_hit.score,
560                    provenance: SearchProvenance::Semantic,
561                    related_entities: Vec::new(),
562                },
563            );
564        }
565
566        // Merge graph results
567        for (memory_id, (graph_score, source_entity, hop_count)) in graph {
568            let id = memory_id.as_str().to_string();
569            self.merge_single_graph_result(
570                &mut hits,
571                id,
572                &memory_id,
573                graph_score,
574                source_entity,
575                hop_count,
576            );
577        }
578
579        // Sort by score descending
580        let mut result: Vec<GraphSearchHit> = hits.into_values().collect();
581        result.sort_by(|a, b| {
582            b.score
583                .partial_cmp(&a.score)
584                .unwrap_or(std::cmp::Ordering::Equal)
585        });
586
587        Ok(result)
588    }
589
590    /// Merges a single graph result into the hits map.
591    fn merge_single_graph_result(
592        &self,
593        hits: &mut HashMap<String, GraphSearchHit>,
594        id: String,
595        memory_id: &MemoryId,
596        graph_score: f32,
597        source_entity: EntityId,
598        hop_count: usize,
599    ) {
600        if let Some(existing) = hits.get_mut(&id) {
601            // Found in both - upgrade provenance
602            #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
603            let semantic_score_int = (existing.score.abs() * 1000.0) as u32;
604            existing.provenance = SearchProvenance::Both {
605                semantic_score: semantic_score_int,
606                source_entity: source_entity.clone(),
607            };
608            // Boost score for appearing in both using midpoint
609            existing.score =
610                f32::midpoint(existing.score, graph_score) * self.config.expansion_boost;
611            existing.related_entities.push(source_entity);
612            return;
613        }
614
615        // Only found via graph - need to fetch the memory
616        let Ok(Some(memory)) = self.recall.get_by_id(memory_id) else {
617            return;
618        };
619
620        hits.insert(
621            id,
622            GraphSearchHit {
623                memory,
624                score: graph_score,
625                provenance: SearchProvenance::GraphExpansion {
626                    source_entity: source_entity.clone(),
627                    hop_count,
628                },
629                related_entities: vec![source_entity],
630            },
631        );
632    }
633
634    /// Performs semantic-only search (no graph expansion).
635    ///
636    /// This is useful for comparison or when graph expansion is not desired.
637    ///
638    /// # Errors
639    ///
640    /// Returns an error if the semantic search operation fails.
641    pub fn search_semantic_only(
642        &self,
643        query: &str,
644        filter: &SearchFilter,
645        limit: usize,
646    ) -> Result<GraphSearchResults> {
647        let semantic_results = self
648            .recall
649            .search(query, SearchMode::Hybrid, filter, limit)?;
650
651        let hits: Vec<GraphSearchHit> = semantic_results
652            .memories
653            .into_iter()
654            .map(|h| GraphSearchHit {
655                memory: h.memory,
656                score: h.score,
657                provenance: SearchProvenance::Semantic,
658                related_entities: Vec::new(),
659            })
660            .collect();
661
662        let count = hits.len();
663
664        Ok(GraphSearchResults {
665            query: query.to_string(),
666            hits,
667            semantic_count: count,
668            graph_count: 0,
669            query_entities: Vec::new(),
670        })
671    }
672
673    /// Performs graph-only search (no semantic search).
674    ///
675    /// Searches by extracting entities from the query and expanding the graph.
676    ///
677    /// # Errors
678    ///
679    /// Returns an error if entity extraction or graph expansion fails.
680    pub fn search_graph_only(
681        &self,
682        query: &str,
683        limit: usize,
684        expansion: Option<ExpansionConfig>,
685    ) -> Result<GraphSearchResults> {
686        let expansion = expansion.unwrap_or_default();
687
688        // Extract entities from query
689        let extraction = self.extractor.extract(query)?;
690        let query_entities: Vec<String> = extraction
691            .entities
692            .iter()
693            .filter(|e| e.confidence >= self.config.min_entity_confidence)
694            .take(self.config.max_query_entities)
695            .map(|e| e.name.clone())
696            .collect();
697
698        // Expand from entities
699        let graph_memories = self.expand_from_entities(&query_entities, &expansion)?;
700        let graph_count = graph_memories.len();
701
702        // Convert to hits
703        let mut hits: Vec<GraphSearchHit> = Vec::new();
704        for (memory_id, (score, source_entity, hop_count)) in graph_memories {
705            if let Ok(Some(memory)) = self.recall.get_by_id(&memory_id) {
706                hits.push(GraphSearchHit {
707                    memory,
708                    score,
709                    provenance: SearchProvenance::GraphExpansion {
710                        source_entity: source_entity.clone(),
711                        hop_count,
712                    },
713                    related_entities: vec![source_entity],
714                });
715            }
716        }
717
718        // Sort by score and limit
719        hits.sort_by(|a, b| {
720            b.score
721                .partial_cmp(&a.score)
722                .unwrap_or(std::cmp::Ordering::Equal)
723        });
724        hits.truncate(limit);
725
726        Ok(GraphSearchResults {
727            query: query.to_string(),
728            hits,
729            semantic_count: 0,
730            graph_count,
731            query_entities,
732        })
733    }
734}
735
736// ============================================================================
737// Tests
738// ============================================================================
739
740#[cfg(test)]
741mod tests {
742    use super::*;
743
744    // ========== Configuration Tests ==========
745
746    #[test]
747    fn test_config_defaults() {
748        let config = GraphRAGConfig::default();
749        assert_eq!(config.max_depth, 2);
750        assert!((config.expansion_boost - 1.2).abs() < f32::EPSILON);
751        assert_eq!(config.max_query_entities, 5);
752        assert_eq!(config.max_expansion_results, 10);
753    }
754
755    #[test]
756    fn test_config_builder() {
757        let config = GraphRAGConfig::new()
758            .with_max_depth(3)
759            .with_expansion_boost(1.5)
760            .with_max_query_entities(10)
761            .with_max_expansion_results(20);
762
763        assert_eq!(config.max_depth, 3);
764        assert!((config.expansion_boost - 1.5).abs() < f32::EPSILON);
765        assert_eq!(config.max_query_entities, 10);
766        assert_eq!(config.max_expansion_results, 20);
767    }
768
769    #[test]
770    fn test_expansion_config_defaults() {
771        let config = ExpansionConfig::default();
772        assert!(config.depth.is_none());
773        assert!(config.entity_type_filter.is_none());
774        assert!(config.use_relationship_weight);
775    }
776
777    // ========== Provenance Tests ==========
778
779    #[test]
780    fn test_provenance_semantic() {
781        let provenance = SearchProvenance::Semantic;
782        assert!(matches!(provenance, SearchProvenance::Semantic));
783    }
784
785    #[test]
786    fn test_provenance_graph_expansion() {
787        let provenance = SearchProvenance::GraphExpansion {
788            source_entity: EntityId::new("e123"),
789            hop_count: 2,
790        };
791        let SearchProvenance::GraphExpansion {
792            source_entity,
793            hop_count,
794        } = provenance
795        else {
796            unreachable!("Expected GraphExpansion variant");
797        };
798        assert_eq!(source_entity.as_ref(), "e123");
799        assert_eq!(hop_count, 2);
800    }
801
802    #[test]
803    fn test_provenance_both() {
804        let provenance = SearchProvenance::Both {
805            semantic_score: 850,
806            source_entity: EntityId::new("e456"),
807        };
808        let SearchProvenance::Both {
809            semantic_score,
810            source_entity,
811        } = provenance
812        else {
813            unreachable!("Expected Both variant");
814        };
815        assert_eq!(semantic_score, 850);
816        assert_eq!(source_entity.as_ref(), "e456");
817    }
818
819    // ========== Results Tests ==========
820
821    #[test]
822    fn test_results_empty() {
823        let results = GraphSearchResults {
824            query: "test".to_string(),
825            hits: Vec::new(),
826            semantic_count: 0,
827            graph_count: 0,
828            query_entities: Vec::new(),
829        };
830        assert!(results.is_empty());
831        assert_eq!(results.len(), 0);
832    }
833
834    #[test]
835    fn test_results_counts() {
836        let results = GraphSearchResults {
837            query: "test".to_string(),
838            hits: Vec::new(),
839            semantic_count: 10,
840            graph_count: 5,
841            query_entities: vec!["Rust".to_string(), "PostgreSQL".to_string()],
842        };
843        assert_eq!(results.semantic_count, 10);
844        assert_eq!(results.graph_count, 5);
845        assert_eq!(results.query_entities.len(), 2);
846    }
847}