1#![allow(clippy::expect_used)]
4
5use super::HookHandler;
6use super::search_context::{AdaptiveContextConfig, MemoryContext, SearchContextBuilder};
7use super::search_intent::{
8 SearchIntent, detect_search_intent, detect_search_intent_hybrid,
9 detect_search_intent_with_timeout,
10};
11use crate::Result;
12use crate::config::SearchIntentConfig;
13use crate::llm::LlmProvider;
14use crate::models::Namespace;
15use crate::services::RecallService;
16use regex::Regex;
17use std::sync::{Arc, LazyLock};
18use tracing::instrument;
19
20pub struct UserPromptHandler {
24 confidence_threshold: f32,
26 search_intent_threshold: f32,
28 context_config: AdaptiveContextConfig,
30 recall_service: Option<RecallService>,
32 llm_provider: Option<Arc<dyn LlmProvider>>,
34 search_intent_config: SearchIntentConfig,
36}
37
38static DECISION_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
40 vec![
41 Regex::new(r"(?i)\b(we('re| are|'ll| will) (going to |gonna )?use|let's use|using)\b").ok(),
42 Regex::new(r"(?i)\b(decided|decision|choosing|chose|picked|selected)\b").ok(),
43 Regex::new(r"(?i)\b(architecture|design|approach|strategy|solution)\b").ok(),
44 Regex::new(r"(?i)\b(from now on|going forward|henceforth)\b").ok(),
45 Regex::new(r"(?i)\b(always|never) (do|use|implement)\b").ok(),
46 ]
47 .into_iter()
48 .flatten()
49 .collect()
50});
51
52static PATTERN_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
53 vec![
54 Regex::new(r"(?i)\b(pattern|convention|standard|best practice)\b").ok(),
55 Regex::new(r"(?i)\b(always|never|should|must)\b.*\b(when|if|before|after)\b").ok(),
56 Regex::new(r"(?i)\b(rule|guideline|principle)\b").ok(),
57 ]
58 .into_iter()
59 .flatten()
60 .collect()
61});
62
63static LEARNING_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
64 vec![
65 Regex::new(r"(?i)\b(learned|discovered|realized|found out|figured out)\b").ok(),
66 Regex::new(r"(?i)\b(TIL|turns out|apparently|actually)\b").ok(),
67 Regex::new(r"(?i)\b(gotcha|caveat|quirk|edge case)\b").ok(),
68 Regex::new(r"(?i)\b(insight|understanding|revelation)\b").ok(),
69 ]
70 .into_iter()
71 .flatten()
72 .collect()
73});
74
75static BLOCKER_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
76 vec![
77 Regex::new(r"(?i)\b(blocked|stuck|issue|problem|bug|error)\b").ok(),
78 Regex::new(r"(?i)\b(fixed|solved|resolved|workaround|solution)\b").ok(),
79 Regex::new(r"(?i)\b(doesn't work|not working|broken|fails)\b").ok(),
80 ]
81 .into_iter()
82 .flatten()
83 .collect()
84});
85
86static TECH_DEBT_PATTERNS: LazyLock<Vec<Regex>> = LazyLock::new(|| {
87 vec![
88 Regex::new(r"(?i)\b(tech debt|technical debt|refactor|cleanup)\b").ok(),
89 Regex::new(r"(?i)\b(TODO|FIXME|HACK|XXX)\b").ok(),
90 Regex::new(r"(?i)\b(temporary|workaround|quick fix|shortcut)\b").ok(),
91 ]
92 .into_iter()
93 .flatten()
94 .collect()
95});
96
97static CAPTURE_COMMAND: LazyLock<Regex> = LazyLock::new(|| {
99 Regex::new(r"(?i)^@?subcog\s+(capture|remember|save|store)\b")
101 .expect("static regex: capture command pattern")
102});
103
104#[derive(Debug, Clone)]
106pub struct CaptureSignal {
107 pub namespace: Namespace,
109 pub confidence: f32,
111 pub matched_patterns: Vec<String>,
113 pub is_explicit: bool,
115}
116
117impl UserPromptHandler {
118 #[must_use]
120 pub fn new() -> Self {
121 Self {
122 confidence_threshold: 0.6,
123 search_intent_threshold: 0.5,
124 context_config: AdaptiveContextConfig::default(),
125 recall_service: None,
126 llm_provider: None,
127 search_intent_config: SearchIntentConfig::default(),
128 }
129 }
130
131 #[must_use]
133 pub const fn with_confidence_threshold(mut self, threshold: f32) -> Self {
134 self.confidence_threshold = threshold;
135 self
136 }
137
138 #[must_use]
140 pub const fn with_search_intent_threshold(mut self, threshold: f32) -> Self {
141 self.search_intent_threshold = threshold;
142 self
143 }
144
145 #[must_use]
147 pub const fn with_context_config(mut self, config: AdaptiveContextConfig) -> Self {
148 self.context_config = config;
149 self
150 }
151
152 #[must_use]
154 pub fn with_recall_service(mut self, service: RecallService) -> Self {
155 self.recall_service = Some(service);
156 self
157 }
158
159 #[must_use]
161 pub fn with_llm_provider(mut self, provider: Arc<dyn LlmProvider>) -> Self {
162 self.llm_provider = Some(provider);
163 self
164 }
165
166 #[must_use]
168 pub const fn with_search_intent_config(mut self, config: SearchIntentConfig) -> Self {
169 self.search_intent_config = config;
170 self
171 }
172
173 fn build_memory_context(&self, intent: &SearchIntent) -> MemoryContext {
175 let mut builder = SearchContextBuilder::new().with_config(self.context_config.clone());
176
177 if let Some(ref recall) = self.recall_service {
178 builder = builder.with_recall_service(recall);
179 }
180
181 builder
183 .build_context(intent)
184 .unwrap_or_else(|_| MemoryContext::empty())
185 }
186
187 fn detect_search_intent(&self, prompt: &str) -> Option<SearchIntent> {
192 let intent = self.classify_intent(prompt);
193
194 if intent.confidence >= self.search_intent_threshold {
195 Some(intent)
196 } else {
197 None
198 }
199 }
200
201 fn classify_intent(&self, prompt: &str) -> SearchIntent {
203 self.llm_provider.as_ref().map_or_else(
204 || self.classify_without_llm(prompt),
205 |provider| {
206 detect_search_intent_hybrid(
207 Some(provider.as_ref()),
208 prompt,
209 &self.search_intent_config,
210 )
211 },
212 )
213 }
214
215 fn classify_without_llm(&self, prompt: &str) -> SearchIntent {
217 if self.search_intent_config.use_llm {
218 detect_search_intent_with_timeout::<crate::llm::AnthropicClient>(
221 None,
222 prompt,
223 &self.search_intent_config,
224 )
225 } else {
226 detect_search_intent(prompt).unwrap_or_default()
228 }
229 }
230
231 fn detect_signals(&self, prompt: &str) -> Vec<CaptureSignal> {
233 let mut signals = Vec::new();
234
235 if CAPTURE_COMMAND.is_match(prompt) {
237 signals.push(CaptureSignal {
238 namespace: Namespace::Decisions,
239 confidence: 1.0,
240 matched_patterns: vec!["explicit_command".to_string()],
241 is_explicit: true,
242 });
243 return signals;
244 }
245
246 self.check_patterns(
248 &DECISION_PATTERNS,
249 Namespace::Decisions,
250 prompt,
251 &mut signals,
252 );
253 self.check_patterns(&PATTERN_PATTERNS, Namespace::Patterns, prompt, &mut signals);
254 self.check_patterns(
255 &LEARNING_PATTERNS,
256 Namespace::Learnings,
257 prompt,
258 &mut signals,
259 );
260 self.check_patterns(&BLOCKER_PATTERNS, Namespace::Blockers, prompt, &mut signals);
261 self.check_patterns(
262 &TECH_DEBT_PATTERNS,
263 Namespace::TechDebt,
264 prompt,
265 &mut signals,
266 );
267
268 signals.sort_by(|a, b| {
270 b.confidence
271 .partial_cmp(&a.confidence)
272 .unwrap_or(std::cmp::Ordering::Equal)
273 });
274
275 signals
276 }
277
278 fn check_patterns(
280 &self,
281 patterns: &[Regex],
282 namespace: Namespace,
283 prompt: &str,
284 signals: &mut Vec<CaptureSignal>,
285 ) {
286 let pattern_matches: Vec<String> = patterns
287 .iter()
288 .filter(|p| p.is_match(prompt))
289 .map(std::string::ToString::to_string)
290 .collect();
291
292 if pattern_matches.is_empty() {
293 return;
294 }
295
296 let confidence = calculate_confidence(&pattern_matches, prompt);
297 if confidence < self.confidence_threshold {
298 return;
299 }
300
301 signals.push(CaptureSignal {
302 namespace,
303 confidence,
304 matched_patterns: pattern_matches,
305 is_explicit: false,
306 });
307 }
308
309 fn extract_content(&self, prompt: &str) -> String {
311 let content = CAPTURE_COMMAND.replace(prompt, "").trim().to_string();
313
314 let content = content
316 .trim_start_matches(':')
317 .trim_start_matches('-')
318 .trim();
319
320 content.to_string()
321 }
322}
323
324#[allow(clippy::cast_precision_loss)]
326fn calculate_confidence(pattern_matches: &[String], prompt: &str) -> f32 {
327 let base_confidence = 0.5;
328 let match_bonus = 0.15_f32.min(pattern_matches.len() as f32 * 0.1);
329
330 let length_factor = if prompt.len() > 50 { 0.1 } else { 0.0 };
332
333 let sentence_factor = if prompt.contains('.') || prompt.contains('!') || prompt.contains('?') {
335 0.1
336 } else {
337 0.0
338 };
339
340 (base_confidence + match_bonus + length_factor + sentence_factor).min(0.95)
341}
342
343impl Default for UserPromptHandler {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349impl HookHandler for UserPromptHandler {
350 fn event_type(&self) -> &'static str {
351 "UserPromptSubmit"
352 }
353
354 #[instrument(skip(self, input), fields(hook = "UserPromptSubmit"))]
355 fn handle(&self, input: &str) -> Result<String> {
356 let input_json: serde_json::Value =
358 serde_json::from_str(input).unwrap_or_else(|_| serde_json::json!({}));
359
360 let prompt = input_json
362 .get("prompt")
363 .and_then(|v| v.as_str())
364 .unwrap_or("");
365
366 if prompt.is_empty() {
367 let response = serde_json::json!({});
369 return serde_json::to_string(&response).map_err(|e| crate::Error::OperationFailed {
370 operation: "serialize_response".to_string(),
371 cause: e.to_string(),
372 });
373 }
374
375 let signals = self.detect_signals(prompt);
377
378 let should_capture = signals
380 .iter()
381 .any(|s| s.confidence >= self.confidence_threshold);
382
383 let content = if should_capture {
385 Some(self.extract_content(prompt))
386 } else {
387 None
388 };
389
390 let signals_json: Vec<serde_json::Value> = signals
392 .iter()
393 .map(|s| {
394 serde_json::json!({
395 "namespace": s.namespace.as_str(),
396 "confidence": s.confidence,
397 "matched_patterns": s.matched_patterns,
398 "is_explicit": s.is_explicit
399 })
400 })
401 .collect();
402
403 let mut metadata = serde_json::json!({
404 "signals": signals_json,
405 "should_capture": should_capture,
406 "confidence_threshold": self.confidence_threshold
407 });
408
409 let search_intent = self.detect_search_intent(prompt);
411
412 let memory_context = if let Some(ref intent) = search_intent {
414 let ctx = self.build_memory_context(intent);
415 metadata["search_intent"] = serde_json::json!({
416 "detected": ctx.search_intent_detected,
417 "intent_type": ctx.intent_type,
418 "confidence": intent.confidence,
419 "topics": ctx.topics,
420 "keywords": intent.keywords,
421 "source": intent.source.as_str()
422 });
423 metadata["memory_context"] =
424 serde_json::to_value(&ctx).unwrap_or(serde_json::Value::Null);
425 Some(ctx)
426 } else {
427 metadata["search_intent"] = serde_json::json!({
428 "detected": false
429 });
430 None
431 };
432
433 let context_message =
435 build_capture_context(should_capture, content.as_ref(), &signals, &mut metadata);
436
437 let search_context = memory_context.as_ref().map(build_memory_context_text);
439
440 let combined_context = match (context_message, search_context) {
442 (Some(capture), Some(search)) => Some(format!("{capture}\n\n---\n\n{search}")),
443 (Some(capture), None) => Some(capture),
444 (None, Some(search)) => Some(search),
445 (None, None) => None,
446 };
447
448 let response = combined_context.map_or_else(
451 || serde_json::json!({}),
452 |ctx| {
453 let metadata_str = serde_json::to_string(&metadata).unwrap_or_default();
455 let context_with_metadata =
456 format!("{ctx}\n\n<!-- subcog-metadata: {metadata_str} -->");
457 serde_json::json!({
458 "hookSpecificOutput": {
459 "hookEventName": "UserPromptSubmit",
460 "additionalContext": context_with_metadata
461 }
462 })
463 },
464 );
465
466 serde_json::to_string(&response).map_err(|e| crate::Error::OperationFailed {
467 operation: "serialize_response".to_string(),
468 cause: e.to_string(),
469 })
470 }
471}
472
473fn build_capture_context(
475 should_capture: bool,
476 content: Option<&String>,
477 signals: &[CaptureSignal],
478 metadata: &mut serde_json::Value,
479) -> Option<String> {
480 if !should_capture {
481 return None;
482 }
483
484 let content_str = content.map_or("", String::as_str);
485 if content_str.is_empty() {
486 return None;
487 }
488
489 let top_signal = signals.first()?;
491
492 metadata["capture_suggestion"] = serde_json::json!({
494 "namespace": top_signal.namespace.as_str(),
495 "content_preview": truncate_for_display(content_str, 100),
496 "confidence": top_signal.confidence,
497 });
498
499 let mut lines = vec!["**Subcog Capture Suggestion**\n".to_string()];
501
502 if top_signal.is_explicit {
503 lines.push(format!(
504 "Explicit capture command detected. Capturing to `{}`:\n",
505 top_signal.namespace.as_str()
506 ));
507 lines.push(format!("> {}", truncate_for_display(content_str, 200)));
508 lines.push(
509 "\nUse `mcp__plugin_subcog_subcog__subcog_capture` tool to save this memory."
510 .to_string(),
511 );
512 } else {
513 lines.push(format!(
514 "Detected {} signal (confidence: {:.0}%):\n",
515 top_signal.namespace.as_str(),
516 top_signal.confidence * 100.0
517 ));
518 lines.push(format!("> {}", truncate_for_display(content_str, 200)));
519 lines.push(format!(
520 "\n**Suggestion**: Consider capturing this as a `{}` memory.",
521 top_signal.namespace.as_str()
522 ));
523 lines.push(
524 "Use `mcp__plugin_subcog_subcog__subcog_capture` tool or ask: \"Should I save this to subcog?\"".to_string(),
525 );
526 }
527
528 Some(lines.join("\n"))
529}
530
531fn truncate_for_display(content: &str, max_len: usize) -> String {
533 if content.len() <= max_len {
534 content.to_string()
535 } else {
536 format!("{}...", &content[..max_len.saturating_sub(3)])
537 }
538}
539
540fn build_memory_context_text(ctx: &MemoryContext) -> String {
542 let mut lines = vec!["**Subcog Memory Context**\n".to_string()];
543
544 if let Some(ref intent_type) = ctx.intent_type {
545 lines.push(format!("Intent type: **{intent_type}**\n"));
546 }
547
548 if !ctx.topics.is_empty() {
549 lines.push(format!("Topics: {}\n", ctx.topics.join(", ")));
550 }
551
552 if !ctx.injected_memories.is_empty() {
554 lines.push("\n**Relevant memories**:".to_string());
555 for memory in ctx.injected_memories.iter().take(5) {
556 lines.push(format!(
557 "- [{}] {}: {}",
558 memory.namespace,
559 memory.id,
560 truncate_for_display(&memory.content_preview, 80)
561 ));
562 }
563 }
564
565 if let Some(ref reminder) = ctx.reminder {
567 lines.push(format!("\n**Reminder**: {reminder}"));
568 }
569
570 if !ctx.suggested_resources.is_empty() {
572 lines.push("\n**Suggested resources**:".to_string());
573 for resource in ctx.suggested_resources.iter().take(4) {
574 lines.push(format!("- `{resource}`"));
575 }
576 }
577
578 lines.join("\n")
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584
585 #[test]
586 fn test_handler_creation() {
587 let handler = UserPromptHandler::default();
588 assert_eq!(handler.event_type(), "UserPromptSubmit");
589 }
590
591 #[test]
592 fn test_explicit_capture_command() {
593 let handler = UserPromptHandler::default();
594
595 let input = r#"{"prompt": "@subcog capture Use PostgreSQL for storage"}"#;
596
597 let result = handler.handle(input);
598 assert!(result.is_ok());
599
600 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
601 let hook_output = response.get("hookSpecificOutput").unwrap();
603 assert_eq!(
604 hook_output.get("hookEventName"),
605 Some(&serde_json::Value::String("UserPromptSubmit".to_string()))
606 );
607 let context = hook_output
609 .get("additionalContext")
610 .unwrap()
611 .as_str()
612 .unwrap();
613 assert!(context.contains("Subcog Capture Suggestion"));
614 assert!(context.contains("subcog-metadata"));
615 }
616
617 #[test]
618 fn test_decision_signal_detection() {
619 let handler = UserPromptHandler::default();
620
621 let signals = handler.detect_signals("We're going to use Rust for this project");
622 assert!(!signals.is_empty());
623 assert!(signals.iter().any(|s| s.namespace == Namespace::Decisions));
624 }
625
626 #[test]
627 fn test_learning_signal_detection() {
628 let handler = UserPromptHandler::default();
629
630 let signals = handler.detect_signals("TIL that SQLite has a row limit of 2GB");
631 assert!(!signals.is_empty());
632 assert!(signals.iter().any(|s| s.namespace == Namespace::Learnings));
633 }
634
635 #[test]
636 fn test_pattern_signal_detection() {
637 let handler = UserPromptHandler::default();
638
639 let signals = handler
640 .detect_signals("The best practice is to always validate input before processing");
641 assert!(!signals.is_empty());
642 assert!(signals.iter().any(|s| s.namespace == Namespace::Patterns));
643 }
644
645 #[test]
646 fn test_blocker_signal_detection() {
647 let handler = UserPromptHandler::default();
648
649 let signals = handler.detect_signals("I fixed the bug by adding a null check");
650 assert!(!signals.is_empty());
651 assert!(signals.iter().any(|s| s.namespace == Namespace::Blockers));
652 }
653
654 #[test]
655 fn test_tech_debt_signal_detection() {
656 let handler = UserPromptHandler::default();
657
658 let signals =
659 handler.detect_signals("This is a temporary workaround, we need to refactor later");
660 assert!(!signals.is_empty());
661 assert!(signals.iter().any(|s| s.namespace == Namespace::TechDebt));
662 }
663
664 #[test]
665 fn test_no_signals_for_generic_prompt() {
666 let handler = UserPromptHandler::default();
667
668 let signals = handler.detect_signals("Hello, how are you?");
669 for signal in &signals {
671 assert!(signal.confidence < 0.8);
672 }
673 }
674
675 #[test]
676 fn test_empty_prompt() {
677 let handler = UserPromptHandler::default();
678
679 let input = r#"{"prompt": ""}"#;
680
681 let result = handler.handle(input);
682 assert!(result.is_ok());
683
684 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
685 assert!(response.as_object().unwrap().is_empty());
687 }
688
689 #[test]
690 fn test_confidence_threshold() {
691 let handler = UserPromptHandler::default().with_confidence_threshold(0.9);
692
693 let signals = handler.detect_signals("maybe use something");
695 let high_confidence: Vec<_> = signals.iter().filter(|s| s.confidence >= 0.9).collect();
696 assert!(high_confidence.is_empty() || high_confidence.iter().all(|s| s.is_explicit));
698 }
699
700 #[test]
701 fn test_extract_content() {
702 let handler = UserPromptHandler::default();
703
704 let content = handler.extract_content("@subcog capture: Use PostgreSQL");
705 assert_eq!(content, "Use PostgreSQL");
706
707 let content = handler.extract_content("Just a regular prompt");
708 assert_eq!(content, "Just a regular prompt");
709 }
710
711 #[test]
712 fn test_calculate_confidence() {
713 let low = calculate_confidence(&["pattern1".to_string()], "short");
715 let high = calculate_confidence(
716 &["pattern1".to_string(), "pattern2".to_string()],
717 "This is a longer prompt with more context.",
718 );
719 assert!(high >= low);
720 }
721
722 #[test]
723 fn test_search_intent_detection_in_handle() {
724 let handler = UserPromptHandler::default();
725
726 let input = r#"{"prompt": "How do I implement authentication in this project?"}"#;
728 let result = handler.handle(input);
729 assert!(result.is_ok());
730
731 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
732 let hook_output = response.get("hookSpecificOutput").unwrap();
734 assert_eq!(
735 hook_output.get("hookEventName"),
736 Some(&serde_json::Value::String("UserPromptSubmit".to_string()))
737 );
738
739 let context = hook_output
741 .get("additionalContext")
742 .unwrap()
743 .as_str()
744 .unwrap();
745 assert!(context.contains("subcog-metadata"));
746 assert!(context.contains("search_intent"));
747 assert!(context.contains("\"detected\":true"));
748 assert!(context.contains("\"intent_type\":\"howto\""));
749 }
750
751 #[test]
752 fn test_search_intent_no_detection() {
753 let handler = UserPromptHandler::default();
754
755 let input = r#"{"prompt": "I finished the task."}"#;
757 let result = handler.handle(input);
758 assert!(result.is_ok());
759
760 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
761 assert!(response.as_object().unwrap().is_empty());
763 }
764
765 #[test]
766 fn test_search_intent_threshold() {
767 let handler = UserPromptHandler::default().with_search_intent_threshold(0.9);
768
769 let input = r#"{"prompt": "how to"}"#;
771 let result = handler.handle(input);
772 assert!(result.is_ok());
773
774 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
775 assert!(response.as_object().unwrap().is_empty());
777 }
778
779 #[test]
780 fn test_search_intent_topics_extraction() {
781 let handler = UserPromptHandler::default();
782
783 let input = r#"{"prompt": "How do I configure the database connection?"}"#;
784 let result = handler.handle(input);
785 assert!(result.is_ok());
786
787 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
788 let hook_output = response.get("hookSpecificOutput").unwrap();
790 let context = hook_output
791 .get("additionalContext")
792 .unwrap()
793 .as_str()
794 .unwrap();
795
796 assert!(context.contains("subcog-metadata"));
798 assert!(context.contains("\"topics\""));
799 assert!(context.contains("database") || context.contains("connection"));
801 }
802}