1use crate::llm::{LlmProvider, sanitize_llm_response_for_error};
51use crate::models::PromptVariable;
52use crate::{Error, Result};
53use serde::{Deserialize, Serialize};
54use std::time::Duration;
55use tracing::instrument;
56
57pub const PROMPT_ENRICHMENT_SYSTEM_PROMPT: &str = r#"<task>
61You are analyzing a prompt template to generate helpful metadata.
62Your goal is to understand the prompt's purpose and generate accurate descriptions
63for both the prompt itself and its variables.
64</task>
65
66<output_format>
67Respond with ONLY valid JSON, no markdown formatting.
68
69{
70 "description": "One sentence describing what this prompt does",
71 "tags": ["tag1", "tag2", "tag3"],
72 "variables": [
73 {
74 "name": "variable_name",
75 "description": "What this variable represents",
76 "required": true,
77 "default": null
78 }
79 ]
80}
81</output_format>
82
83<guidelines>
84- description: Clear, one-sentence summary of the prompt's purpose
85- tags: 2-5 lowercase, hyphenated tags (e.g., "code-review", "documentation")
86- variables: For each detected variable:
87 - description: What value the user should provide
88 - required: true if the prompt makes no sense without it
89 - default: Sensible default value, or null if none appropriate
90</guidelines>
91
92<rules>
93- Only include variables that were detected in the prompt
94- Use lowercase hyphenated format for tags
95- Keep descriptions concise but informative
96- Respond with valid JSON only, no explanation
97</rules>"#;
98
99fn escape_xml(s: &str) -> String {
100 let mut result = String::with_capacity(s.len());
101 for c in s.chars() {
102 match c {
103 '&' => result.push_str("&"),
104 '<' => result.push_str("<"),
105 '>' => result.push_str(">"),
106 '"' => result.push_str("""),
107 '\'' => result.push_str("'"),
108 _ => result.push(c),
109 }
110 }
111 result
112}
113
114pub const ENRICHMENT_TIMEOUT: Duration = Duration::from_secs(5);
116
117#[derive(Debug, Clone, Serialize)]
119pub struct EnrichmentRequest {
120 pub content: String,
122 pub variables: Vec<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub existing: Option<PartialMetadata>,
127}
128
129impl EnrichmentRequest {
130 #[must_use]
132 pub fn new(content: impl Into<String>, variables: Vec<String>) -> Self {
133 Self {
134 content: content.into(),
135 variables,
136 existing: None,
137 }
138 }
139
140 #[must_use]
142 pub fn with_existing(mut self, existing: PartialMetadata) -> Self {
143 self.existing = Some(existing);
144 self
145 }
146
147 #[must_use]
149 pub fn with_optional_existing(mut self, existing: Option<PartialMetadata>) -> Self {
150 self.existing = existing;
151 self
152 }
153}
154
155#[derive(Debug, Clone, Default, Serialize, Deserialize)]
159pub struct PartialMetadata {
160 #[serde(skip_serializing_if = "Option::is_none")]
162 pub description: Option<String>,
163 #[serde(default, skip_serializing_if = "Vec::is_empty")]
165 pub tags: Vec<String>,
166 #[serde(default, skip_serializing_if = "Vec::is_empty")]
168 pub variables: Vec<PromptVariable>,
169}
170
171impl PartialMetadata {
172 #[must_use]
174 pub fn new() -> Self {
175 Self::default()
176 }
177
178 #[must_use]
180 pub fn with_description(mut self, description: impl Into<String>) -> Self {
181 self.description = Some(description.into());
182 self
183 }
184
185 #[must_use]
187 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
188 self.tags = tags;
189 self
190 }
191
192 #[must_use]
194 pub fn with_variables(mut self, variables: Vec<PromptVariable>) -> Self {
195 self.variables = variables;
196 self
197 }
198
199 #[must_use]
201 pub const fn is_empty(&self) -> bool {
202 self.description.is_none() && self.tags.is_empty() && self.variables.is_empty()
203 }
204
205 #[must_use]
207 pub fn get_variable(&self, name: &str) -> Option<&PromptVariable> {
208 self.variables.iter().find(|v| v.name == name)
209 }
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct PromptEnrichmentResult {
215 pub description: String,
217 pub tags: Vec<String>,
219 pub variables: Vec<PromptVariable>,
221 #[serde(default)]
223 pub status: EnrichmentStatus,
224}
225
226#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
228#[serde(rename_all = "snake_case")]
229pub enum EnrichmentStatus {
230 #[default]
232 Full,
233 Fallback,
235 Skipped,
237}
238
239impl PromptEnrichmentResult {
240 #[must_use]
244 pub fn basic_from_variables(variables: &[String]) -> Self {
245 Self {
246 description: String::new(),
247 tags: Vec::new(),
248 variables: variables
249 .iter()
250 .map(|name| PromptVariable {
251 name: name.clone(),
252 description: None,
253 default: None,
254 required: true,
255 })
256 .collect(),
257 status: EnrichmentStatus::Fallback,
258 }
259 }
260
261 #[must_use]
265 pub fn merge_with_user(mut self, user: &PartialMetadata) -> Self {
266 if let Some(ref desc) = user.description {
268 self.description.clone_from(desc);
269 }
270
271 if !user.tags.is_empty() {
273 self.tags.clone_from(&user.tags);
274 }
275
276 for var in &mut self.variables {
278 let Some(user_var) = user.get_variable(&var.name) else {
279 continue;
280 };
281 if let Some(ref desc) = user_var.description {
283 var.description = Some(desc.clone());
284 }
285 if let Some(ref default) = user_var.default {
287 var.default = Some(default.clone());
288 }
289 var.required = user_var.required;
291 }
292
293 self
294 }
295}
296
297#[derive(Debug, Clone, Deserialize)]
299struct LlmEnrichmentResponse {
300 description: String,
301 #[serde(default)]
302 tags: Vec<String>,
303 #[serde(default)]
304 variables: Vec<LlmVariableResponse>,
305}
306
307#[derive(Debug, Clone, Deserialize)]
309struct LlmVariableResponse {
310 name: String,
311 #[serde(default)]
312 description: Option<String>,
313 #[serde(default = "default_required")]
314 required: bool,
315 #[serde(default)]
316 default: Option<String>,
317}
318
319const fn default_required() -> bool {
321 true
322}
323
324pub struct PromptEnrichmentService<P: LlmProvider> {
326 llm: P,
328}
329
330impl<P: LlmProvider> PromptEnrichmentService<P> {
331 #[must_use]
333 pub const fn new(llm: P) -> Self {
334 Self { llm }
335 }
336
337 #[instrument(skip(self), fields(operation = "prompt_enrich", variables_count = request.variables.len()))]
352 pub fn enrich(&self, request: &EnrichmentRequest) -> Result<PromptEnrichmentResult> {
353 let user_message = self.build_user_message(request);
355
356 let response = self
358 .llm
359 .complete_with_system(PROMPT_ENRICHMENT_SYSTEM_PROMPT, &user_message)?;
360
361 let llm_result = self.parse_response(&response, &request.variables)?;
363
364 let result = if let Some(ref existing) = request.existing {
366 llm_result.merge_with_user(existing)
367 } else {
368 llm_result
369 };
370
371 Ok(result)
372 }
373
374 #[instrument(skip(self), fields(operation = "prompt_enrich_fallback"))]
379 pub fn enrich_with_fallback(&self, request: &EnrichmentRequest) -> PromptEnrichmentResult {
380 match self.enrich(request) {
381 Ok(result) => result,
382 Err(e) => {
383 tracing::warn!("Prompt enrichment failed, using fallback: {}", e);
384 let mut result = PromptEnrichmentResult::basic_from_variables(&request.variables);
385
386 if let Some(ref existing) = request.existing {
388 result = result.merge_with_user(existing);
389 }
390
391 result
392 },
393 }
394 }
395
396 fn build_user_message(&self, request: &EnrichmentRequest) -> String {
398 let variables_str = if request.variables.is_empty() {
399 "No variables detected".to_string()
400 } else {
401 request.variables.join(", ")
402 };
403 let escaped_content = escape_xml(&request.content);
404 let escaped_variables = escape_xml(&variables_str);
405 format!(
406 "<prompt_content>\n{escaped_content}\n</prompt_content>\n\n<detected_variables>\n{escaped_variables}\n</detected_variables>"
407 )
408 }
409
410 fn parse_response(
412 &self,
413 response: &str,
414 expected_variables: &[String],
415 ) -> Result<PromptEnrichmentResult> {
416 let json_str = crate::llm::extract_json_from_response(response);
418
419 let sanitized = sanitize_llm_response_for_error(response);
421 let llm_response: LlmEnrichmentResponse =
422 serde_json::from_str(json_str).map_err(|e| Error::OperationFailed {
423 operation: "parse_enrichment_response".to_string(),
424 cause: format!("Failed to parse LLM response: {e}. Response was: {sanitized}"),
425 })?;
426
427 let mut variable_map: std::collections::HashMap<String, PromptVariable> = llm_response
429 .variables
430 .into_iter()
431 .map(|v| {
432 (
433 v.name.clone(),
434 PromptVariable {
435 name: v.name,
436 description: v.description,
437 default: v.default,
438 required: v.required,
439 },
440 )
441 })
442 .collect();
443
444 let variables: Vec<PromptVariable> = expected_variables
446 .iter()
447 .map(|name| {
448 variable_map.remove(name).unwrap_or_else(|| PromptVariable {
449 name: name.clone(),
450 description: None,
451 default: None,
452 required: true,
453 })
454 })
455 .collect();
456
457 Ok(PromptEnrichmentResult {
458 description: llm_response.description,
459 tags: llm_response.tags,
460 variables,
461 status: EnrichmentStatus::Full,
462 })
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 struct MockLlmProvider {
472 response: String,
473 should_fail: bool,
474 }
475
476 impl MockLlmProvider {
477 fn new(response: impl Into<String>) -> Self {
478 Self {
479 response: response.into(),
480 should_fail: false,
481 }
482 }
483
484 fn failing() -> Self {
485 Self {
486 response: String::new(),
487 should_fail: true,
488 }
489 }
490 }
491
492 impl LlmProvider for MockLlmProvider {
493 fn name(&self) -> &'static str {
494 "mock"
495 }
496
497 fn complete(&self, _prompt: &str) -> Result<String> {
498 if self.should_fail {
499 Err(Error::OperationFailed {
500 operation: "mock_complete".to_string(),
501 cause: "Mock LLM failure".to_string(),
502 })
503 } else {
504 Ok(self.response.clone())
505 }
506 }
507
508 fn complete_with_system(&self, _system: &str, _prompt: &str) -> Result<String> {
509 if self.should_fail {
510 Err(Error::OperationFailed {
511 operation: "mock_complete".to_string(),
512 cause: "Mock LLM failure".to_string(),
513 })
514 } else {
515 Ok(self.response.clone())
516 }
517 }
518
519 fn analyze_for_capture(&self, _content: &str) -> Result<crate::llm::CaptureAnalysis> {
520 Ok(crate::llm::CaptureAnalysis {
521 should_capture: true,
522 confidence: 0.8,
523 suggested_namespace: Some("decisions".to_string()),
524 suggested_tags: vec![],
525 reasoning: "Mock analysis".to_string(),
526 })
527 }
528 }
529
530 #[test]
531 fn test_enrichment_request_new() {
532 let request = EnrichmentRequest::new(
533 "Review {{file}} for {{issue_type}}",
534 vec!["file".to_string(), "issue_type".to_string()],
535 );
536 assert_eq!(request.content, "Review {{file}} for {{issue_type}}");
537 assert_eq!(request.variables.len(), 2);
538 assert!(request.existing.is_none());
539 }
540
541 #[test]
542 fn test_enrichment_request_with_existing() {
543 let existing = PartialMetadata::new().with_description("My description");
544 let request =
545 EnrichmentRequest::new("Test {{var}}", vec!["var".to_string()]).with_existing(existing);
546 assert!(request.existing.is_some());
547 assert_eq!(
548 request.existing.unwrap().description,
549 Some("My description".to_string())
550 );
551 }
552
553 #[test]
554 fn test_partial_metadata_is_empty() {
555 let empty = PartialMetadata::new();
556 assert!(empty.is_empty());
557
558 let with_desc = PartialMetadata::new().with_description("test");
559 assert!(!with_desc.is_empty());
560
561 let with_tags = PartialMetadata::new().with_tags(vec!["tag".to_string()]);
562 assert!(!with_tags.is_empty());
563 }
564
565 #[test]
566 fn test_partial_metadata_get_variable() {
567 let vars = vec![
568 PromptVariable {
569 name: "file".to_string(),
570 description: Some("File path".to_string()),
571 default: None,
572 required: true,
573 },
574 PromptVariable {
575 name: "type".to_string(),
576 description: None,
577 default: Some("general".to_string()),
578 required: false,
579 },
580 ];
581 let partial = PartialMetadata::new().with_variables(vars);
582
583 let file_var = partial.get_variable("file");
584 assert!(file_var.is_some());
585 assert_eq!(file_var.unwrap().description, Some("File path".to_string()));
586
587 let missing = partial.get_variable("missing");
588 assert!(missing.is_none());
589 }
590
591 #[test]
592 fn test_enrichment_result_basic_from_variables() {
593 let result =
594 PromptEnrichmentResult::basic_from_variables(&["file".to_string(), "type".to_string()]);
595
596 assert!(result.description.is_empty());
597 assert!(result.tags.is_empty());
598 assert_eq!(result.variables.len(), 2);
599 assert_eq!(result.status, EnrichmentStatus::Fallback);
600
601 assert_eq!(result.variables[0].name, "file");
602 assert!(result.variables[0].required);
603 assert!(result.variables[0].description.is_none());
604 }
605
606 #[test]
607 fn test_enrichment_result_merge_with_user() {
608 let llm_result = PromptEnrichmentResult {
609 description: "LLM description".to_string(),
610 tags: vec!["llm-tag".to_string()],
611 variables: vec![
612 PromptVariable {
613 name: "file".to_string(),
614 description: Some("LLM file desc".to_string()),
615 default: None,
616 required: true,
617 },
618 PromptVariable {
619 name: "type".to_string(),
620 description: Some("LLM type desc".to_string()),
621 default: Some("llm-default".to_string()),
622 required: true,
623 },
624 ],
625 status: EnrichmentStatus::Full,
626 };
627
628 let user = PartialMetadata::new()
629 .with_description("User description")
630 .with_variables(vec![PromptVariable {
631 name: "file".to_string(),
632 description: Some("User file desc".to_string()),
633 default: Some("user-default".to_string()),
634 required: false,
635 }]);
636
637 let merged = llm_result.merge_with_user(&user);
638
639 assert_eq!(merged.description, "User description");
641 assert_eq!(merged.tags, vec!["llm-tag".to_string()]);
643
644 assert_eq!(
646 merged.variables[0].description,
647 Some("User file desc".to_string())
648 );
649 assert_eq!(
650 merged.variables[0].default,
651 Some("user-default".to_string())
652 );
653 assert!(!merged.variables[0].required);
654
655 assert_eq!(
657 merged.variables[1].description,
658 Some("LLM type desc".to_string())
659 );
660 assert_eq!(merged.variables[1].default, Some("llm-default".to_string()));
661 }
662
663 #[test]
664 fn test_enrichment_service_successful() {
665 let mock_response = r#"{
666 "description": "Code review prompt for specific files",
667 "tags": ["code-review", "analysis"],
668 "variables": [
669 {
670 "name": "file",
671 "description": "Path to the file to review",
672 "required": true,
673 "default": null
674 },
675 {
676 "name": "issue_type",
677 "description": "Category of issues to look for",
678 "required": false,
679 "default": "general"
680 }
681 ]
682 }"#;
683
684 let llm = MockLlmProvider::new(mock_response);
685 let service = PromptEnrichmentService::new(llm);
686
687 let request = EnrichmentRequest::new(
688 "Review {{file}} for {{issue_type}} issues",
689 vec!["file".to_string(), "issue_type".to_string()],
690 );
691
692 let result = service.enrich(&request).unwrap();
693
694 assert_eq!(result.description, "Code review prompt for specific files");
695 assert_eq!(result.tags, vec!["code-review", "analysis"]);
696 assert_eq!(result.variables.len(), 2);
697 assert_eq!(result.status, EnrichmentStatus::Full);
698
699 let file_var = result.variables.iter().find(|v| v.name == "file").unwrap();
700 assert_eq!(
701 file_var.description,
702 Some("Path to the file to review".to_string())
703 );
704 assert!(file_var.required);
705
706 let type_var = result
707 .variables
708 .iter()
709 .find(|v| v.name == "issue_type")
710 .unwrap();
711 assert_eq!(type_var.default, Some("general".to_string()));
712 assert!(!type_var.required);
713 }
714
715 #[test]
716 fn test_enrichment_service_with_json_in_markdown() {
717 let mock_response = r#"```json
718{
719 "description": "Test prompt",
720 "tags": ["test"],
721 "variables": []
722}
723```"#;
724
725 let llm = MockLlmProvider::new(mock_response);
726 let service = PromptEnrichmentService::new(llm);
727
728 let request = EnrichmentRequest::new("Test content", vec![]);
729
730 let result = service.enrich(&request).unwrap();
731 assert_eq!(result.description, "Test prompt");
732 }
733
734 #[test]
735 fn test_enrichment_service_fallback_on_error() {
736 let llm = MockLlmProvider::failing();
737 let service = PromptEnrichmentService::new(llm);
738
739 let request = EnrichmentRequest::new("Review {{file}}", vec!["file".to_string()]);
740
741 let result = service.enrich_with_fallback(&request);
742
743 assert_eq!(result.status, EnrichmentStatus::Fallback);
744 assert!(result.description.is_empty());
745 assert_eq!(result.variables.len(), 1);
746 assert_eq!(result.variables[0].name, "file");
747 }
748
749 #[test]
750 fn test_enrichment_service_fallback_preserves_user_metadata() {
751 let llm = MockLlmProvider::failing();
752 let service = PromptEnrichmentService::new(llm);
753
754 let existing = PartialMetadata::new()
755 .with_description("User description")
756 .with_tags(vec!["user-tag".to_string()]);
757
758 let request = EnrichmentRequest::new("Review {{file}}", vec!["file".to_string()])
759 .with_existing(existing);
760
761 let result = service.enrich_with_fallback(&request);
762
763 assert_eq!(result.description, "User description");
764 assert_eq!(result.tags, vec!["user-tag".to_string()]);
765 }
766
767 #[test]
768 fn test_enrichment_service_missing_variable_filled() {
769 let mock_response = r#"{
771 "description": "Test",
772 "tags": [],
773 "variables": [
774 {"name": "file", "description": "File path", "required": true}
775 ]
776 }"#;
777
778 let llm = MockLlmProvider::new(mock_response);
779 let service = PromptEnrichmentService::new(llm);
780
781 let request = EnrichmentRequest::new(
782 "Review {{file}} for {{issue_type}}",
783 vec!["file".to_string(), "issue_type".to_string()],
784 );
785
786 let result = service.enrich(&request).unwrap();
787
788 assert_eq!(result.variables.len(), 2);
790
791 let missing = result
793 .variables
794 .iter()
795 .find(|v| v.name == "issue_type")
796 .unwrap();
797 assert!(missing.description.is_none());
798 assert!(missing.required);
799 }
800
801 #[test]
802 fn test_enrichment_service_empty_variables() {
803 let mock_response = r#"{
804 "description": "Static prompt with no variables",
805 "tags": ["static"],
806 "variables": []
807 }"#;
808
809 let llm = MockLlmProvider::new(mock_response);
810 let service = PromptEnrichmentService::new(llm);
811
812 let request = EnrichmentRequest::new("Hello, world!", vec![]);
813
814 let result = service.enrich(&request).unwrap();
815
816 assert_eq!(result.description, "Static prompt with no variables");
817 assert!(result.variables.is_empty());
818 }
819
820 #[test]
821 fn test_enrichment_service_invalid_json() {
822 let mock_response = "This is not JSON";
823
824 let llm = MockLlmProvider::new(mock_response);
825 let service = PromptEnrichmentService::new(llm);
826
827 let request = EnrichmentRequest::new("Test {{var}}", vec!["var".to_string()]);
828
829 let result = service.enrich(&request);
830 assert!(result.is_err());
831 }
832
833 #[test]
834 fn test_enrichment_status_serialization() {
835 let full = EnrichmentStatus::Full;
836 let serialized = serde_json::to_string(&full).unwrap();
837 assert_eq!(serialized, r#""full""#);
838
839 let fallback = EnrichmentStatus::Fallback;
840 let serialized = serde_json::to_string(&fallback).unwrap();
841 assert_eq!(serialized, r#""fallback""#);
842
843 let skipped = EnrichmentStatus::Skipped;
844 let serialized = serde_json::to_string(&skipped).unwrap();
845 assert_eq!(serialized, r#""skipped""#);
846 }
847
848 #[test]
849 fn test_system_prompt_contains_required_sections() {
850 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("<task>"));
851 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("<output_format>"));
852 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("<guidelines>"));
853 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("<rules>"));
854 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("description"));
855 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("tags"));
856 assert!(PROMPT_ENRICHMENT_SYSTEM_PROMPT.contains("variables"));
857 }
858}