1use super::HookHandler;
4use crate::Result;
5use crate::models::{IssueSeverity, SearchFilter, SearchMode, validate_prompt_content};
6use crate::observability::current_request_id;
7use crate::services::RecallService;
8use std::fmt::Write;
9use std::time::Instant;
10use tracing::instrument;
11
12pub struct PostToolUseHandler {
16 recall: Option<RecallService>,
18 max_memories: usize,
20 min_relevance: f32,
22}
23
24const CONTEXTUAL_TOOLS: &[&str] = &[
26 "Read", "Write", "Edit", "Bash", "Search", "Grep", "Glob", "LSP",
27];
28
29impl PostToolUseHandler {
30 #[must_use]
32 pub const fn new() -> Self {
33 Self {
34 recall: None,
35 max_memories: 3,
36 min_relevance: 0.5,
37 }
38 }
39
40 #[must_use]
42 pub fn with_recall(mut self, recall: RecallService) -> Self {
43 self.recall = Some(recall);
44 self
45 }
46
47 #[must_use]
49 pub const fn with_max_memories(mut self, max: usize) -> Self {
50 self.max_memories = max;
51 self
52 }
53
54 #[must_use]
56 pub const fn with_min_relevance(mut self, min: f32) -> Self {
57 self.min_relevance = min;
58 self
59 }
60
61 #[allow(clippy::unused_self)]
64 fn should_lookup(&self, tool_name: &str) -> bool {
65 CONTEXTUAL_TOOLS
66 .iter()
67 .any(|t| t.eq_ignore_ascii_case(tool_name))
68 }
69
70 fn is_prompt_save_tool(tool_name: &str) -> bool {
72 let lower = tool_name.to_lowercase();
73 lower == "prompt_save" || lower == "prompt.save" || lower == "subcog_prompt_save"
74 }
75
76 fn validate_prompt(&self, tool_input: &serde_json::Value) -> Option<String> {
80 let content = tool_input.get("content").and_then(|v| v.as_str())?;
82
83 if content.is_empty() {
85 return None;
86 }
87
88 let validation = validate_prompt_content(content);
90
91 if validation.is_valid {
92 return None;
93 }
94
95 let mut guidance = vec!["**Prompt Validation Issues**\n".to_string()];
97
98 for issue in &validation.issues {
99 let severity_icon = match issue.severity {
100 IssueSeverity::Error => "\u{274c}", IssueSeverity::Warning => "\u{26a0}", };
103
104 let position_info = issue
105 .position
106 .map_or(String::new(), |pos| format!(" at position {pos}"));
107
108 guidance.push(format!(
109 "- {severity_icon} {}{position_info}",
110 issue.message
111 ));
112 }
113
114 guidance.push("\n**Tips:**".to_string());
115 guidance.push("- Variables use `{{variable_name}}` syntax".to_string());
116 guidance.push("- Ensure all `{{` have matching `}}`".to_string());
117 guidance.push("- Variable names should be alphanumeric with underscores".to_string());
118 guidance.push("- See `subcog://help/prompts` for format documentation".to_string());
119
120 Some(guidance.join("\n"))
121 }
122
123 #[allow(clippy::unused_self)]
126 fn extract_query(&self, tool_name: &str, tool_input: &serde_json::Value) -> Option<String> {
127 match tool_name.to_lowercase().as_str() {
128 "read" | "write" | "edit" => {
129 tool_input
131 .get("file_path")
132 .or_else(|| tool_input.get("path"))
133 .and_then(|v| v.as_str())
134 .map(|p| {
135 let parts: Vec<&str> = p.split('/').filter(|s| !s.is_empty()).collect();
137 parts.join(" ")
138 })
139 },
140 "bash" => {
141 tool_input.get("command").and_then(|v| v.as_str()).map(|c| {
143 c.split_whitespace().take(5).collect::<Vec<_>>().join(" ")
145 })
146 },
147 "search" | "grep" => {
148 tool_input
150 .get("pattern")
151 .or_else(|| tool_input.get("query"))
152 .and_then(|v| v.as_str())
153 .map(String::from)
154 },
155 "glob" => {
156 tool_input
158 .get("pattern")
159 .and_then(|v| v.as_str())
160 .map(|p| p.replace(['*', '.'], " "))
161 },
162 "lsp" => {
163 tool_input
165 .get("symbol")
166 .or_else(|| tool_input.get("file_path"))
167 .and_then(|v| v.as_str())
168 .map(String::from)
169 },
170 _ => None,
171 }
172 }
173
174 fn find_related_memories(&self, query: &str) -> Result<Vec<RelatedMemory>> {
176 let Some(recall) = &self.recall else {
177 return Ok(Vec::new());
178 };
179
180 let result = recall.search(
181 query,
182 SearchMode::Hybrid,
183 &SearchFilter::new(),
184 self.max_memories,
185 )?;
186
187 let memories: Vec<RelatedMemory> = result
188 .memories
189 .into_iter()
190 .filter(|hit| hit.score >= self.min_relevance)
191 .map(|hit| {
192 let domain_part = if hit.memory.domain.is_project_scoped() {
194 "project".to_string()
195 } else {
196 hit.memory.domain.to_string()
197 };
198 let urn = format!(
199 "subcog://{}/{}/{}",
200 domain_part,
201 hit.memory.namespace.as_str(),
202 hit.memory.id.as_str()
203 );
204 RelatedMemory {
205 urn,
206 namespace: hit.memory.namespace.as_str().to_string(),
207 content: truncate_content(&hit.memory.content, 200),
208 relevance: hit.score,
209 }
210 })
211 .collect();
212
213 Ok(memories)
214 }
215
216 fn empty_response() -> Result<String> {
217 Self::serialize_response(&serde_json::json!({}))
218 }
219
220 fn serialize_response(response: &serde_json::Value) -> Result<String> {
221 serde_json::to_string(response).map_err(|e| crate::Error::OperationFailed {
222 operation: "serialize_response".to_string(),
223 cause: e.to_string(),
224 })
225 }
226
227 fn build_memories_response(
228 tool_name: &str,
229 query: &str,
230 memories: &[RelatedMemory],
231 ) -> serde_json::Value {
232 if memories.is_empty() {
233 return serde_json::json!({});
234 }
235
236 let memories_json: Vec<serde_json::Value> = memories
237 .iter()
238 .map(|m| {
239 serde_json::json!({
240 "urn": m.urn,
241 "namespace": m.namespace,
242 "content": m.content,
243 "relevance": m.relevance
244 })
245 })
246 .collect();
247
248 let metadata = serde_json::json!({
249 "memories": memories_json,
250 "lookup_performed": true,
251 "query": query,
252 "tool_name": tool_name
253 });
254
255 let mut xml = String::from("<memories>");
257 for m in memories {
258 let content = m
260 .content
261 .replace('&', "&")
262 .replace('<', "<")
263 .replace('>', ">")
264 .replace('"', """);
265 let _ = write!(
266 xml,
267 "<m urn=\"{}\" ns=\"{}\" rel=\"{:.0}\">{}</m>",
268 m.urn,
269 m.namespace,
270 m.relevance * 100.0,
271 content
272 );
273 }
274 xml.push_str("</memories>");
275 let context = xml;
276
277 let metadata_str = serde_json::to_string(&metadata).unwrap_or_default();
278 let context_with_metadata =
279 format!("{context}\n\n<!-- subcog-metadata: {metadata_str} -->");
280
281 serde_json::json!({
282 "hookSpecificOutput": {
283 "hookEventName": "PostToolUse",
284 "additionalContext": context_with_metadata
285 }
286 })
287 }
288
289 fn handle_inner(
290 &self,
291 input: &str,
292 lookup_performed: &mut bool,
293 memories_found: &mut usize,
294 ) -> Result<String> {
295 let input_json: serde_json::Value =
296 serde_json::from_str(input).unwrap_or_else(|_| serde_json::json!({}));
297
298 let tool_name = input_json
299 .get("tool_name")
300 .and_then(|v| v.as_str())
301 .unwrap_or("");
302 let span = tracing::Span::current();
303 span.record("tool_name", tool_name);
304
305 let tool_input = input_json
306 .get("tool_input")
307 .unwrap_or(&serde_json::Value::Null);
308
309 if Self::is_prompt_save_tool(tool_name) {
310 if let Some(guidance) = self.validate_prompt(tool_input) {
311 let response = serde_json::json!({
312 "hookSpecificOutput": {
313 "hookEventName": "PostToolUse",
314 "additionalContext": guidance
315 }
316 });
317 return Self::serialize_response(&response);
318 }
319 return Self::empty_response();
320 }
321
322 if !self.should_lookup(tool_name) {
323 return Self::empty_response();
324 }
325
326 let query = self
327 .extract_query(tool_name, tool_input)
328 .filter(|q| !q.is_empty());
329 let Some(query) = query else {
330 return Self::empty_response();
331 };
332
333 let memories = self.find_related_memories(&query)?;
334 *lookup_performed = true;
335 *memories_found = memories.len();
336 span.record("lookup_performed", *lookup_performed);
337 span.record("memories_found", *memories_found);
338
339 let response = Self::build_memories_response(tool_name, &query, &memories);
340 Self::serialize_response(&response)
341 }
342}
343
344fn truncate_content(content: &str, max_len: usize) -> String {
346 if content.len() <= max_len {
347 content.to_string()
348 } else {
349 format!("{}...", &content[..max_len.saturating_sub(3)])
350 }
351}
352
353impl Default for PostToolUseHandler {
354 fn default() -> Self {
355 Self::new()
356 }
357}
358
359impl HookHandler for PostToolUseHandler {
360 fn event_type(&self) -> &'static str {
361 "PostToolUse"
362 }
363
364 #[instrument(
365 name = "subcog.hook.post_tool_use",
366 skip(self, input),
367 fields(
368 request_id = tracing::field::Empty,
369 component = "hooks",
370 operation = "post_tool_use",
371 hook = "PostToolUse",
372 tool_name = tracing::field::Empty,
373 lookup_performed = tracing::field::Empty,
374 memories_found = tracing::field::Empty
375 )
376 )]
377 fn handle(&self, input: &str) -> Result<String> {
378 let start = Instant::now();
379 let mut lookup_performed = false;
380 let mut memories_found = 0usize;
381 if let Some(request_id) = current_request_id() {
382 tracing::Span::current().record("request_id", request_id.as_str());
383 }
384
385 let result = self.handle_inner(input, &mut lookup_performed, &mut memories_found);
386
387 let status = if result.is_ok() { "success" } else { "error" };
388 metrics::counter!(
389 "hook_executions_total",
390 "hook_type" => "PostToolUse",
391 "status" => status
392 )
393 .increment(1);
394 metrics::histogram!("hook_duration_ms", "hook_type" => "PostToolUse")
395 .record(start.elapsed().as_secs_f64() * 1000.0);
396 if lookup_performed {
397 metrics::counter!(
398 "hook_memory_lookup_total",
399 "hook_type" => "PostToolUse",
400 "result" => if memories_found > 0 { "hit" } else { "miss" }
401 )
402 .increment(1);
403 }
404
405 result
406 }
407}
408
409#[derive(Debug, Clone)]
411pub struct RelatedMemory {
412 pub urn: String,
414 pub namespace: String,
416 pub content: String,
418 pub relevance: f32,
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_handler_creation() {
428 let handler = PostToolUseHandler::default();
429 assert_eq!(handler.event_type(), "PostToolUse");
430 }
431
432 #[test]
433 fn test_should_lookup() {
434 let handler = PostToolUseHandler::default();
435
436 assert!(handler.should_lookup("Read"));
437 assert!(handler.should_lookup("read"));
438 assert!(handler.should_lookup("Write"));
439 assert!(handler.should_lookup("Bash"));
440 assert!(handler.should_lookup("Grep"));
441 assert!(!handler.should_lookup("Unknown"));
442 assert!(!handler.should_lookup(""));
443 }
444
445 #[test]
446 fn test_extract_query_read() {
447 let handler = PostToolUseHandler::default();
448
449 let input = serde_json::json!({
450 "file_path": "/src/services/capture.rs"
451 });
452
453 let query = handler.extract_query("Read", &input);
454 assert!(query.is_some());
455 assert!(query.as_ref().is_some_and(|q| q.contains("capture")));
456 }
457
458 #[test]
459 fn test_extract_query_bash() {
460 let handler = PostToolUseHandler::default();
461
462 let input = serde_json::json!({
463 "command": "cargo test --all-features"
464 });
465
466 let query = handler.extract_query("Bash", &input);
467 assert!(query.is_some());
468 assert!(query.as_ref().is_some_and(|q| q.contains("cargo")));
469 }
470
471 #[test]
472 fn test_extract_query_grep() {
473 let handler = PostToolUseHandler::default();
474
475 let input = serde_json::json!({
476 "pattern": "fn capture"
477 });
478
479 let query = handler.extract_query("grep", &input);
480 assert!(query.is_some());
481 assert_eq!(query, Some("fn capture".to_string()));
482 }
483
484 #[test]
485 fn test_handle_non_contextual_tool() {
486 let handler = PostToolUseHandler::default();
487
488 let input = r#"{"tool_name": "SomeOtherTool", "tool_input": {}}"#;
489
490 let result = handler.handle(input);
491 assert!(result.is_ok());
492
493 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
494 assert!(response.as_object().unwrap().is_empty());
496 }
497
498 #[test]
499 fn test_handle_contextual_tool() {
500 let handler = PostToolUseHandler::default();
501
502 let input = r#"{"tool_name": "Read", "tool_input": {"file_path": "/src/main.rs"}}"#;
503
504 let result = handler.handle(input);
505 assert!(result.is_ok());
506
507 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
508 assert!(response.as_object().unwrap().is_empty());
511 }
512
513 #[test]
514 fn test_truncate_content() {
515 let short = "Short text";
516 assert_eq!(truncate_content(short, 100), short);
517
518 let long =
519 "This is a much longer text that should be truncated because it exceeds the limit";
520 let truncated = truncate_content(long, 30);
521 assert!(truncated.ends_with("..."));
522 assert!(truncated.len() <= 30);
523 }
524
525 #[test]
526 fn test_configuration() {
527 let handler = PostToolUseHandler::default()
528 .with_max_memories(5)
529 .with_min_relevance(0.7);
530
531 assert_eq!(handler.max_memories, 5);
532 assert!((handler.min_relevance - 0.7).abs() < f32::EPSILON);
533 }
534
535 #[test]
536 fn test_is_prompt_save_tool() {
537 assert!(PostToolUseHandler::is_prompt_save_tool("prompt_save"));
538 assert!(PostToolUseHandler::is_prompt_save_tool("PROMPT_SAVE"));
539 assert!(PostToolUseHandler::is_prompt_save_tool("prompt.save"));
540 assert!(PostToolUseHandler::is_prompt_save_tool(
541 "subcog_prompt_save"
542 ));
543 assert!(!PostToolUseHandler::is_prompt_save_tool("prompt_get"));
544 assert!(!PostToolUseHandler::is_prompt_save_tool("subcog_capture"));
545 }
546
547 #[test]
548 fn test_handle_prompt_save_valid() {
549 let handler = PostToolUseHandler::default();
550
551 let input = serde_json::json!({
552 "tool_name": "prompt_save",
553 "tool_input": {
554 "name": "test-prompt",
555 "content": "Hello {{name}}, welcome to {{place}}!"
556 }
557 });
558
559 let result = handler.handle(&serde_json::to_string(&input).unwrap());
560 assert!(result.is_ok());
561
562 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
563 assert!(response.as_object().unwrap().is_empty());
565 }
566
567 #[test]
568 fn test_handle_prompt_save_invalid_braces() {
569 let handler = PostToolUseHandler::default();
570
571 let input = serde_json::json!({
572 "tool_name": "prompt_save",
573 "tool_input": {
574 "name": "test-prompt",
575 "content": "Hello {{name, this is broken"
576 }
577 });
578
579 let result = handler.handle(&serde_json::to_string(&input).unwrap());
580 assert!(result.is_ok());
581
582 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
583 assert!(response.get("hookSpecificOutput").is_some());
585
586 let additional_context = response
587 .get("hookSpecificOutput")
588 .and_then(|o| o.get("additionalContext"))
589 .and_then(|v| v.as_str())
590 .unwrap_or("");
591 assert!(additional_context.contains("Prompt Validation Issues"));
592 assert!(additional_context.contains("subcog://help/prompts"));
593 }
594
595 #[test]
596 fn test_validate_prompt_empty_content() {
597 let handler = PostToolUseHandler::default();
598
599 let input = serde_json::json!({
600 "content": ""
601 });
602
603 let guidance = handler.validate_prompt(&input);
605 assert!(guidance.is_none());
606 }
607
608 #[test]
609 fn test_validate_prompt_missing_content() {
610 let handler = PostToolUseHandler::default();
611
612 let input = serde_json::json!({
613 "name": "test"
614 });
615
616 let guidance = handler.validate_prompt(&input);
618 assert!(guidance.is_none());
619 }
620}