Skip to main content

subcog/hooks/
post_tool_use.rs

1//! Post tool use hook handler.
2
3use super::HookHandler;
4use crate::Result;
5use crate::models::{IssueSeverity, SearchFilter, SearchMode, validate_prompt_content};
6use crate::observability::current_request_id;
7use crate::services::RecallService;
8use std::fmt::Write;
9use std::time::Instant;
10use tracing::instrument;
11
12/// Handles `PostToolUse` hook events.
13///
14/// Surfaces related memories after tool usage.
15pub struct PostToolUseHandler {
16    /// Recall service for searching memories.
17    recall: Option<RecallService>,
18    /// Maximum number of memories to surface.
19    max_memories: usize,
20    /// Minimum relevance score to surface.
21    min_relevance: f32,
22}
23
24/// Tools that may benefit from memory context.
25const CONTEXTUAL_TOOLS: &[&str] = &[
26    "Read", "Write", "Edit", "Bash", "Search", "Grep", "Glob", "LSP",
27];
28
29impl PostToolUseHandler {
30    /// Creates a new handler.
31    #[must_use]
32    pub const fn new() -> Self {
33        Self {
34            recall: None,
35            max_memories: 3,
36            min_relevance: 0.5,
37        }
38    }
39
40    /// Sets the recall service.
41    #[must_use]
42    pub fn with_recall(mut self, recall: RecallService) -> Self {
43        self.recall = Some(recall);
44        self
45    }
46
47    /// Sets the maximum number of memories to surface.
48    #[must_use]
49    pub const fn with_max_memories(mut self, max: usize) -> Self {
50        self.max_memories = max;
51        self
52    }
53
54    /// Sets the minimum relevance score.
55    #[must_use]
56    pub const fn with_min_relevance(mut self, min: f32) -> Self {
57        self.min_relevance = min;
58        self
59    }
60
61    /// Determines if a tool use warrants memory lookup.
62    /// Kept as method for API consistency.
63    #[allow(clippy::unused_self)]
64    fn should_lookup(&self, tool_name: &str) -> bool {
65        CONTEXTUAL_TOOLS
66            .iter()
67            .any(|t| t.eq_ignore_ascii_case(tool_name))
68    }
69
70    /// Checks if a tool is a prompt save tool.
71    fn is_prompt_save_tool(tool_name: &str) -> bool {
72        let lower = tool_name.to_lowercase();
73        lower == "prompt_save" || lower == "prompt.save" || lower == "subcog_prompt_save"
74    }
75
76    /// Validates prompt content and returns any issues.
77    ///
78    /// Returns a guidance message if validation issues are found.
79    fn validate_prompt(&self, tool_input: &serde_json::Value) -> Option<String> {
80        // Extract content from tool input
81        let content = tool_input.get("content").and_then(|v| v.as_str())?;
82
83        // Skip validation for empty content
84        if content.is_empty() {
85            return None;
86        }
87
88        // Validate the prompt content
89        let validation = validate_prompt_content(content);
90
91        if validation.is_valid {
92            return None;
93        }
94
95        // Build guidance message for issues
96        let mut guidance = vec!["**Prompt Validation Issues**\n".to_string()];
97
98        for issue in &validation.issues {
99            let severity_icon = match issue.severity {
100                IssueSeverity::Error => "\u{274c}",   // X
101                IssueSeverity::Warning => "\u{26a0}", // Warning sign
102            };
103
104            let position_info = issue
105                .position
106                .map_or(String::new(), |pos| format!(" at position {pos}"));
107
108            guidance.push(format!(
109                "- {severity_icon} {}{position_info}",
110                issue.message
111            ));
112        }
113
114        guidance.push("\n**Tips:**".to_string());
115        guidance.push("- Variables use `{{variable_name}}` syntax".to_string());
116        guidance.push("- Ensure all `{{` have matching `}}`".to_string());
117        guidance.push("- Variable names should be alphanumeric with underscores".to_string());
118        guidance.push("- See `subcog://help/prompts` for format documentation".to_string());
119
120        Some(guidance.join("\n"))
121    }
122
123    /// Extracts a search query from tool input.
124    /// Kept as method for API consistency.
125    #[allow(clippy::unused_self)]
126    fn extract_query(&self, tool_name: &str, tool_input: &serde_json::Value) -> Option<String> {
127        match tool_name.to_lowercase().as_str() {
128            "read" | "write" | "edit" => {
129                // Use file path as query
130                tool_input
131                    .get("file_path")
132                    .or_else(|| tool_input.get("path"))
133                    .and_then(|v| v.as_str())
134                    .map(|p| {
135                        // Extract meaningful parts from path
136                        let parts: Vec<&str> = p.split('/').filter(|s| !s.is_empty()).collect();
137                        parts.join(" ")
138                    })
139            },
140            "bash" => {
141                // Use command as query
142                tool_input.get("command").and_then(|v| v.as_str()).map(|c| {
143                    // Extract key terms from command
144                    c.split_whitespace().take(5).collect::<Vec<_>>().join(" ")
145                })
146            },
147            "search" | "grep" => {
148                // Use pattern as query
149                tool_input
150                    .get("pattern")
151                    .or_else(|| tool_input.get("query"))
152                    .and_then(|v| v.as_str())
153                    .map(String::from)
154            },
155            "glob" => {
156                // Use pattern as query
157                tool_input
158                    .get("pattern")
159                    .and_then(|v| v.as_str())
160                    .map(|p| p.replace(['*', '.'], " "))
161            },
162            "lsp" => {
163                // Use symbol or file as query
164                tool_input
165                    .get("symbol")
166                    .or_else(|| tool_input.get("file_path"))
167                    .and_then(|v| v.as_str())
168                    .map(String::from)
169            },
170            _ => None,
171        }
172    }
173
174    /// Searches for related memories.
175    fn find_related_memories(&self, query: &str) -> Result<Vec<RelatedMemory>> {
176        let Some(recall) = &self.recall else {
177            return Ok(Vec::new());
178        };
179
180        let result = recall.search(
181            query,
182            SearchMode::Hybrid,
183            &SearchFilter::new(),
184            self.max_memories,
185        )?;
186
187        let memories: Vec<RelatedMemory> = result
188            .memories
189            .into_iter()
190            .filter(|hit| hit.score >= self.min_relevance)
191            .map(|hit| {
192                // Build full URN: subcog://{domain}/{namespace}/{id}
193                let domain_part = if hit.memory.domain.is_project_scoped() {
194                    "project".to_string()
195                } else {
196                    hit.memory.domain.to_string()
197                };
198                let urn = format!(
199                    "subcog://{}/{}/{}",
200                    domain_part,
201                    hit.memory.namespace.as_str(),
202                    hit.memory.id.as_str()
203                );
204                RelatedMemory {
205                    urn,
206                    namespace: hit.memory.namespace.as_str().to_string(),
207                    content: truncate_content(&hit.memory.content, 200),
208                    relevance: hit.score,
209                }
210            })
211            .collect();
212
213        Ok(memories)
214    }
215
216    fn empty_response() -> Result<String> {
217        Self::serialize_response(&serde_json::json!({}))
218    }
219
220    fn serialize_response(response: &serde_json::Value) -> Result<String> {
221        serde_json::to_string(response).map_err(|e| crate::Error::OperationFailed {
222            operation: "serialize_response".to_string(),
223            cause: e.to_string(),
224        })
225    }
226
227    fn build_memories_response(
228        tool_name: &str,
229        query: &str,
230        memories: &[RelatedMemory],
231    ) -> serde_json::Value {
232        if memories.is_empty() {
233            return serde_json::json!({});
234        }
235
236        let memories_json: Vec<serde_json::Value> = memories
237            .iter()
238            .map(|m| {
239                serde_json::json!({
240                    "urn": m.urn,
241                    "namespace": m.namespace,
242                    "content": m.content,
243                    "relevance": m.relevance
244                })
245            })
246            .collect();
247
248        let metadata = serde_json::json!({
249            "memories": memories_json,
250            "lookup_performed": true,
251            "query": query,
252            "tool_name": tool_name
253        });
254
255        // Build single-line XML format for token efficiency
256        let mut xml = String::from("<memories>");
257        for m in memories {
258            // Escape XML special chars in content
259            let content = m
260                .content
261                .replace('&', "&amp;")
262                .replace('<', "&lt;")
263                .replace('>', "&gt;")
264                .replace('"', "&quot;");
265            let _ = write!(
266                xml,
267                "<m urn=\"{}\" ns=\"{}\" rel=\"{:.0}\">{}</m>",
268                m.urn,
269                m.namespace,
270                m.relevance * 100.0,
271                content
272            );
273        }
274        xml.push_str("</memories>");
275        let context = xml;
276
277        let metadata_str = serde_json::to_string(&metadata).unwrap_or_default();
278        let context_with_metadata =
279            format!("{context}\n\n<!-- subcog-metadata: {metadata_str} -->");
280
281        serde_json::json!({
282            "hookSpecificOutput": {
283                "hookEventName": "PostToolUse",
284                "additionalContext": context_with_metadata
285            }
286        })
287    }
288
289    fn handle_inner(
290        &self,
291        input: &str,
292        lookup_performed: &mut bool,
293        memories_found: &mut usize,
294    ) -> Result<String> {
295        let input_json: serde_json::Value =
296            serde_json::from_str(input).unwrap_or_else(|_| serde_json::json!({}));
297
298        let tool_name = input_json
299            .get("tool_name")
300            .and_then(|v| v.as_str())
301            .unwrap_or("");
302        let span = tracing::Span::current();
303        span.record("tool_name", tool_name);
304
305        let tool_input = input_json
306            .get("tool_input")
307            .unwrap_or(&serde_json::Value::Null);
308
309        if Self::is_prompt_save_tool(tool_name) {
310            if let Some(guidance) = self.validate_prompt(tool_input) {
311                let response = serde_json::json!({
312                    "hookSpecificOutput": {
313                        "hookEventName": "PostToolUse",
314                        "additionalContext": guidance
315                    }
316                });
317                return Self::serialize_response(&response);
318            }
319            return Self::empty_response();
320        }
321
322        if !self.should_lookup(tool_name) {
323            return Self::empty_response();
324        }
325
326        let query = self
327            .extract_query(tool_name, tool_input)
328            .filter(|q| !q.is_empty());
329        let Some(query) = query else {
330            return Self::empty_response();
331        };
332
333        let memories = self.find_related_memories(&query)?;
334        *lookup_performed = true;
335        *memories_found = memories.len();
336        span.record("lookup_performed", *lookup_performed);
337        span.record("memories_found", *memories_found);
338
339        let response = Self::build_memories_response(tool_name, &query, &memories);
340        Self::serialize_response(&response)
341    }
342}
343
344/// Truncates content to a maximum length.
345fn truncate_content(content: &str, max_len: usize) -> String {
346    if content.len() <= max_len {
347        content.to_string()
348    } else {
349        format!("{}...", &content[..max_len.saturating_sub(3)])
350    }
351}
352
353impl Default for PostToolUseHandler {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359impl HookHandler for PostToolUseHandler {
360    fn event_type(&self) -> &'static str {
361        "PostToolUse"
362    }
363
364    #[instrument(
365        name = "subcog.hook.post_tool_use",
366        skip(self, input),
367        fields(
368            request_id = tracing::field::Empty,
369            component = "hooks",
370            operation = "post_tool_use",
371            hook = "PostToolUse",
372            tool_name = tracing::field::Empty,
373            lookup_performed = tracing::field::Empty,
374            memories_found = tracing::field::Empty
375        )
376    )]
377    fn handle(&self, input: &str) -> Result<String> {
378        let start = Instant::now();
379        let mut lookup_performed = false;
380        let mut memories_found = 0usize;
381        if let Some(request_id) = current_request_id() {
382            tracing::Span::current().record("request_id", request_id.as_str());
383        }
384
385        let result = self.handle_inner(input, &mut lookup_performed, &mut memories_found);
386
387        let status = if result.is_ok() { "success" } else { "error" };
388        metrics::counter!(
389            "hook_executions_total",
390            "hook_type" => "PostToolUse",
391            "status" => status
392        )
393        .increment(1);
394        metrics::histogram!("hook_duration_ms", "hook_type" => "PostToolUse")
395            .record(start.elapsed().as_secs_f64() * 1000.0);
396        if lookup_performed {
397            metrics::counter!(
398                "hook_memory_lookup_total",
399                "hook_type" => "PostToolUse",
400                "result" => if memories_found > 0 { "hit" } else { "miss" }
401            )
402            .increment(1);
403        }
404
405        result
406    }
407}
408
409/// A related memory surfaced by the handler.
410#[derive(Debug, Clone)]
411pub struct RelatedMemory {
412    /// Full URN (`subcog://{domain}/{namespace}/{id}`).
413    pub urn: String,
414    /// Namespace.
415    pub namespace: String,
416    /// Truncated content.
417    pub content: String,
418    /// Relevance score.
419    pub relevance: f32,
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_handler_creation() {
428        let handler = PostToolUseHandler::default();
429        assert_eq!(handler.event_type(), "PostToolUse");
430    }
431
432    #[test]
433    fn test_should_lookup() {
434        let handler = PostToolUseHandler::default();
435
436        assert!(handler.should_lookup("Read"));
437        assert!(handler.should_lookup("read"));
438        assert!(handler.should_lookup("Write"));
439        assert!(handler.should_lookup("Bash"));
440        assert!(handler.should_lookup("Grep"));
441        assert!(!handler.should_lookup("Unknown"));
442        assert!(!handler.should_lookup(""));
443    }
444
445    #[test]
446    fn test_extract_query_read() {
447        let handler = PostToolUseHandler::default();
448
449        let input = serde_json::json!({
450            "file_path": "/src/services/capture.rs"
451        });
452
453        let query = handler.extract_query("Read", &input);
454        assert!(query.is_some());
455        assert!(query.as_ref().is_some_and(|q| q.contains("capture")));
456    }
457
458    #[test]
459    fn test_extract_query_bash() {
460        let handler = PostToolUseHandler::default();
461
462        let input = serde_json::json!({
463            "command": "cargo test --all-features"
464        });
465
466        let query = handler.extract_query("Bash", &input);
467        assert!(query.is_some());
468        assert!(query.as_ref().is_some_and(|q| q.contains("cargo")));
469    }
470
471    #[test]
472    fn test_extract_query_grep() {
473        let handler = PostToolUseHandler::default();
474
475        let input = serde_json::json!({
476            "pattern": "fn capture"
477        });
478
479        let query = handler.extract_query("grep", &input);
480        assert!(query.is_some());
481        assert_eq!(query, Some("fn capture".to_string()));
482    }
483
484    #[test]
485    fn test_handle_non_contextual_tool() {
486        let handler = PostToolUseHandler::default();
487
488        let input = r#"{"tool_name": "SomeOtherTool", "tool_input": {}}"#;
489
490        let result = handler.handle(input);
491        assert!(result.is_ok());
492
493        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
494        // Claude Code hook format - empty response for non-contextual tools
495        assert!(response.as_object().unwrap().is_empty());
496    }
497
498    #[test]
499    fn test_handle_contextual_tool() {
500        let handler = PostToolUseHandler::default();
501
502        let input = r#"{"tool_name": "Read", "tool_input": {"file_path": "/src/main.rs"}}"#;
503
504        let result = handler.handle(input);
505        assert!(result.is_ok());
506
507        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
508        // Without recall service, no memories found - empty response
509        // (memories would be returned in hookSpecificOutput.additionalContext if found)
510        assert!(response.as_object().unwrap().is_empty());
511    }
512
513    #[test]
514    fn test_truncate_content() {
515        let short = "Short text";
516        assert_eq!(truncate_content(short, 100), short);
517
518        let long =
519            "This is a much longer text that should be truncated because it exceeds the limit";
520        let truncated = truncate_content(long, 30);
521        assert!(truncated.ends_with("..."));
522        assert!(truncated.len() <= 30);
523    }
524
525    #[test]
526    fn test_configuration() {
527        let handler = PostToolUseHandler::default()
528            .with_max_memories(5)
529            .with_min_relevance(0.7);
530
531        assert_eq!(handler.max_memories, 5);
532        assert!((handler.min_relevance - 0.7).abs() < f32::EPSILON);
533    }
534
535    #[test]
536    fn test_is_prompt_save_tool() {
537        assert!(PostToolUseHandler::is_prompt_save_tool("prompt_save"));
538        assert!(PostToolUseHandler::is_prompt_save_tool("PROMPT_SAVE"));
539        assert!(PostToolUseHandler::is_prompt_save_tool("prompt.save"));
540        assert!(PostToolUseHandler::is_prompt_save_tool(
541            "subcog_prompt_save"
542        ));
543        assert!(!PostToolUseHandler::is_prompt_save_tool("prompt_get"));
544        assert!(!PostToolUseHandler::is_prompt_save_tool("subcog_capture"));
545    }
546
547    #[test]
548    fn test_handle_prompt_save_valid() {
549        let handler = PostToolUseHandler::default();
550
551        let input = serde_json::json!({
552            "tool_name": "prompt_save",
553            "tool_input": {
554                "name": "test-prompt",
555                "content": "Hello {{name}}, welcome to {{place}}!"
556            }
557        });
558
559        let result = handler.handle(&serde_json::to_string(&input).unwrap());
560        assert!(result.is_ok());
561
562        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
563        // Valid prompt - empty response (no validation issues)
564        assert!(response.as_object().unwrap().is_empty());
565    }
566
567    #[test]
568    fn test_handle_prompt_save_invalid_braces() {
569        let handler = PostToolUseHandler::default();
570
571        let input = serde_json::json!({
572            "tool_name": "prompt_save",
573            "tool_input": {
574                "name": "test-prompt",
575                "content": "Hello {{name, this is broken"
576            }
577        });
578
579        let result = handler.handle(&serde_json::to_string(&input).unwrap());
580        assert!(result.is_ok());
581
582        let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
583        // Invalid prompt - should have validation guidance
584        assert!(response.get("hookSpecificOutput").is_some());
585
586        let additional_context = response
587            .get("hookSpecificOutput")
588            .and_then(|o| o.get("additionalContext"))
589            .and_then(|v| v.as_str())
590            .unwrap_or("");
591        assert!(additional_context.contains("Prompt Validation Issues"));
592        assert!(additional_context.contains("subcog://help/prompts"));
593    }
594
595    #[test]
596    fn test_validate_prompt_empty_content() {
597        let handler = PostToolUseHandler::default();
598
599        let input = serde_json::json!({
600            "content": ""
601        });
602
603        // Empty content should return None (no validation needed)
604        let guidance = handler.validate_prompt(&input);
605        assert!(guidance.is_none());
606    }
607
608    #[test]
609    fn test_validate_prompt_missing_content() {
610        let handler = PostToolUseHandler::default();
611
612        let input = serde_json::json!({
613            "name": "test"
614        });
615
616        // Missing content should return None
617        let guidance = handler.validate_prompt(&input);
618        assert!(guidance.is_none());
619    }
620}