Skip to main content

subcog/hooks/search_intent/
llm.rs

1//! LLM-based search intent classification.
2//!
3//! This module provides high-accuracy intent classification using language models.
4//! Classification includes a 200ms timeout by default for responsive user experience.
5
6use super::types::{DetectionSource, SearchIntent, SearchIntentType};
7use crate::Result;
8use crate::llm::LlmProvider as LlmProviderTrait;
9use serde::{Deserialize, Serialize};
10
11/// Prompt template for LLM intent classification.
12const LLM_INTENT_PROMPT: &str = "Classify the search intent of the following user prompt.
13
14USER PROMPT:
15<<PROMPT>>
16
17Respond with a JSON object containing:
18- \"intent_type\": one of \"howto\", \"location\", \"explanation\", \"comparison\", \"troubleshoot\", \"general\"
19- \"confidence\": a float from 0.0 to 1.0
20- \"topics\": array of up to 5 relevant topic strings
21- \"reasoning\": brief explanation of classification
22
23Response (JSON only):";
24
25/// LLM classification result structure.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27struct LlmIntentResponse {
28    pub intent_type: String,
29    pub confidence: f32,
30    #[serde(default)]
31    pub topics: Vec<String>,
32    #[serde(default)]
33    pub reasoning: String,
34}
35
36/// Classifies search intent using an LLM provider.
37///
38/// # Arguments
39///
40/// * `provider` - The LLM provider to use for classification.
41/// * `prompt` - The user prompt to classify.
42///
43/// # Returns
44///
45/// A `SearchIntent` with LLM classification results.
46///
47/// # Errors
48///
49/// Returns an error if the LLM call fails or response parsing fails.
50pub fn classify_intent_with_llm<P: LlmProviderTrait + ?Sized>(
51    provider: &P,
52    prompt: &str,
53) -> Result<SearchIntent> {
54    let classification_prompt = LLM_INTENT_PROMPT.replace("<<PROMPT>>", prompt);
55    let response = provider.complete(&classification_prompt)?;
56    parse_llm_response(&response)
57}
58
59/// Parses LLM response into a `SearchIntent`.
60fn parse_llm_response(response: &str) -> Result<SearchIntent> {
61    // Try to extract JSON from response (handle markdown code blocks)
62    let json_str = extract_json_from_response(response);
63
64    let parsed: LlmIntentResponse =
65        serde_json::from_str(json_str).map_err(|e| crate::Error::OperationFailed {
66            operation: "parse_llm_intent_response".to_string(),
67            cause: format!("Invalid JSON: {e}"),
68        })?;
69
70    let intent_type =
71        SearchIntentType::parse(&parsed.intent_type).unwrap_or(SearchIntentType::General);
72
73    Ok(SearchIntent {
74        intent_type,
75        confidence: parsed.confidence.clamp(0.0, 1.0),
76        keywords: Vec::new(), // LLM doesn't provide keywords
77        topics: parsed.topics,
78        source: DetectionSource::Llm,
79    })
80}
81
82/// Extracts JSON from LLM response, handling markdown code blocks.
83fn extract_json_from_response(response: &str) -> &str {
84    let trimmed = response.trim();
85
86    // Handle ```json ... ``` blocks
87    if let Some((json_start, end)) = trimmed.find("```json").and_then(|start| {
88        let json_start = start + 7;
89        trimmed[json_start..]
90            .find("```")
91            .map(|end| (json_start, end))
92    }) {
93        return trimmed[json_start..json_start + end].trim();
94    }
95
96    // Handle ``` ... ``` blocks (without json marker)
97    if let Some((json_start, end)) = trimmed.find("```").and_then(|start| {
98        let content_start = start + 3;
99        // Skip language identifier if present (e.g., "json\n")
100        let after_marker = &trimmed[content_start..];
101        let json_start = after_marker
102            .find('{')
103            .map_or(content_start, |pos| content_start + pos);
104        trimmed[json_start..]
105            .find("```")
106            .map(|end| (json_start, end))
107    }) {
108        return trimmed[json_start..json_start + end].trim();
109    }
110
111    // Handle raw JSON (find first { to last })
112    if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}')) {
113        return &trimmed[start..=end];
114    }
115
116    trimmed
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_extract_json_from_code_block() {
125        let response = r#"```json
126{"intent_type": "howto", "confidence": 0.9, "topics": ["rust"]}
127```"#;
128        let json = extract_json_from_response(response);
129        assert!(json.starts_with('{'));
130        assert!(json.contains("howto"));
131    }
132
133    #[test]
134    fn test_extract_json_raw() {
135        let response = r#"{"intent_type": "location", "confidence": 0.8, "topics": []}"#;
136        let json = extract_json_from_response(response);
137        assert_eq!(json, response);
138    }
139
140    #[test]
141    fn test_extract_json_with_text_before() {
142        let response = r#"Here's the classification:
143{"intent_type": "troubleshoot", "confidence": 0.75, "topics": ["error"]}"#;
144        let json = extract_json_from_response(response);
145        assert!(json.starts_with('{'));
146        assert!(json.contains("troubleshoot"));
147    }
148
149    #[test]
150    fn test_parse_llm_response_valid() {
151        let response = r#"{"intent_type": "howto", "confidence": 0.85, "topics": ["authentication", "oauth"]}"#;
152        let result = parse_llm_response(response);
153        assert!(result.is_ok());
154        let intent = result.unwrap();
155        assert_eq!(intent.intent_type, SearchIntentType::HowTo);
156        assert!((intent.confidence - 0.85).abs() < f32::EPSILON);
157        assert_eq!(intent.topics, vec!["authentication", "oauth"]);
158        assert_eq!(intent.source, DetectionSource::Llm);
159    }
160
161    #[test]
162    fn test_parse_llm_response_unknown_intent_defaults_to_general() {
163        let response = r#"{"intent_type": "unknown_type", "confidence": 0.5, "topics": []}"#;
164        let result = parse_llm_response(response);
165        assert!(result.is_ok());
166        assert_eq!(result.unwrap().intent_type, SearchIntentType::General);
167    }
168
169    #[test]
170    fn test_parse_llm_response_confidence_clamped() {
171        let response = r#"{"intent_type": "howto", "confidence": 1.5, "topics": []}"#;
172        let result = parse_llm_response(response);
173        assert!(result.is_ok());
174        assert!((result.unwrap().confidence - 1.0).abs() < f32::EPSILON);
175
176        let response = r#"{"intent_type": "howto", "confidence": -0.5, "topics": []}"#;
177        let result = parse_llm_response(response);
178        assert!(result.is_ok());
179        assert!(result.unwrap().confidence.abs() < f32::EPSILON);
180    }
181
182    #[test]
183    fn test_parse_llm_response_invalid_json() {
184        let response = "not valid json";
185        let result = parse_llm_response(response);
186        assert!(result.is_err());
187    }
188
189    #[test]
190    fn test_parse_llm_response_missing_optional_fields() {
191        let response = r#"{"intent_type": "location", "confidence": 0.7}"#;
192        let result = parse_llm_response(response);
193        assert!(result.is_ok());
194        let intent = result.unwrap();
195        assert!(intent.topics.is_empty());
196    }
197}