Skip to main content

subcog/storage/prompt/
sqlite.rs

1//! SQLite-based prompt storage for user scope.
2//!
3//! Stores prompts in `~/.config/subcog/memories.db`.
4
5use super::PromptStorage;
6use crate::models::PromptTemplate;
7use crate::{Error, Result};
8use rusqlite::{Connection, OptionalExtension, params};
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11
12/// `SQLite`-based prompt storage.
13pub struct SqlitePromptStorage {
14    /// Connection to the `SQLite` database.
15    conn: Mutex<Connection>,
16    /// Path to the `SQLite` database.
17    db_path: PathBuf,
18}
19
20impl SqlitePromptStorage {
21    /// Creates a new `SQLite` prompt storage.
22    ///
23    /// # Arguments
24    ///
25    /// * `db_path` - Path to the `SQLite` database file
26    ///
27    /// # Errors
28    ///
29    /// Returns an error if the database cannot be opened or initialized.
30    pub fn new(db_path: impl Into<PathBuf>) -> Result<Self> {
31        let db_path = db_path.into();
32
33        // Ensure parent directory exists
34        if let Some(parent) = db_path.parent() {
35            std::fs::create_dir_all(parent).map_err(|e| Error::OperationFailed {
36                operation: "create_prompt_dir".to_string(),
37                cause: e.to_string(),
38            })?;
39        }
40
41        let conn = Connection::open(&db_path).map_err(|e| Error::OperationFailed {
42            operation: "open_prompt_db".to_string(),
43            cause: e.to_string(),
44        })?;
45
46        let storage = Self {
47            conn: Mutex::new(conn),
48            db_path,
49        };
50
51        storage.initialize()?;
52        Ok(storage)
53    }
54
55    /// Creates an in-memory `SQLite` storage (useful for testing).
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the database cannot be initialized.
60    pub fn in_memory() -> Result<Self> {
61        let conn = Connection::open_in_memory().map_err(|e| Error::OperationFailed {
62            operation: "open_prompt_db_memory".to_string(),
63            cause: e.to_string(),
64        })?;
65
66        let storage = Self {
67            conn: Mutex::new(conn),
68            db_path: PathBuf::from(":memory:"),
69        };
70
71        storage.initialize()?;
72        Ok(storage)
73    }
74
75    /// Returns the default user-scope database path.
76    ///
77    /// Returns `~/.config/subcog/memories.db`.
78    #[must_use]
79    pub fn default_user_path() -> Option<PathBuf> {
80        directories::BaseDirs::new().map(|d| {
81            d.home_dir()
82                .join(".config")
83                .join("subcog")
84                .join("memories.db")
85        })
86    }
87
88    /// Returns the default org-scope database path.
89    ///
90    /// Returns `~/.config/subcog/orgs/{org}/memories.db`.
91    #[must_use]
92    pub fn default_org_path(org: &str) -> Option<PathBuf> {
93        directories::BaseDirs::new().map(|d| {
94            d.home_dir()
95                .join(".config")
96                .join("subcog")
97                .join("orgs")
98                .join(org)
99                .join("memories.db")
100        })
101    }
102
103    /// Returns the database path.
104    #[must_use]
105    pub fn db_path(&self) -> &Path {
106        &self.db_path
107    }
108
109    /// Initializes the database schema and configures pragmas.
110    fn initialize(&self) -> Result<()> {
111        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
112            operation: "lock_prompt_db".to_string(),
113            cause: e.to_string(),
114        })?;
115
116        // Configure SQLite pragmas for performance and reliability
117        // WAL mode: better concurrent read performance
118        // synchronous=NORMAL: good balance of safety vs performance
119        // busy_timeout: wait up to 5 seconds if database is locked
120        let _ = conn.pragma_update(None, "journal_mode", "WAL");
121        let _ = conn.pragma_update(None, "synchronous", "NORMAL");
122        let _ = conn.pragma_update(None, "busy_timeout", "5000");
123
124        conn.execute(
125            "CREATE TABLE IF NOT EXISTS prompts (
126                name TEXT PRIMARY KEY,
127                description TEXT NOT NULL DEFAULT '',
128                content TEXT NOT NULL,
129                variables TEXT NOT NULL DEFAULT '[]',
130                tags TEXT NOT NULL DEFAULT '[]',
131                author TEXT,
132                usage_count INTEGER NOT NULL DEFAULT 0,
133                created_at INTEGER NOT NULL,
134                updated_at INTEGER NOT NULL
135            )",
136            [],
137        )
138        .map_err(|e| Error::OperationFailed {
139            operation: "create_prompts_table".to_string(),
140            cause: e.to_string(),
141        })?;
142
143        // Create index on tags for faster filtering
144        conn.execute(
145            "CREATE INDEX IF NOT EXISTS idx_prompts_tags ON prompts(tags)",
146            [],
147        )
148        .map_err(|e| Error::OperationFailed {
149            operation: "create_prompts_tags_index".to_string(),
150            cause: e.to_string(),
151        })?;
152
153        Ok(())
154    }
155
156    /// Locks the connection and returns a guard.
157    fn lock_conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>> {
158        self.conn.lock().map_err(|e| Error::OperationFailed {
159            operation: "lock_prompt_db".to_string(),
160            cause: e.to_string(),
161        })
162    }
163
164    /// Runs database maintenance (VACUUM and ANALYZE).
165    ///
166    /// Call this periodically (e.g., on status command or admin trigger) to:
167    /// - VACUUM: Reclaim space from deleted rows and defragment
168    /// - ANALYZE: Update query planner statistics for optimal performance
169    ///
170    /// # Errors
171    ///
172    /// Returns an error if maintenance commands fail.
173    pub fn vacuum_and_analyze(&self) -> Result<()> {
174        let conn = self.lock_conn()?;
175
176        // VACUUM must run outside a transaction
177        conn.execute("VACUUM", [])
178            .map_err(|e| Error::OperationFailed {
179                operation: "prompt_db_vacuum".to_string(),
180                cause: e.to_string(),
181            })?;
182
183        conn.execute("ANALYZE", [])
184            .map_err(|e| Error::OperationFailed {
185                operation: "prompt_db_analyze".to_string(),
186                cause: e.to_string(),
187            })?;
188
189        Ok(())
190    }
191
192    /// Returns database statistics for monitoring.
193    ///
194    /// Useful for admin/status commands to show database health.
195    #[must_use]
196    pub fn stats(&self) -> Option<PromptDbStats> {
197        let conn = self.lock_conn().ok()?;
198
199        let prompt_count: i64 = conn
200            .query_row("SELECT COUNT(*) FROM prompts", [], |row| row.get(0))
201            .unwrap_or(0);
202
203        let page_count: i64 = conn
204            .pragma_query_value(None, "page_count", |row| row.get(0))
205            .unwrap_or(0);
206
207        let page_size: i64 = conn
208            .pragma_query_value(None, "page_size", |row| row.get(0))
209            .unwrap_or(4096);
210
211        let freelist_count: i64 = conn
212            .pragma_query_value(None, "freelist_count", |row| row.get(0))
213            .unwrap_or(0);
214
215        // Safe casts: counts and sizes are always non-negative from SQLite
216        Some(PromptDbStats {
217            prompt_count: u64::try_from(prompt_count).unwrap_or(0),
218            db_size_bytes: u64::try_from(page_count.saturating_mul(page_size)).unwrap_or(0),
219            freelist_pages: u64::try_from(freelist_count).unwrap_or(0),
220        })
221    }
222}
223
224/// Database statistics for prompt storage.
225#[derive(Debug, Clone, Copy, Default)]
226pub struct PromptDbStats {
227    /// Number of prompts stored.
228    pub prompt_count: u64,
229    /// Total database size in bytes.
230    pub db_size_bytes: u64,
231    /// Number of freelist pages (reclaimable with VACUUM).
232    pub freelist_pages: u64,
233}
234
235impl PromptStorage for SqlitePromptStorage {
236    #[allow(clippy::cast_possible_wrap)]
237    fn save(&self, template: &PromptTemplate) -> Result<String> {
238        let conn = self.lock_conn()?;
239
240        let variables_json =
241            serde_json::to_string(&template.variables).map_err(|e| Error::OperationFailed {
242                operation: "serialize_variables".to_string(),
243                cause: e.to_string(),
244            })?;
245
246        let tags_json =
247            serde_json::to_string(&template.tags).map_err(|e| Error::OperationFailed {
248                operation: "serialize_tags".to_string(),
249                cause: e.to_string(),
250            })?;
251
252        let now = std::time::SystemTime::now()
253            .duration_since(std::time::UNIX_EPOCH)
254            .map(|d| d.as_secs())
255            .unwrap_or(0);
256
257        // Use INSERT OR REPLACE to handle updates
258        conn.execute(
259            "INSERT OR REPLACE INTO prompts
260             (name, description, content, variables, tags, author, usage_count, created_at, updated_at)
261             VALUES (?1, ?2, ?3, ?4, ?5, ?6,
262                     COALESCE((SELECT usage_count FROM prompts WHERE name = ?1), 0),
263                     COALESCE((SELECT created_at FROM prompts WHERE name = ?1), ?7),
264                     ?7)",
265            params![
266                template.name,
267                template.description,
268                template.content,
269                variables_json,
270                tags_json,
271                template.author,
272                now as i64,
273            ],
274        )
275        .map_err(|e| Error::OperationFailed {
276            operation: "save_prompt".to_string(),
277            cause: e.to_string(),
278        })?;
279
280        Ok(format!("prompt_user_{}", template.name))
281    }
282
283    #[allow(clippy::cast_sign_loss)]
284    fn get(&self, name: &str) -> Result<Option<PromptTemplate>> {
285        let conn = self.lock_conn()?;
286
287        let result = conn
288            .query_row(
289                "SELECT name, description, content, variables, tags, author, usage_count, created_at, updated_at
290                 FROM prompts WHERE name = ?1",
291                params![name],
292                |row| {
293                    Ok((
294                        row.get::<_, String>(0)?,
295                        row.get::<_, String>(1)?,
296                        row.get::<_, String>(2)?,
297                        row.get::<_, String>(3)?,
298                        row.get::<_, String>(4)?,
299                        row.get::<_, Option<String>>(5)?,
300                        row.get::<_, i64>(6)?,
301                        row.get::<_, i64>(7)?,
302                        row.get::<_, i64>(8)?,
303                    ))
304                },
305            )
306            .optional()
307            .map_err(|e| Error::OperationFailed {
308                operation: "get_prompt".to_string(),
309                cause: e.to_string(),
310            })?;
311
312        match result {
313            Some((
314                name,
315                description,
316                content,
317                variables_json,
318                tags_json,
319                author,
320                usage_count,
321                created_at,
322                updated_at,
323            )) => {
324                let variables = serde_json::from_str(&variables_json).unwrap_or_default();
325                let tags = serde_json::from_str(&tags_json).unwrap_or_default();
326
327                Ok(Some(PromptTemplate {
328                    name,
329                    description,
330                    content,
331                    variables,
332                    tags,
333                    author,
334                    usage_count: usage_count as u64,
335                    created_at: created_at as u64,
336                    updated_at: updated_at as u64,
337                }))
338            },
339            None => Ok(None),
340        }
341    }
342
343    #[allow(clippy::cast_sign_loss)]
344    fn list(
345        &self,
346        tags: Option<&[String]>,
347        name_pattern: Option<&str>,
348    ) -> Result<Vec<PromptTemplate>> {
349        let conn = self.lock_conn()?;
350
351        let mut sql = String::from(
352            "SELECT name, description, content, variables, tags, author, usage_count, created_at, updated_at
353             FROM prompts WHERE 1=1",
354        );
355        let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
356
357        // Add name pattern filter (convert glob to SQL LIKE)
358        if let Some(pattern) = name_pattern {
359            let like_pattern = pattern.replace('*', "%").replace('?', "_");
360            sql.push_str(" AND name LIKE ?");
361            params_vec.push(Box::new(like_pattern));
362        }
363
364        sql.push_str(" ORDER BY usage_count DESC, name ASC");
365
366        let mut stmt = conn.prepare(&sql).map_err(|e| Error::OperationFailed {
367            operation: "prepare_list_prompts".to_string(),
368            cause: e.to_string(),
369        })?;
370
371        let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(AsRef::as_ref).collect();
372
373        let rows = stmt
374            .query_map(params_refs.as_slice(), |row| {
375                Ok((
376                    row.get::<_, String>(0)?,
377                    row.get::<_, String>(1)?,
378                    row.get::<_, String>(2)?,
379                    row.get::<_, String>(3)?,
380                    row.get::<_, String>(4)?,
381                    row.get::<_, Option<String>>(5)?,
382                    row.get::<_, i64>(6)?,
383                    row.get::<_, i64>(7)?,
384                    row.get::<_, i64>(8)?,
385                ))
386            })
387            .map_err(|e| Error::OperationFailed {
388                operation: "list_prompts".to_string(),
389                cause: e.to_string(),
390            })?;
391
392        let mut results = Vec::new();
393        for row in rows {
394            let (
395                name,
396                description,
397                content,
398                variables_json,
399                tags_json,
400                author,
401                usage_count,
402                created_at,
403                updated_at,
404            ) = row.map_err(|e| Error::OperationFailed {
405                operation: "read_prompt_row".to_string(),
406                cause: e.to_string(),
407            })?;
408
409            let variables = serde_json::from_str(&variables_json).unwrap_or_default();
410            let prompt_tags: Vec<String> = serde_json::from_str(&tags_json).unwrap_or_default();
411
412            // Filter by tags if specified
413            let has_all_required_tags = tags
414                .is_none_or(|required_tags| required_tags.iter().all(|t| prompt_tags.contains(t)));
415            if !has_all_required_tags {
416                continue;
417            }
418
419            results.push(PromptTemplate {
420                name,
421                description,
422                content,
423                variables,
424                tags: prompt_tags,
425                author,
426                usage_count: usage_count as u64,
427                created_at: created_at as u64,
428                updated_at: updated_at as u64,
429            });
430        }
431
432        Ok(results)
433    }
434
435    fn delete(&self, name: &str) -> Result<bool> {
436        let conn = self.lock_conn()?;
437
438        let rows_affected = conn
439            .execute("DELETE FROM prompts WHERE name = ?1", params![name])
440            .map_err(|e| Error::OperationFailed {
441                operation: "delete_prompt".to_string(),
442                cause: e.to_string(),
443            })?;
444
445        Ok(rows_affected > 0)
446    }
447
448    #[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
449    fn increment_usage(&self, name: &str) -> Result<u64> {
450        let conn = self.lock_conn()?;
451
452        conn.execute(
453            "UPDATE prompts SET usage_count = usage_count + 1, updated_at = ?1 WHERE name = ?2",
454            params![
455                std::time::SystemTime::now()
456                    .duration_since(std::time::UNIX_EPOCH)
457                    .map(|d| d.as_secs() as i64)
458                    .unwrap_or(0),
459                name
460            ],
461        )
462        .map_err(|e| Error::OperationFailed {
463            operation: "increment_usage".to_string(),
464            cause: e.to_string(),
465        })?;
466
467        // Get the new count
468        let count: i64 = conn
469            .query_row(
470                "SELECT usage_count FROM prompts WHERE name = ?1",
471                params![name],
472                |row| row.get(0),
473            )
474            .map_err(|e| Error::OperationFailed {
475                operation: "get_usage_count".to_string(),
476                cause: e.to_string(),
477            })?;
478
479        Ok(count as u64)
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_sqlite_prompt_storage_creation() {
489        let storage = SqlitePromptStorage::in_memory().unwrap();
490        assert_eq!(storage.db_path().to_str(), Some(":memory:"));
491    }
492
493    #[test]
494    fn test_save_and_get_prompt() {
495        let storage = SqlitePromptStorage::in_memory().unwrap();
496
497        let template =
498            PromptTemplate::new("test-prompt", "Hello {{name}}!").with_description("A test prompt");
499
500        let id = storage.save(&template).unwrap();
501        assert!(id.contains("test-prompt"));
502
503        let retrieved = storage.get("test-prompt").unwrap();
504        assert!(retrieved.is_some());
505        let retrieved = retrieved.unwrap();
506        assert_eq!(retrieved.name, "test-prompt");
507        assert_eq!(retrieved.content, "Hello {{name}}!");
508        assert_eq!(retrieved.description, "A test prompt");
509    }
510
511    #[test]
512    fn test_list_prompts() {
513        let storage = SqlitePromptStorage::in_memory().unwrap();
514
515        storage
516            .save(&PromptTemplate::new("alpha", "A").with_tags(vec!["tag1".to_string()]))
517            .unwrap();
518        storage
519            .save(
520                &PromptTemplate::new("beta", "B")
521                    .with_tags(vec!["tag1".to_string(), "tag2".to_string()]),
522            )
523            .unwrap();
524        storage.save(&PromptTemplate::new("gamma", "C")).unwrap();
525
526        // List all
527        let all = storage.list(None, None).unwrap();
528        assert_eq!(all.len(), 3);
529
530        // Filter by tag
531        let with_tag1 = storage.list(Some(&["tag1".to_string()]), None).unwrap();
532        assert_eq!(with_tag1.len(), 2);
533
534        // Filter by name pattern
535        let alpha_pattern = storage.list(None, Some("a*")).unwrap();
536        assert_eq!(alpha_pattern.len(), 1);
537        assert_eq!(alpha_pattern[0].name, "alpha");
538    }
539
540    #[test]
541    fn test_delete_prompt() {
542        let storage = SqlitePromptStorage::in_memory().unwrap();
543
544        storage
545            .save(&PromptTemplate::new("to-delete", "Content"))
546            .unwrap();
547
548        assert!(storage.get("to-delete").unwrap().is_some());
549        assert!(storage.delete("to-delete").unwrap());
550        assert!(storage.get("to-delete").unwrap().is_none());
551        assert!(!storage.delete("to-delete").unwrap()); // Already deleted
552    }
553
554    #[test]
555    fn test_increment_usage() {
556        let storage = SqlitePromptStorage::in_memory().unwrap();
557
558        storage
559            .save(&PromptTemplate::new("used-prompt", "Content"))
560            .unwrap();
561
562        let count1 = storage.increment_usage("used-prompt").unwrap();
563        assert_eq!(count1, 1);
564
565        let count2 = storage.increment_usage("used-prompt").unwrap();
566        assert_eq!(count2, 2);
567
568        let prompt = storage.get("used-prompt").unwrap().unwrap();
569        assert_eq!(prompt.usage_count, 2);
570    }
571
572    #[test]
573    fn test_update_existing_prompt() {
574        let storage = SqlitePromptStorage::in_memory().unwrap();
575
576        // Save initial version
577        storage
578            .save(&PromptTemplate::new("update-me", "Version 1"))
579            .unwrap();
580
581        // Increment usage
582        storage.increment_usage("update-me").unwrap();
583
584        // Update content
585        storage
586            .save(&PromptTemplate::new("update-me", "Version 2").with_description("Updated"))
587            .unwrap();
588
589        // Verify update preserved usage count
590        let prompt = storage.get("update-me").unwrap().unwrap();
591        assert_eq!(prompt.content, "Version 2");
592        assert_eq!(prompt.description, "Updated");
593        assert_eq!(prompt.usage_count, 1); // Preserved from before update
594    }
595
596    #[test]
597    fn test_default_user_path() {
598        let path = SqlitePromptStorage::default_user_path();
599        // Should return Some on most systems
600        if let Some(p) = path {
601            assert!(p.to_string_lossy().contains("subcog"));
602            assert!(p.to_string_lossy().ends_with("memories.db"));
603        }
604    }
605
606    #[test]
607    fn test_vacuum_and_analyze() {
608        let storage = SqlitePromptStorage::in_memory().unwrap();
609
610        // Add and delete some prompts to create fragmentation
611        for i in 0..10 {
612            storage
613                .save(&PromptTemplate::new(format!("temp-{i}"), "Content"))
614                .unwrap();
615        }
616        for i in 0..10 {
617            storage.delete(&format!("temp-{i}")).unwrap();
618        }
619
620        // VACUUM and ANALYZE should succeed
621        assert!(storage.vacuum_and_analyze().is_ok());
622    }
623
624    #[test]
625    fn test_stats() {
626        let storage = SqlitePromptStorage::in_memory().unwrap();
627
628        // Initially empty
629        let stats = storage.stats().unwrap();
630        assert_eq!(stats.prompt_count, 0);
631
632        // Add some prompts
633        storage
634            .save(&PromptTemplate::new("stats-test-1", "Content 1"))
635            .unwrap();
636        storage
637            .save(&PromptTemplate::new("stats-test-2", "Content 2"))
638            .unwrap();
639
640        let stats = storage.stats().unwrap();
641        assert_eq!(stats.prompt_count, 2);
642        assert!(stats.db_size_bytes > 0);
643    }
644}