subcog/hooks/search_intent/
llm.rs1use super::types::{DetectionSource, SearchIntent, SearchIntentType};
7use crate::Result;
8use crate::llm::LlmProvider as LlmProviderTrait;
9use serde::{Deserialize, Serialize};
10
11const LLM_INTENT_PROMPT: &str = "Classify the search intent of the following user prompt.
13
14USER PROMPT:
15<<PROMPT>>
16
17Respond with a JSON object containing:
18- \"intent_type\": one of \"howto\", \"location\", \"explanation\", \"comparison\", \"troubleshoot\", \"general\"
19- \"confidence\": a float from 0.0 to 1.0
20- \"topics\": array of up to 5 relevant topic strings
21- \"reasoning\": brief explanation of classification
22
23Response (JSON only):";
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27struct LlmIntentResponse {
28 pub intent_type: String,
29 pub confidence: f32,
30 #[serde(default)]
31 pub topics: Vec<String>,
32 #[serde(default)]
33 pub reasoning: String,
34}
35
36pub fn classify_intent_with_llm<P: LlmProviderTrait + ?Sized>(
51 provider: &P,
52 prompt: &str,
53) -> Result<SearchIntent> {
54 let classification_prompt = LLM_INTENT_PROMPT.replace("<<PROMPT>>", prompt);
55 let response = provider.complete(&classification_prompt)?;
56 parse_llm_response(&response)
57}
58
59fn parse_llm_response(response: &str) -> Result<SearchIntent> {
61 let json_str = extract_json_from_response(response);
63
64 let parsed: LlmIntentResponse =
65 serde_json::from_str(json_str).map_err(|e| crate::Error::OperationFailed {
66 operation: "parse_llm_intent_response".to_string(),
67 cause: format!("Invalid JSON: {e}"),
68 })?;
69
70 let intent_type =
71 SearchIntentType::parse(&parsed.intent_type).unwrap_or(SearchIntentType::General);
72
73 Ok(SearchIntent {
74 intent_type,
75 confidence: parsed.confidence.clamp(0.0, 1.0),
76 keywords: Vec::new(), topics: parsed.topics,
78 source: DetectionSource::Llm,
79 })
80}
81
82fn extract_json_from_response(response: &str) -> &str {
84 let trimmed = response.trim();
85
86 if let Some((json_start, end)) = trimmed.find("```json").and_then(|start| {
88 let json_start = start + 7;
89 trimmed[json_start..]
90 .find("```")
91 .map(|end| (json_start, end))
92 }) {
93 return trimmed[json_start..json_start + end].trim();
94 }
95
96 if let Some((json_start, end)) = trimmed.find("```").and_then(|start| {
98 let content_start = start + 3;
99 let after_marker = &trimmed[content_start..];
101 let json_start = after_marker
102 .find('{')
103 .map_or(content_start, |pos| content_start + pos);
104 trimmed[json_start..]
105 .find("```")
106 .map(|end| (json_start, end))
107 }) {
108 return trimmed[json_start..json_start + end].trim();
109 }
110
111 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
113 return &trimmed[start..=end];
114 }
115
116 trimmed
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_extract_json_from_code_block() {
125 let response = r#"```json
126{"intent_type": "howto", "confidence": 0.9, "topics": ["rust"]}
127```"#;
128 let json = extract_json_from_response(response);
129 assert!(json.starts_with('{'));
130 assert!(json.contains("howto"));
131 }
132
133 #[test]
134 fn test_extract_json_raw() {
135 let response = r#"{"intent_type": "location", "confidence": 0.8, "topics": []}"#;
136 let json = extract_json_from_response(response);
137 assert_eq!(json, response);
138 }
139
140 #[test]
141 fn test_extract_json_with_text_before() {
142 let response = r#"Here's the classification:
143{"intent_type": "troubleshoot", "confidence": 0.75, "topics": ["error"]}"#;
144 let json = extract_json_from_response(response);
145 assert!(json.starts_with('{'));
146 assert!(json.contains("troubleshoot"));
147 }
148
149 #[test]
150 fn test_parse_llm_response_valid() {
151 let response = r#"{"intent_type": "howto", "confidence": 0.85, "topics": ["authentication", "oauth"]}"#;
152 let result = parse_llm_response(response);
153 assert!(result.is_ok());
154 let intent = result.unwrap();
155 assert_eq!(intent.intent_type, SearchIntentType::HowTo);
156 assert!((intent.confidence - 0.85).abs() < f32::EPSILON);
157 assert_eq!(intent.topics, vec!["authentication", "oauth"]);
158 assert_eq!(intent.source, DetectionSource::Llm);
159 }
160
161 #[test]
162 fn test_parse_llm_response_unknown_intent_defaults_to_general() {
163 let response = r#"{"intent_type": "unknown_type", "confidence": 0.5, "topics": []}"#;
164 let result = parse_llm_response(response);
165 assert!(result.is_ok());
166 assert_eq!(result.unwrap().intent_type, SearchIntentType::General);
167 }
168
169 #[test]
170 fn test_parse_llm_response_confidence_clamped() {
171 let response = r#"{"intent_type": "howto", "confidence": 1.5, "topics": []}"#;
172 let result = parse_llm_response(response);
173 assert!(result.is_ok());
174 assert!((result.unwrap().confidence - 1.0).abs() < f32::EPSILON);
175
176 let response = r#"{"intent_type": "howto", "confidence": -0.5, "topics": []}"#;
177 let result = parse_llm_response(response);
178 assert!(result.is_ok());
179 assert!(result.unwrap().confidence.abs() < f32::EPSILON);
180 }
181
182 #[test]
183 fn test_parse_llm_response_invalid_json() {
184 let response = "not valid json";
185 let result = parse_llm_response(response);
186 assert!(result.is_err());
187 }
188
189 #[test]
190 fn test_parse_llm_response_missing_optional_fields() {
191 let response = r#"{"intent_type": "location", "confidence": 0.7}"#;
192 let result = parse_llm_response(response);
193 assert!(result.is_ok());
194 let intent = result.unwrap();
195 assert!(intent.topics.is_empty());
196 }
197}