1use serde::{Deserialize, Serialize};
30use std::fmt;
31
32use crate::models::prompt::{ExtractedVariable, PromptVariable, extract_variables};
33
34pub const AUTO_VARIABLES: &[&str] = &[
39 "memories",
41 "memory.id",
43 "memory.content",
44 "memory.namespace",
45 "memory.tags",
46 "memory.score",
47 "memory.created_at",
48 "memory.updated_at",
49 "memory.domain",
50 "statistics",
52 "total_count",
53 "namespace_counts",
54];
55
56pub const AUTO_VARIABLE_PREFIXES: &[&str] = &["memory."];
60
61#[must_use]
63pub fn is_auto_variable(name: &str) -> bool {
64 AUTO_VARIABLES.contains(&name)
65 || AUTO_VARIABLE_PREFIXES
66 .iter()
67 .any(|prefix| name.starts_with(prefix))
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
72#[serde(rename_all = "lowercase")]
73pub enum OutputFormat {
74 #[default]
76 Markdown,
77 Json,
79 Xml,
81}
82
83impl OutputFormat {
84 #[must_use]
86 pub const fn as_str(&self) -> &'static str {
87 match self {
88 Self::Markdown => "markdown",
89 Self::Json => "json",
90 Self::Xml => "xml",
91 }
92 }
93}
94
95impl fmt::Display for OutputFormat {
96 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
97 write!(f, "{}", self.as_str())
98 }
99}
100
101impl std::str::FromStr for OutputFormat {
102 type Err = crate::Error;
103
104 fn from_str(s: &str) -> Result<Self, Self::Err> {
105 match s.to_lowercase().as_str() {
106 "markdown" | "md" => Ok(Self::Markdown),
107 "json" => Ok(Self::Json),
108 "xml" => Ok(Self::Xml),
109 _ => Err(crate::Error::InvalidInput(format!(
110 "Invalid output format: {s}. Expected: markdown, json, or xml"
111 ))),
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
118#[serde(rename_all = "lowercase")]
119pub enum VariableType {
120 #[default]
122 User,
123 Auto,
125}
126
127impl VariableType {
128 #[must_use]
130 pub const fn as_str(&self) -> &'static str {
131 match self {
132 Self::User => "user",
133 Self::Auto => "auto",
134 }
135 }
136}
137
138impl fmt::Display for VariableType {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 write!(f, "{}", self.as_str())
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
146pub struct TemplateVariable {
147 pub name: String,
149 #[serde(default)]
151 pub var_type: VariableType,
152 #[serde(default)]
154 pub description: Option<String>,
155 #[serde(default)]
157 pub default: Option<String>,
158 #[serde(default = "default_required")]
160 pub required: bool,
161}
162
163const fn default_required() -> bool {
165 true
166}
167
168impl TemplateVariable {
169 #[must_use]
171 pub fn new(name: impl Into<String>) -> Self {
172 let name = name.into();
173 let var_type = if is_auto_variable(&name) {
174 VariableType::Auto
175 } else {
176 VariableType::User
177 };
178 let required = var_type == VariableType::User;
179
180 Self {
181 name,
182 var_type,
183 description: None,
184 default: None,
185 required,
186 }
187 }
188
189 #[must_use]
191 pub fn auto(name: impl Into<String>) -> Self {
192 Self {
193 name: name.into(),
194 var_type: VariableType::Auto,
195 description: None,
196 default: None,
197 required: false,
198 }
199 }
200
201 #[must_use]
203 pub fn with_description(mut self, description: impl Into<String>) -> Self {
204 self.description = Some(description.into());
205 self
206 }
207
208 #[must_use]
210 pub fn with_default(mut self, default: impl Into<String>) -> Self {
211 self.default = Some(default.into());
212 self.required = false;
213 self
214 }
215}
216
217impl Default for TemplateVariable {
218 fn default() -> Self {
219 Self {
220 name: String::new(),
221 var_type: VariableType::User,
222 description: None,
223 default: None,
224 required: true,
225 }
226 }
227}
228
229impl From<ExtractedVariable> for TemplateVariable {
230 fn from(extracted: ExtractedVariable) -> Self {
231 Self::new(extracted.name)
232 }
233}
234
235impl From<PromptVariable> for TemplateVariable {
236 fn from(prompt_var: PromptVariable) -> Self {
237 Self {
238 name: prompt_var.name,
239 var_type: VariableType::User,
240 description: prompt_var.description,
241 default: prompt_var.default,
242 required: prompt_var.required,
243 }
244 }
245}
246
247impl From<TemplateVariable> for PromptVariable {
248 fn from(template_var: TemplateVariable) -> Self {
249 Self {
250 name: template_var.name,
251 description: template_var.description,
252 default: template_var.default,
253 required: template_var.required,
254 }
255 }
256}
257
258#[derive(Debug, Clone, Default, Serialize, Deserialize)]
260pub struct ContextTemplate {
261 pub name: String,
263 #[serde(default)]
265 pub description: String,
266 pub content: String,
268 #[serde(default)]
270 pub variables: Vec<TemplateVariable>,
271 #[serde(default)]
273 pub tags: Vec<String>,
274 #[serde(default)]
276 pub output_format: OutputFormat,
277 #[serde(default)]
279 pub author: Option<String>,
280 #[serde(default = "default_version")]
282 pub version: u32,
283 #[serde(default)]
285 pub created_at: u64,
286 #[serde(default)]
288 pub updated_at: u64,
289}
290
291const fn default_version() -> u32 {
293 1
294}
295
296impl ContextTemplate {
297 #[must_use]
302 pub fn new(name: impl Into<String>, content: impl Into<String>) -> Self {
303 let content = content.into();
304 let variables = extract_variables(&content)
305 .into_iter()
306 .map(TemplateVariable::from)
307 .collect();
308
309 Self {
310 name: name.into(),
311 description: String::new(),
312 content,
313 variables,
314 tags: Vec::new(),
315 output_format: OutputFormat::default(),
316 author: None,
317 version: 1,
318 created_at: 0,
319 updated_at: 0,
320 }
321 }
322
323 #[must_use]
325 pub fn with_description(mut self, description: impl Into<String>) -> Self {
326 self.description = description.into();
327 self
328 }
329
330 #[must_use]
332 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
333 self.tags = tags;
334 self
335 }
336
337 #[must_use]
339 pub fn with_author(mut self, author: impl Into<String>) -> Self {
340 self.author = Some(author.into());
341 self
342 }
343
344 #[must_use]
346 pub fn with_variables(mut self, variables: Vec<TemplateVariable>) -> Self {
347 self.variables = variables;
348 self
349 }
350
351 #[must_use]
353 pub const fn with_version(mut self, version: u32) -> Self {
354 self.version = version;
355 self
356 }
357
358 #[must_use]
360 pub fn variable_names(&self) -> Vec<&str> {
361 self.variables.iter().map(|v| v.name.as_str()).collect()
362 }
363
364 #[must_use]
366 pub fn user_variables(&self) -> Vec<&TemplateVariable> {
367 self.variables
368 .iter()
369 .filter(|v| v.var_type == VariableType::User)
370 .collect()
371 }
372
373 #[must_use]
375 pub fn auto_variables(&self) -> Vec<&TemplateVariable> {
376 self.variables
377 .iter()
378 .filter(|v| v.var_type == VariableType::Auto)
379 .collect()
380 }
381
382 #[must_use]
384 pub fn has_iteration(&self) -> bool {
385 self.content.contains("{{#each")
386 }
387
388 #[must_use]
390 pub fn iteration_collections(&self) -> Vec<&str> {
391 let mut collections = Vec::new();
392 let mut search_start = 0;
393
394 while let Some(start) = self.content[search_start..].find("{{#each ") {
395 let abs_start = search_start + start + 8; let Some(end) = self.content[abs_start..].find("}}") else {
397 break;
398 };
399 let collection = self.content[abs_start..abs_start + end].trim();
400 if !collections.contains(&collection) {
401 collections.push(collection);
402 }
403 search_start = abs_start + end;
404 }
405
406 collections
407 }
408}
409
410#[derive(Debug, Clone, Serialize, Deserialize)]
412pub struct TemplateVersion {
413 pub version: u32,
415 pub created_at: u64,
417 pub author: Option<String>,
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424
425 #[test]
426 fn test_is_auto_variable() {
427 assert!(is_auto_variable("memories"));
429 assert!(is_auto_variable("memory.id"));
430 assert!(is_auto_variable("memory.content"));
431 assert!(is_auto_variable("statistics"));
432 assert!(is_auto_variable("total_count"));
433
434 assert!(is_auto_variable("memory.custom_field"));
436
437 assert!(!is_auto_variable("user_name"));
439 assert!(!is_auto_variable("custom_var"));
440 assert!(!is_auto_variable("my_memories")); }
442
443 #[test]
444 fn test_output_format_parsing() {
445 assert_eq!(
446 "markdown".parse::<OutputFormat>().unwrap(),
447 OutputFormat::Markdown
448 );
449 assert_eq!(
450 "md".parse::<OutputFormat>().unwrap(),
451 OutputFormat::Markdown
452 );
453 assert_eq!("json".parse::<OutputFormat>().unwrap(), OutputFormat::Json);
454 assert_eq!("xml".parse::<OutputFormat>().unwrap(), OutputFormat::Xml);
455 assert!("invalid".parse::<OutputFormat>().is_err());
456 }
457
458 #[test]
459 fn test_output_format_display() {
460 assert_eq!(OutputFormat::Markdown.to_string(), "markdown");
461 assert_eq!(OutputFormat::Json.to_string(), "json");
462 assert_eq!(OutputFormat::Xml.to_string(), "xml");
463 }
464
465 #[test]
466 fn test_template_variable_new() {
467 let user_var = TemplateVariable::new("user_name");
469 assert_eq!(user_var.var_type, VariableType::User);
470 assert!(user_var.required);
471
472 let auto_var = TemplateVariable::new("memory.id");
474 assert_eq!(auto_var.var_type, VariableType::Auto);
475 assert!(!auto_var.required);
476 }
477
478 #[test]
479 fn test_template_variable_builders() {
480 let var = TemplateVariable::new("name")
481 .with_description("User's name")
482 .with_default("Anonymous");
483
484 assert_eq!(var.description, Some("User's name".to_string()));
485 assert_eq!(var.default, Some("Anonymous".to_string()));
486 assert!(!var.required); }
488
489 #[test]
490 fn test_context_template_new() {
491 let template = ContextTemplate::new(
492 "test-template",
493 "Hello {{user_name}}, you have {{total_count}} memories.",
494 );
495
496 assert_eq!(template.name, "test-template");
497 assert_eq!(template.variables.len(), 2);
498
499 let user_var = template
501 .variables
502 .iter()
503 .find(|v| v.name == "user_name")
504 .unwrap();
505 assert_eq!(user_var.var_type, VariableType::User);
506
507 let auto_var = template
508 .variables
509 .iter()
510 .find(|v| v.name == "total_count")
511 .unwrap();
512 assert_eq!(auto_var.var_type, VariableType::Auto);
513 }
514
515 #[test]
516 fn test_context_template_builders() {
517 let template = ContextTemplate::new("test", "{{var}}")
518 .with_description("A test template")
519 .with_tags(vec!["test".to_string(), "example".to_string()])
520 .with_author("test-user")
521 .with_version(5);
522
523 assert_eq!(template.description, "A test template");
524 assert_eq!(template.tags, vec!["test", "example"]);
525 assert_eq!(template.author, Some("test-user".to_string()));
526 assert_eq!(template.version, 5);
527 }
528
529 #[test]
530 fn test_context_template_variable_helpers() {
531 let template = ContextTemplate::new("test", "{{user_var}} {{memory.id}} {{total_count}}");
532
533 let user_vars = template.user_variables();
534 assert_eq!(user_vars.len(), 1);
535 assert_eq!(user_vars[0].name, "user_var");
536
537 let auto_vars = template.auto_variables();
538 assert_eq!(auto_vars.len(), 2);
539 }
540
541 #[test]
542 fn test_context_template_has_iteration() {
543 let with_iteration =
544 ContextTemplate::new("test", "{{#each memories}}{{memory.id}}{{/each}}");
545 assert!(with_iteration.has_iteration());
546
547 let without_iteration = ContextTemplate::new("test", "{{total_count}}");
548 assert!(!without_iteration.has_iteration());
549 }
550
551 #[test]
552 fn test_context_template_iteration_collections() {
553 let template = ContextTemplate::new(
554 "test",
555 "{{#each memories}}{{memory.id}}{{/each}} and {{#each items}}{{item.name}}{{/each}}",
556 );
557
558 let collections = template.iteration_collections();
559 assert_eq!(collections.len(), 2);
560 assert!(collections.contains(&"memories"));
561 assert!(collections.contains(&"items"));
562 }
563
564 #[test]
565 fn test_context_template_serialization() {
566 let template = ContextTemplate::new("test", "{{var}}")
567 .with_description("A test")
568 .with_tags(vec!["tag1".to_string()]);
569
570 let json = serde_json::to_string(&template).unwrap();
571 let parsed: ContextTemplate = serde_json::from_str(&json).unwrap();
572
573 assert_eq!(parsed.name, "test");
574 assert_eq!(parsed.description, "A test");
575 assert_eq!(parsed.tags, vec!["tag1"]);
576 }
577
578 #[test]
579 fn test_template_variable_from_extracted() {
580 let extracted = ExtractedVariable {
581 name: "memory.content".to_string(),
582 position: 0,
583 };
584 let var: TemplateVariable = extracted.into();
585
586 assert_eq!(var.name, "memory.content");
587 assert_eq!(var.var_type, VariableType::Auto);
588 }
589
590 #[test]
591 fn test_template_variable_from_prompt_variable() {
592 let prompt_var = PromptVariable {
593 name: "user_input".to_string(),
594 description: Some("User input".to_string()),
595 default: Some("default".to_string()),
596 required: false,
597 };
598 let var: TemplateVariable = prompt_var.into();
599
600 assert_eq!(var.name, "user_input");
601 assert_eq!(var.var_type, VariableType::User);
602 assert_eq!(var.description, Some("User input".to_string()));
603 assert_eq!(var.default, Some("default".to_string()));
604 }
605
606 #[test]
607 fn test_template_variable_to_prompt_variable() {
608 let var = TemplateVariable::new("test")
609 .with_description("Test var")
610 .with_default("default");
611
612 let prompt_var: PromptVariable = var.into();
613
614 assert_eq!(prompt_var.name, "test");
615 assert_eq!(prompt_var.description, Some("Test var".to_string()));
616 assert_eq!(prompt_var.default, Some("default".to_string()));
617 }
618}