1use super::PromptStorage;
6use crate::models::PromptTemplate;
7use crate::{Error, Result};
8use rusqlite::{Connection, OptionalExtension, params};
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11
12pub struct SqlitePromptStorage {
14 conn: Mutex<Connection>,
16 db_path: PathBuf,
18}
19
20impl SqlitePromptStorage {
21 pub fn new(db_path: impl Into<PathBuf>) -> Result<Self> {
31 let db_path = db_path.into();
32
33 if let Some(parent) = db_path.parent() {
35 std::fs::create_dir_all(parent).map_err(|e| Error::OperationFailed {
36 operation: "create_prompt_dir".to_string(),
37 cause: e.to_string(),
38 })?;
39 }
40
41 let conn = Connection::open(&db_path).map_err(|e| Error::OperationFailed {
42 operation: "open_prompt_db".to_string(),
43 cause: e.to_string(),
44 })?;
45
46 let storage = Self {
47 conn: Mutex::new(conn),
48 db_path,
49 };
50
51 storage.initialize()?;
52 Ok(storage)
53 }
54
55 pub fn in_memory() -> Result<Self> {
61 let conn = Connection::open_in_memory().map_err(|e| Error::OperationFailed {
62 operation: "open_prompt_db_memory".to_string(),
63 cause: e.to_string(),
64 })?;
65
66 let storage = Self {
67 conn: Mutex::new(conn),
68 db_path: PathBuf::from(":memory:"),
69 };
70
71 storage.initialize()?;
72 Ok(storage)
73 }
74
75 #[must_use]
79 pub fn default_user_path() -> Option<PathBuf> {
80 directories::BaseDirs::new().map(|d| {
81 d.home_dir()
82 .join(".config")
83 .join("subcog")
84 .join("memories.db")
85 })
86 }
87
88 #[must_use]
92 pub fn default_org_path(org: &str) -> Option<PathBuf> {
93 directories::BaseDirs::new().map(|d| {
94 d.home_dir()
95 .join(".config")
96 .join("subcog")
97 .join("orgs")
98 .join(org)
99 .join("memories.db")
100 })
101 }
102
103 #[must_use]
105 pub fn db_path(&self) -> &Path {
106 &self.db_path
107 }
108
109 fn initialize(&self) -> Result<()> {
111 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
112 operation: "lock_prompt_db".to_string(),
113 cause: e.to_string(),
114 })?;
115
116 let _ = conn.pragma_update(None, "journal_mode", "WAL");
121 let _ = conn.pragma_update(None, "synchronous", "NORMAL");
122 let _ = conn.pragma_update(None, "busy_timeout", "5000");
123
124 conn.execute(
125 "CREATE TABLE IF NOT EXISTS prompts (
126 name TEXT PRIMARY KEY,
127 description TEXT NOT NULL DEFAULT '',
128 content TEXT NOT NULL,
129 variables TEXT NOT NULL DEFAULT '[]',
130 tags TEXT NOT NULL DEFAULT '[]',
131 author TEXT,
132 usage_count INTEGER NOT NULL DEFAULT 0,
133 created_at INTEGER NOT NULL,
134 updated_at INTEGER NOT NULL
135 )",
136 [],
137 )
138 .map_err(|e| Error::OperationFailed {
139 operation: "create_prompts_table".to_string(),
140 cause: e.to_string(),
141 })?;
142
143 conn.execute(
145 "CREATE INDEX IF NOT EXISTS idx_prompts_tags ON prompts(tags)",
146 [],
147 )
148 .map_err(|e| Error::OperationFailed {
149 operation: "create_prompts_tags_index".to_string(),
150 cause: e.to_string(),
151 })?;
152
153 Ok(())
154 }
155
156 fn lock_conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>> {
158 self.conn.lock().map_err(|e| Error::OperationFailed {
159 operation: "lock_prompt_db".to_string(),
160 cause: e.to_string(),
161 })
162 }
163
164 pub fn vacuum_and_analyze(&self) -> Result<()> {
174 let conn = self.lock_conn()?;
175
176 conn.execute("VACUUM", [])
178 .map_err(|e| Error::OperationFailed {
179 operation: "prompt_db_vacuum".to_string(),
180 cause: e.to_string(),
181 })?;
182
183 conn.execute("ANALYZE", [])
184 .map_err(|e| Error::OperationFailed {
185 operation: "prompt_db_analyze".to_string(),
186 cause: e.to_string(),
187 })?;
188
189 Ok(())
190 }
191
192 #[must_use]
196 pub fn stats(&self) -> Option<PromptDbStats> {
197 let conn = self.lock_conn().ok()?;
198
199 let prompt_count: i64 = conn
200 .query_row("SELECT COUNT(*) FROM prompts", [], |row| row.get(0))
201 .unwrap_or(0);
202
203 let page_count: i64 = conn
204 .pragma_query_value(None, "page_count", |row| row.get(0))
205 .unwrap_or(0);
206
207 let page_size: i64 = conn
208 .pragma_query_value(None, "page_size", |row| row.get(0))
209 .unwrap_or(4096);
210
211 let freelist_count: i64 = conn
212 .pragma_query_value(None, "freelist_count", |row| row.get(0))
213 .unwrap_or(0);
214
215 Some(PromptDbStats {
217 prompt_count: u64::try_from(prompt_count).unwrap_or(0),
218 db_size_bytes: u64::try_from(page_count.saturating_mul(page_size)).unwrap_or(0),
219 freelist_pages: u64::try_from(freelist_count).unwrap_or(0),
220 })
221 }
222}
223
224#[derive(Debug, Clone, Copy, Default)]
226pub struct PromptDbStats {
227 pub prompt_count: u64,
229 pub db_size_bytes: u64,
231 pub freelist_pages: u64,
233}
234
235impl PromptStorage for SqlitePromptStorage {
236 #[allow(clippy::cast_possible_wrap)]
237 fn save(&self, template: &PromptTemplate) -> Result<String> {
238 let conn = self.lock_conn()?;
239
240 let variables_json =
241 serde_json::to_string(&template.variables).map_err(|e| Error::OperationFailed {
242 operation: "serialize_variables".to_string(),
243 cause: e.to_string(),
244 })?;
245
246 let tags_json =
247 serde_json::to_string(&template.tags).map_err(|e| Error::OperationFailed {
248 operation: "serialize_tags".to_string(),
249 cause: e.to_string(),
250 })?;
251
252 let now = std::time::SystemTime::now()
253 .duration_since(std::time::UNIX_EPOCH)
254 .map(|d| d.as_secs())
255 .unwrap_or(0);
256
257 conn.execute(
259 "INSERT OR REPLACE INTO prompts
260 (name, description, content, variables, tags, author, usage_count, created_at, updated_at)
261 VALUES (?1, ?2, ?3, ?4, ?5, ?6,
262 COALESCE((SELECT usage_count FROM prompts WHERE name = ?1), 0),
263 COALESCE((SELECT created_at FROM prompts WHERE name = ?1), ?7),
264 ?7)",
265 params![
266 template.name,
267 template.description,
268 template.content,
269 variables_json,
270 tags_json,
271 template.author,
272 now as i64,
273 ],
274 )
275 .map_err(|e| Error::OperationFailed {
276 operation: "save_prompt".to_string(),
277 cause: e.to_string(),
278 })?;
279
280 Ok(format!("prompt_user_{}", template.name))
281 }
282
283 #[allow(clippy::cast_sign_loss)]
284 fn get(&self, name: &str) -> Result<Option<PromptTemplate>> {
285 let conn = self.lock_conn()?;
286
287 let result = conn
288 .query_row(
289 "SELECT name, description, content, variables, tags, author, usage_count, created_at, updated_at
290 FROM prompts WHERE name = ?1",
291 params![name],
292 |row| {
293 Ok((
294 row.get::<_, String>(0)?,
295 row.get::<_, String>(1)?,
296 row.get::<_, String>(2)?,
297 row.get::<_, String>(3)?,
298 row.get::<_, String>(4)?,
299 row.get::<_, Option<String>>(5)?,
300 row.get::<_, i64>(6)?,
301 row.get::<_, i64>(7)?,
302 row.get::<_, i64>(8)?,
303 ))
304 },
305 )
306 .optional()
307 .map_err(|e| Error::OperationFailed {
308 operation: "get_prompt".to_string(),
309 cause: e.to_string(),
310 })?;
311
312 match result {
313 Some((
314 name,
315 description,
316 content,
317 variables_json,
318 tags_json,
319 author,
320 usage_count,
321 created_at,
322 updated_at,
323 )) => {
324 let variables = serde_json::from_str(&variables_json).unwrap_or_default();
325 let tags = serde_json::from_str(&tags_json).unwrap_or_default();
326
327 Ok(Some(PromptTemplate {
328 name,
329 description,
330 content,
331 variables,
332 tags,
333 author,
334 usage_count: usage_count as u64,
335 created_at: created_at as u64,
336 updated_at: updated_at as u64,
337 }))
338 },
339 None => Ok(None),
340 }
341 }
342
343 #[allow(clippy::cast_sign_loss)]
344 fn list(
345 &self,
346 tags: Option<&[String]>,
347 name_pattern: Option<&str>,
348 ) -> Result<Vec<PromptTemplate>> {
349 let conn = self.lock_conn()?;
350
351 let mut sql = String::from(
352 "SELECT name, description, content, variables, tags, author, usage_count, created_at, updated_at
353 FROM prompts WHERE 1=1",
354 );
355 let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
356
357 if let Some(pattern) = name_pattern {
359 let like_pattern = pattern.replace('*', "%").replace('?', "_");
360 sql.push_str(" AND name LIKE ?");
361 params_vec.push(Box::new(like_pattern));
362 }
363
364 sql.push_str(" ORDER BY usage_count DESC, name ASC");
365
366 let mut stmt = conn.prepare(&sql).map_err(|e| Error::OperationFailed {
367 operation: "prepare_list_prompts".to_string(),
368 cause: e.to_string(),
369 })?;
370
371 let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(AsRef::as_ref).collect();
372
373 let rows = stmt
374 .query_map(params_refs.as_slice(), |row| {
375 Ok((
376 row.get::<_, String>(0)?,
377 row.get::<_, String>(1)?,
378 row.get::<_, String>(2)?,
379 row.get::<_, String>(3)?,
380 row.get::<_, String>(4)?,
381 row.get::<_, Option<String>>(5)?,
382 row.get::<_, i64>(6)?,
383 row.get::<_, i64>(7)?,
384 row.get::<_, i64>(8)?,
385 ))
386 })
387 .map_err(|e| Error::OperationFailed {
388 operation: "list_prompts".to_string(),
389 cause: e.to_string(),
390 })?;
391
392 let mut results = Vec::new();
393 for row in rows {
394 let (
395 name,
396 description,
397 content,
398 variables_json,
399 tags_json,
400 author,
401 usage_count,
402 created_at,
403 updated_at,
404 ) = row.map_err(|e| Error::OperationFailed {
405 operation: "read_prompt_row".to_string(),
406 cause: e.to_string(),
407 })?;
408
409 let variables = serde_json::from_str(&variables_json).unwrap_or_default();
410 let prompt_tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
411
412 let has_all_required_tags = tags
414 .is_none_or(|required_tags| required_tags.iter().all(|t| prompt_tags.contains(t)));
415 if !has_all_required_tags {
416 continue;
417 }
418
419 results.push(PromptTemplate {
420 name,
421 description,
422 content,
423 variables,
424 tags: prompt_tags,
425 author,
426 usage_count: usage_count as u64,
427 created_at: created_at as u64,
428 updated_at: updated_at as u64,
429 });
430 }
431
432 Ok(results)
433 }
434
435 fn delete(&self, name: &str) -> Result<bool> {
436 let conn = self.lock_conn()?;
437
438 let rows_affected = conn
439 .execute("DELETE FROM prompts WHERE name = ?1", params![name])
440 .map_err(|e| Error::OperationFailed {
441 operation: "delete_prompt".to_string(),
442 cause: e.to_string(),
443 })?;
444
445 Ok(rows_affected > 0)
446 }
447
448 #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
449 fn increment_usage(&self, name: &str) -> Result<u64> {
450 let conn = self.lock_conn()?;
451
452 conn.execute(
453 "UPDATE prompts SET usage_count = usage_count + 1, updated_at = ?1 WHERE name = ?2",
454 params![
455 std::time::SystemTime::now()
456 .duration_since(std::time::UNIX_EPOCH)
457 .map(|d| d.as_secs() as i64)
458 .unwrap_or(0),
459 name
460 ],
461 )
462 .map_err(|e| Error::OperationFailed {
463 operation: "increment_usage".to_string(),
464 cause: e.to_string(),
465 })?;
466
467 let count: i64 = conn
469 .query_row(
470 "SELECT usage_count FROM prompts WHERE name = ?1",
471 params![name],
472 |row| row.get(0),
473 )
474 .map_err(|e| Error::OperationFailed {
475 operation: "get_usage_count".to_string(),
476 cause: e.to_string(),
477 })?;
478
479 Ok(count as u64)
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_sqlite_prompt_storage_creation() {
489 let storage = SqlitePromptStorage::in_memory().unwrap();
490 assert_eq!(storage.db_path().to_str(), Some(":memory:"));
491 }
492
493 #[test]
494 fn test_save_and_get_prompt() {
495 let storage = SqlitePromptStorage::in_memory().unwrap();
496
497 let template =
498 PromptTemplate::new("test-prompt", "Hello {{name}}!").with_description("A test prompt");
499
500 let id = storage.save(&template).unwrap();
501 assert!(id.contains("test-prompt"));
502
503 let retrieved = storage.get("test-prompt").unwrap();
504 assert!(retrieved.is_some());
505 let retrieved = retrieved.unwrap();
506 assert_eq!(retrieved.name, "test-prompt");
507 assert_eq!(retrieved.content, "Hello {{name}}!");
508 assert_eq!(retrieved.description, "A test prompt");
509 }
510
511 #[test]
512 fn test_list_prompts() {
513 let storage = SqlitePromptStorage::in_memory().unwrap();
514
515 storage
516 .save(&PromptTemplate::new("alpha", "A").with_tags(vec!["tag1".to_string()]))
517 .unwrap();
518 storage
519 .save(
520 &PromptTemplate::new("beta", "B")
521 .with_tags(vec!["tag1".to_string(), "tag2".to_string()]),
522 )
523 .unwrap();
524 storage.save(&PromptTemplate::new("gamma", "C")).unwrap();
525
526 let all = storage.list(None, None).unwrap();
528 assert_eq!(all.len(), 3);
529
530 let with_tag1 = storage.list(Some(&["tag1".to_string()]), None).unwrap();
532 assert_eq!(with_tag1.len(), 2);
533
534 let alpha_pattern = storage.list(None, Some("a*")).unwrap();
536 assert_eq!(alpha_pattern.len(), 1);
537 assert_eq!(alpha_pattern[0].name, "alpha");
538 }
539
540 #[test]
541 fn test_delete_prompt() {
542 let storage = SqlitePromptStorage::in_memory().unwrap();
543
544 storage
545 .save(&PromptTemplate::new("to-delete", "Content"))
546 .unwrap();
547
548 assert!(storage.get("to-delete").unwrap().is_some());
549 assert!(storage.delete("to-delete").unwrap());
550 assert!(storage.get("to-delete").unwrap().is_none());
551 assert!(!storage.delete("to-delete").unwrap()); }
553
554 #[test]
555 fn test_increment_usage() {
556 let storage = SqlitePromptStorage::in_memory().unwrap();
557
558 storage
559 .save(&PromptTemplate::new("used-prompt", "Content"))
560 .unwrap();
561
562 let count1 = storage.increment_usage("used-prompt").unwrap();
563 assert_eq!(count1, 1);
564
565 let count2 = storage.increment_usage("used-prompt").unwrap();
566 assert_eq!(count2, 2);
567
568 let prompt = storage.get("used-prompt").unwrap().unwrap();
569 assert_eq!(prompt.usage_count, 2);
570 }
571
572 #[test]
573 fn test_update_existing_prompt() {
574 let storage = SqlitePromptStorage::in_memory().unwrap();
575
576 storage
578 .save(&PromptTemplate::new("update-me", "Version 1"))
579 .unwrap();
580
581 storage.increment_usage("update-me").unwrap();
583
584 storage
586 .save(&PromptTemplate::new("update-me", "Version 2").with_description("Updated"))
587 .unwrap();
588
589 let prompt = storage.get("update-me").unwrap().unwrap();
591 assert_eq!(prompt.content, "Version 2");
592 assert_eq!(prompt.description, "Updated");
593 assert_eq!(prompt.usage_count, 1); }
595
596 #[test]
597 fn test_default_user_path() {
598 let path = SqlitePromptStorage::default_user_path();
599 if let Some(p) = path {
601 assert!(p.to_string_lossy().contains("subcog"));
602 assert!(p.to_string_lossy().ends_with("memories.db"));
603 }
604 }
605
606 #[test]
607 fn test_vacuum_and_analyze() {
608 let storage = SqlitePromptStorage::in_memory().unwrap();
609
610 for i in 0..10 {
612 storage
613 .save(&PromptTemplate::new(format!("temp-{i}"), "Content"))
614 .unwrap();
615 }
616 for i in 0..10 {
617 storage.delete(&format!("temp-{i}")).unwrap();
618 }
619
620 assert!(storage.vacuum_and_analyze().is_ok());
622 }
623
624 #[test]
625 fn test_stats() {
626 let storage = SqlitePromptStorage::in_memory().unwrap();
627
628 let stats = storage.stats().unwrap();
630 assert_eq!(stats.prompt_count, 0);
631
632 storage
634 .save(&PromptTemplate::new("stats-test-1", "Content 1"))
635 .unwrap();
636 storage
637 .save(&PromptTemplate::new("stats-test-2", "Content 2"))
638 .unwrap();
639
640 let stats = storage.stats().unwrap();
641 assert_eq!(stats.prompt_count, 2);
642 assert!(stats.db_size_bytes > 0);
643 }
644}