Skip to main content

subcog/services/
enrichment.rs

1//! Memory enrichment service.
2//!
3//! Enriches memories with tags, structure, and context using LLM.
4
5use crate::llm::{
6    LlmProvider, OperationMode, build_system_prompt, sanitize_llm_response_for_error,
7};
8use crate::models::{Memory, MemoryId, SearchFilter};
9use crate::storage::traits::IndexBackend;
10use crate::{Error, Result};
11use std::sync::Arc;
12use std::time::Instant;
13use tracing::instrument;
14
15/// Service for enriching memories with LLM-generated tags and metadata.
16pub struct EnrichmentService<P: LlmProvider> {
17    /// LLM provider for generating enrichments.
18    llm: P,
19    /// Index backend for memory access.
20    index: Arc<dyn IndexBackend>,
21}
22
23impl<P: LlmProvider> EnrichmentService<P> {
24    /// Creates a new enrichment service.
25    #[must_use]
26    pub fn new(llm: P, index: Arc<dyn IndexBackend>) -> Self {
27        Self { llm, index }
28    }
29
30    /// Enriches all memories that have empty tags.
31    ///
32    /// # Arguments
33    ///
34    /// * `dry_run` - If true, shows what would be changed without applying
35    /// * `update_all` - If true, updates all memories even if they have tags
36    ///
37    /// # Errors
38    ///
39    /// Returns [`Error::OperationFailed`] if:
40    /// - Memory listing fails (database access error)
41    /// - LLM enrichment fails for all memories
42    #[instrument(skip(self), fields(operation = "enrich_all", dry_run = dry_run, update_all = update_all))]
43    pub fn enrich_all(&self, dry_run: bool, update_all: bool) -> Result<EnrichmentStats> {
44        let start = Instant::now();
45        let result = (|| {
46            // Get all memory IDs from SQLite
47            let filter = SearchFilter::default();
48            let all_ids = self.index.list_all(&filter, usize::MAX)?;
49
50            let mut stats = EnrichmentStats {
51                total: all_ids.len(),
52                ..Default::default()
53            };
54
55            for (memory_id, _score) in &all_ids {
56                if let Some(memory) = self.index.get_memory(memory_id)? {
57                    self.process_memory(&memory, dry_run, update_all, &mut stats);
58                }
59            }
60
61            Ok(stats)
62        })();
63
64        let status = if result.is_ok() { "success" } else { "error" };
65        metrics::counter!(
66            "memory_operations_total",
67            "operation" => "enrich",
68            "namespace" => "mixed",
69            "domain" => "project",
70            "status" => status
71        )
72        .increment(1);
73        metrics::histogram!(
74            "memory_operation_duration_ms",
75            "operation" => "enrich",
76            "namespace" => "mixed"
77        )
78        .record(start.elapsed().as_secs_f64() * 1000.0);
79
80        result
81    }
82
83    /// Processes a single memory for enrichment.
84    fn process_memory(
85        &self,
86        memory: &Memory,
87        dry_run: bool,
88        update_all: bool,
89        stats: &mut EnrichmentStats,
90    ) {
91        // Check if tags exist
92        let has_tags = !memory.tags.is_empty();
93
94        // Skip if has tags and not updating all
95        if has_tags && !update_all {
96            stats.skipped += 1;
97            return;
98        }
99
100        let namespace = memory.namespace.as_str();
101
102        let new_tags = match self.generate_tags(&memory.content, namespace) {
103            Ok(tags) => tags,
104            Err(e) => {
105                tracing::warn!("Failed to generate tags for {}: {e}", memory.id.as_str());
106                stats.failed += 1;
107                return;
108            },
109        };
110
111        let action = if has_tags { "update" } else { "enrich" };
112
113        if dry_run {
114            tracing::info!(
115                "Would {action} {} with tags: {new_tags:?}",
116                memory.id.as_str()
117            );
118            if has_tags {
119                stats.would_update += 1;
120            } else {
121                stats.would_enrich += 1;
122            }
123            return;
124        }
125
126        match self.update_memory_tags(memory, &new_tags) {
127            Ok(()) => {
128                tracing::info!("{action}ed {} with tags: {new_tags:?}", memory.id.as_str());
129                if has_tags {
130                    stats.updated += 1;
131                } else {
132                    stats.enriched += 1;
133                }
134            },
135            Err(e) => {
136                tracing::warn!("Failed to update memory {}: {e}", memory.id.as_str());
137                stats.failed += 1;
138            },
139        }
140    }
141
142    /// Enriches a specific memory by ID.
143    ///
144    /// # Errors
145    ///
146    /// Returns [`Error::OperationFailed`] if:
147    /// - The memory with the given ID is not found
148    /// - LLM tag generation fails (provider error or invalid response)
149    /// - Memory update fails (database error)
150    #[instrument(skip(self), fields(operation = "enrich_one", dry_run = dry_run, memory_id = memory_id))]
151    pub fn enrich_one(&self, memory_id: &str, dry_run: bool) -> Result<EnrichmentResult> {
152        let start = Instant::now();
153        let result = (|| {
154            let id = MemoryId::new(memory_id);
155            let memory = self
156                .index
157                .get_memory(&id)?
158                .ok_or_else(|| Error::OperationFailed {
159                    operation: "enrich_one".to_string(),
160                    cause: format!("Memory not found: {memory_id}"),
161                })?;
162
163            let namespace = memory.namespace.as_str();
164
165            // Generate tags
166            let new_tags = self.generate_tags(&memory.content, namespace)?;
167
168            if dry_run {
169                return Ok(EnrichmentResult {
170                    memory_id: memory_id.to_string(),
171                    new_tags,
172                    applied: false,
173                });
174            }
175
176            // Update the memory
177            self.update_memory_tags(&memory, &new_tags)?;
178
179            Ok(EnrichmentResult {
180                memory_id: memory_id.to_string(),
181                new_tags,
182                applied: true,
183            })
184        })();
185
186        let status = if result.is_ok() { "success" } else { "error" };
187        metrics::counter!(
188            "memory_operations_total",
189            "operation" => "enrich",
190            "namespace" => "mixed",
191            "domain" => "project",
192            "status" => status
193        )
194        .increment(1);
195        metrics::histogram!(
196            "memory_operation_duration_ms",
197            "operation" => "enrich",
198            "namespace" => "mixed"
199        )
200        .record(start.elapsed().as_secs_f64() * 1000.0);
201
202        result
203    }
204
205    /// Generates tags for content using LLM.
206    fn generate_tags(&self, content: &str, namespace: &str) -> Result<Vec<String>> {
207        let system = build_system_prompt(OperationMode::Enrichment, None);
208        let user_prompt = format!(
209            "Generate tags for this memory.\n\nNamespace: {namespace}\nContent: {content}\n\nReturn ONLY a JSON array of strings."
210        );
211        let response = self.llm.complete_with_system(&system, &user_prompt)?;
212
213        // Parse the JSON response
214        let sanitized = sanitize_llm_response_for_error(&response);
215        let tags: Vec<String> =
216            serde_json::from_str(&response).map_err(|e| Error::OperationFailed {
217                operation: "parse_tags".to_string(),
218                cause: format!("Failed to parse LLM response: {e}. Response was: {sanitized}"),
219            })?;
220
221        Ok(tags)
222    }
223
224    /// Updates a memory with new tags.
225    fn update_memory_tags(&self, memory: &Memory, new_tags: &[String]) -> Result<()> {
226        // Create updated memory with new tags
227        let updated_memory = Memory {
228            id: memory.id.clone(),
229            content: memory.content.clone(),
230            namespace: memory.namespace,
231            domain: memory.domain.clone(),
232            project_id: memory.project_id.clone(),
233            branch: memory.branch.clone(),
234            file_path: memory.file_path.clone(),
235            status: memory.status,
236            created_at: memory.created_at,
237            updated_at: std::time::SystemTime::now()
238                .duration_since(std::time::UNIX_EPOCH)
239                .map(|d| d.as_secs())
240                .unwrap_or(memory.updated_at),
241            tombstoned_at: memory.tombstoned_at,
242            expires_at: memory.expires_at,
243            embedding: memory.embedding.clone(),
244            tags: new_tags.to_vec(),
245            #[cfg(feature = "group-scope")]
246            group_id: memory.group_id.clone(),
247            source: memory.source.clone(),
248            is_summary: memory.is_summary,
249            source_memory_ids: memory.source_memory_ids.clone(),
250            consolidation_timestamp: memory.consolidation_timestamp,
251        };
252
253        // Re-index the updated memory
254        self.index.index(&updated_memory)?;
255
256        Ok(())
257    }
258}
259
260/// Statistics from a batch enrichment operation.
261#[derive(Debug, Clone, Default)]
262pub struct EnrichmentStats {
263    /// Total memories scanned.
264    pub total: usize,
265    /// Memories newly enriched (had no tags).
266    pub enriched: usize,
267    /// Memories updated (had existing tags).
268    pub updated: usize,
269    /// Memories skipped (already have tags, not in update mode).
270    pub skipped: usize,
271    /// Memories that would be enriched (dry run).
272    pub would_enrich: usize,
273    /// Memories that would be updated (dry run).
274    pub would_update: usize,
275    /// Memories that failed to enrich.
276    pub failed: usize,
277}
278
279impl EnrichmentStats {
280    /// Returns a human-readable summary.
281    #[must_use]
282    pub fn summary(&self) -> String {
283        if self.would_enrich > 0 || self.would_update > 0 {
284            format!(
285                "Dry run: {} would be enriched, {} would be updated, {} skipped, {} failed (of {} total)",
286                self.would_enrich, self.would_update, self.skipped, self.failed, self.total
287            )
288        } else {
289            format!(
290                "Enriched: {}, Updated: {}, Skipped: {}, Failed: {} (of {} total)",
291                self.enriched, self.updated, self.skipped, self.failed, self.total
292            )
293        }
294    }
295}
296
297/// Result of enriching a single memory.
298#[derive(Debug, Clone)]
299pub struct EnrichmentResult {
300    /// Memory ID that was enriched.
301    pub memory_id: String,
302    /// New tags that were generated.
303    pub new_tags: Vec<String>,
304    /// Whether the changes were applied.
305    pub applied: bool,
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_enrichment_stats_summary() {
314        let stats = EnrichmentStats {
315            total: 10,
316            enriched: 5,
317            updated: 2,
318            skipped: 1,
319            would_enrich: 0,
320            would_update: 0,
321            failed: 2,
322        };
323        let summary = stats.summary();
324        assert!(summary.contains("Enriched: 5"));
325        assert!(summary.contains("Updated: 2"));
326        assert!(summary.contains("Skipped: 1"));
327        assert!(summary.contains("Failed: 2"));
328    }
329
330    #[test]
331    fn test_enrichment_stats_dry_run_summary() {
332        let stats = EnrichmentStats {
333            total: 10,
334            enriched: 0,
335            updated: 0,
336            skipped: 1,
337            would_enrich: 5,
338            would_update: 2,
339            failed: 2,
340        };
341        let summary = stats.summary();
342        assert!(summary.contains("Dry run"));
343        assert!(summary.contains("5 would be enriched"));
344        assert!(summary.contains("2 would be updated"));
345    }
346
347    #[test]
348    fn test_enrichment_stats_default() {
349        let stats = EnrichmentStats::default();
350        assert_eq!(stats.total, 0);
351        assert_eq!(stats.enriched, 0);
352        assert_eq!(stats.updated, 0);
353        assert_eq!(stats.skipped, 0);
354        assert_eq!(stats.failed, 0);
355    }
356
357    #[test]
358    fn test_enrichment_result() {
359        let result = EnrichmentResult {
360            memory_id: "test-id".to_string(),
361            new_tags: vec!["rust".to_string(), "memory".to_string()],
362            applied: true,
363        };
364        assert_eq!(result.memory_id, "test-id");
365        assert_eq!(result.new_tags.len(), 2);
366        assert!(result.applied);
367    }
368}