subcog/storage/index/
sqlite.rs

1//! `SQLite` + FTS5 index backend.
2//!
3//! Provides full-text search using `SQLite`'s FTS5 extension.
4
5use crate::models::{Memory, MemoryId, SearchFilter};
6use crate::storage::traits::IndexBackend;
7use crate::{Error, Result};
8use rusqlite::{Connection, OptionalExtension, params};
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11
12/// SQLite-based index backend with FTS5.
13pub struct SqliteBackend {
14    /// Connection to the `SQLite` database.
15    conn: Mutex<Connection>,
16    /// Path to the `SQLite` database (None for in-memory).
17    db_path: Option<PathBuf>,
18}
19
20impl SqliteBackend {
21    /// Creates a new `SQLite` backend.
22    ///
23    /// # Errors
24    ///
25    /// Returns an error if the database cannot be opened or initialized.
26    pub fn new(db_path: impl Into<PathBuf>) -> Result<Self> {
27        let db_path = db_path.into();
28        let conn = Connection::open(&db_path).map_err(|e| Error::OperationFailed {
29            operation: "open_sqlite".to_string(),
30            cause: e.to_string(),
31        })?;
32
33        let backend = Self {
34            conn: Mutex::new(conn),
35            db_path: Some(db_path),
36        };
37
38        backend.initialize()?;
39        Ok(backend)
40    }
41
42    /// Creates an in-memory `SQLite` backend (useful for testing).
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the database cannot be initialized.
47    pub fn in_memory() -> Result<Self> {
48        let conn = Connection::open_in_memory().map_err(|e| Error::OperationFailed {
49            operation: "open_sqlite_memory".to_string(),
50            cause: e.to_string(),
51        })?;
52
53        let backend = Self {
54            conn: Mutex::new(conn),
55            db_path: None,
56        };
57
58        backend.initialize()?;
59        Ok(backend)
60    }
61
62    /// Returns the database path.
63    #[must_use]
64    pub fn db_path(&self) -> Option<&Path> {
65        self.db_path.as_deref()
66    }
67
68    /// Initializes the database schema.
69    fn initialize(&self) -> Result<()> {
70        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
71            operation: "lock_connection".to_string(),
72            cause: e.to_string(),
73        })?;
74
75        // Create the main table for memory metadata
76        conn.execute(
77            "CREATE TABLE IF NOT EXISTS memories (
78                id TEXT PRIMARY KEY,
79                namespace TEXT NOT NULL,
80                domain TEXT,
81                status TEXT NOT NULL,
82                created_at INTEGER NOT NULL,
83                tags TEXT,
84                source TEXT
85            )",
86            [],
87        )
88        .map_err(|e| Error::OperationFailed {
89            operation: "create_memories_table".to_string(),
90            cause: e.to_string(),
91        })?;
92
93        // Add source column if it doesn't exist (for migration)
94        let _ = conn.execute("ALTER TABLE memories ADD COLUMN source TEXT", []);
95
96        // Create FTS5 virtual table for full-text search (standalone, not synced with memories)
97        conn.execute(
98            "CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
99                id,
100                content,
101                tags
102            )",
103            [],
104        )
105        .map_err(|e| Error::OperationFailed {
106            operation: "create_fts_table".to_string(),
107            cause: e.to_string(),
108        })?;
109
110        Ok(())
111    }
112
113    /// Builds a WHERE clause from a search filter with numbered parameters.
114    /// Returns the clause string, the parameters, and the next parameter index.
115    fn build_filter_clause_numbered(
116        &self,
117        filter: &SearchFilter,
118        start_param: usize,
119    ) -> (String, Vec<String>, usize) {
120        let mut conditions = Vec::new();
121        let mut params = Vec::new();
122        let mut param_idx = start_param;
123
124        if !filter.namespaces.is_empty() {
125            let placeholders: Vec<String> = filter
126                .namespaces
127                .iter()
128                .map(|_| {
129                    let p = format!("?{param_idx}");
130                    param_idx += 1;
131                    p
132                })
133                .collect();
134            conditions.push(format!("m.namespace IN ({})", placeholders.join(",")));
135            for ns in &filter.namespaces {
136                params.push(ns.as_str().to_string());
137            }
138        }
139
140        if !filter.statuses.is_empty() {
141            let placeholders: Vec<String> = filter
142                .statuses
143                .iter()
144                .map(|_| {
145                    let p = format!("?{param_idx}");
146                    param_idx += 1;
147                    p
148                })
149                .collect();
150            conditions.push(format!("m.status IN ({})", placeholders.join(",")));
151            for s in &filter.statuses {
152                params.push(s.as_str().to_string());
153            }
154        }
155
156        // Tag filtering (AND logic - must have ALL tags)
157        // Use ',tag,' pattern with wrapped column to match whole tags only
158        for tag in &filter.tags {
159            conditions.push(format!("(',' || m.tags || ',') LIKE ?{param_idx}"));
160            param_idx += 1;
161            params.push(format!("%,{tag},%"));
162        }
163
164        // Tag filtering (OR logic - must have ANY tag)
165        if !filter.tags_any.is_empty() {
166            let or_conditions: Vec<String> = filter
167                .tags_any
168                .iter()
169                .map(|tag| {
170                    let cond = format!("(',' || m.tags || ',') LIKE ?{param_idx}");
171                    param_idx += 1;
172                    params.push(format!("%,{tag},%"));
173                    cond
174                })
175                .collect();
176            conditions.push(format!("({})", or_conditions.join(" OR ")));
177        }
178
179        // Excluded tags (NOT LIKE) - match whole tags only
180        for tag in &filter.excluded_tags {
181            conditions.push(format!("(',' || m.tags || ',') NOT LIKE ?{param_idx}"));
182            param_idx += 1;
183            params.push(format!("%,{tag},%"));
184        }
185
186        // Source pattern (glob-style converted to SQL LIKE)
187        if let Some(ref pattern) = filter.source_pattern {
188            // Convert glob pattern to SQL LIKE pattern: * -> %, ? -> _
189            let sql_pattern = pattern.replace('*', "%").replace('?', "_");
190            conditions.push(format!("m.source LIKE ?{param_idx}"));
191            param_idx += 1;
192            params.push(sql_pattern);
193        }
194
195        if let Some(after) = filter.created_after {
196            conditions.push(format!("m.created_at >= ?{param_idx}"));
197            param_idx += 1;
198            params.push(after.to_string());
199        }
200
201        if let Some(before) = filter.created_before {
202            conditions.push(format!("m.created_at <= ?{param_idx}"));
203            param_idx += 1;
204            params.push(before.to_string());
205        }
206
207        let clause = if conditions.is_empty() {
208            String::new()
209        } else {
210            format!(" AND {}", conditions.join(" AND "))
211        };
212
213        (clause, params, param_idx)
214    }
215}
216
217impl IndexBackend for SqliteBackend {
218    fn index(&mut self, memory: &Memory) -> Result<()> {
219        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
220            operation: "lock_connection".to_string(),
221            cause: e.to_string(),
222        })?;
223
224        let tags_str = memory.tags.join(",");
225        let domain_str = memory.domain.to_string();
226
227        // Insert or replace in main table
228        conn.execute(
229            "INSERT OR REPLACE INTO memories (id, namespace, domain, status, created_at, tags, source)
230             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
231            params![
232                memory.id.as_str(),
233                memory.namespace.as_str(),
234                domain_str,
235                memory.status.as_str(),
236                memory.created_at,
237                tags_str,
238                memory.source.as_deref()
239            ],
240        )
241        .map_err(|e| Error::OperationFailed {
242            operation: "insert_memory".to_string(),
243            cause: e.to_string(),
244        })?;
245
246        // Delete from FTS if exists (FTS5 uses rowid internally for matching)
247        conn.execute(
248            "DELETE FROM memories_fts WHERE id = ?1",
249            params![memory.id.as_str()],
250        )
251        .map_err(|e| Error::OperationFailed {
252            operation: "delete_fts".to_string(),
253            cause: e.to_string(),
254        })?;
255
256        // Insert into FTS table
257        conn.execute(
258            "INSERT INTO memories_fts (id, content, tags) VALUES (?1, ?2, ?3)",
259            params![memory.id.as_str(), memory.content, tags_str],
260        )
261        .map_err(|e| Error::OperationFailed {
262            operation: "insert_fts".to_string(),
263            cause: e.to_string(),
264        })?;
265
266        Ok(())
267    }
268
269    fn remove(&mut self, id: &MemoryId) -> Result<bool> {
270        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
271            operation: "lock_connection".to_string(),
272            cause: e.to_string(),
273        })?;
274
275        // Delete from FTS
276        conn.execute(
277            "DELETE FROM memories_fts WHERE id = ?1",
278            params![id.as_str()],
279        )
280        .map_err(|e| Error::OperationFailed {
281            operation: "delete_fts".to_string(),
282            cause: e.to_string(),
283        })?;
284
285        // Delete from main table
286        let deleted = conn
287            .execute("DELETE FROM memories WHERE id = ?1", params![id.as_str()])
288            .map_err(|e| Error::OperationFailed {
289                operation: "delete_memory".to_string(),
290                cause: e.to_string(),
291            })?;
292
293        Ok(deleted > 0)
294    }
295
296    fn search(
297        &self,
298        query: &str,
299        filter: &SearchFilter,
300        limit: usize,
301    ) -> Result<Vec<(MemoryId, f32)>> {
302        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
303            operation: "lock_connection".to_string(),
304            cause: e.to_string(),
305        })?;
306
307        // Build filter clause with numbered parameters starting from ?2
308        // ?1 is the FTS query
309        let (filter_clause, filter_params, next_param) =
310            self.build_filter_clause_numbered(filter, 2);
311
312        // Use FTS5 MATCH for search with BM25 ranking
313        // Limit parameter comes after all filter parameters
314        let sql = format!(
315            "SELECT f.id, bm25(memories_fts) as score
316             FROM memories_fts f
317             JOIN memories m ON f.id = m.id
318             WHERE memories_fts MATCH ?1 {filter_clause}
319             ORDER BY score
320             LIMIT ?{next_param}"
321        );
322
323        let mut stmt = conn.prepare(&sql).map_err(|e| Error::OperationFailed {
324            operation: "prepare_search".to_string(),
325            cause: e.to_string(),
326        })?;
327
328        // Build parameters: query, filter params, limit
329        let mut results = Vec::new();
330
331        // FTS5 query - escape special characters and wrap terms in quotes
332        // FTS5 special chars: - (NOT), * (prefix), " (phrase), : (column)
333        let fts_query = query
334            .split_whitespace()
335            .map(|term| {
336                // Escape double quotes and wrap each term in quotes for literal matching
337                let escaped = term.replace('"', "\"\"");
338                format!("\"{escaped}\"")
339            })
340            .collect::<Vec<_>>()
341            .join(" OR ");
342
343        let rows = stmt
344            .query_map(
345                rusqlite::params_from_iter(
346                    std::iter::once(fts_query)
347                        .chain(filter_params.into_iter())
348                        .chain(std::iter::once(limit.to_string())),
349                ),
350                |row| {
351                    let id: String = row.get(0)?;
352                    let score: f64 = row.get(1)?;
353                    Ok((id, score))
354                },
355            )
356            .map_err(|e| Error::OperationFailed {
357                operation: "execute_search".to_string(),
358                cause: e.to_string(),
359            })?;
360
361        for row in rows {
362            let (id, score) = row.map_err(|e| Error::OperationFailed {
363                operation: "read_search_row".to_string(),
364                cause: e.to_string(),
365            })?;
366
367            // Normalize BM25 score (BM25 returns negative values, lower is better)
368            // Convert to 0-1 range where higher is better
369            #[allow(clippy::cast_possible_truncation)]
370            let normalized_score = (1.0 / (1.0 - score)).min(1.0) as f32;
371
372            results.push((MemoryId::new(id), normalized_score));
373        }
374
375        // Apply minimum score filter if specified
376        if let Some(min_score) = filter.min_score {
377            results.retain(|(_, score)| *score >= min_score);
378        }
379
380        Ok(results)
381    }
382
383    fn clear(&mut self) -> Result<()> {
384        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
385            operation: "lock_connection".to_string(),
386            cause: e.to_string(),
387        })?;
388
389        conn.execute("DELETE FROM memories_fts", [])
390            .map_err(|e| Error::OperationFailed {
391                operation: "clear_fts".to_string(),
392                cause: e.to_string(),
393            })?;
394
395        conn.execute("DELETE FROM memories", [])
396            .map_err(|e| Error::OperationFailed {
397                operation: "clear_memories".to_string(),
398                cause: e.to_string(),
399            })?;
400
401        Ok(())
402    }
403
404    fn list_all(&self, filter: &SearchFilter, limit: usize) -> Result<Vec<(MemoryId, f32)>> {
405        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
406            operation: "lock_connection".to_string(),
407            cause: e.to_string(),
408        })?;
409
410        // Build filter clause (starting at parameter 1, no FTS query)
411        let (filter_clause, filter_params, next_param) =
412            self.build_filter_clause_numbered(filter, 1);
413
414        // Query all memories without FTS MATCH, ordered by created_at desc
415        let sql = format!(
416            "SELECT m.id, 1.0 as score
417             FROM memories m
418             WHERE 1=1 {filter_clause}
419             ORDER BY m.created_at DESC
420             LIMIT ?{next_param}"
421        );
422
423        let mut stmt = conn.prepare(&sql).map_err(|e| Error::OperationFailed {
424            operation: "prepare_list_all".to_string(),
425            cause: e.to_string(),
426        })?;
427
428        let mut results = Vec::new();
429
430        let rows = stmt
431            .query_map(
432                rusqlite::params_from_iter(
433                    filter_params
434                        .into_iter()
435                        .chain(std::iter::once(limit.to_string())),
436                ),
437                |row| {
438                    let id: String = row.get(0)?;
439                    let score: f64 = row.get(1)?;
440                    Ok((id, score))
441                },
442            )
443            .map_err(|e| Error::OperationFailed {
444                operation: "list_all".to_string(),
445                cause: e.to_string(),
446            })?;
447
448        for row in rows {
449            let (id, score) = row.map_err(|e| Error::OperationFailed {
450                operation: "read_list_row".to_string(),
451                cause: e.to_string(),
452            })?;
453
454            #[allow(clippy::cast_possible_truncation)]
455            results.push((MemoryId::new(id), score as f32));
456        }
457
458        Ok(results)
459    }
460
461    fn get_memory(&self, id: &MemoryId) -> Result<Option<Memory>> {
462        use crate::models::{Domain, MemoryStatus, Namespace};
463
464        let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
465            operation: "lock_connection".to_string(),
466            cause: e.to_string(),
467        })?;
468
469        // Join memories and memories_fts to get full memory data
470        let mut stmt = conn
471            .prepare(
472                "SELECT m.id, m.namespace, m.domain, m.status, m.created_at, m.tags, f.content
473                 FROM memories m
474                 JOIN memories_fts f ON m.id = f.id
475                 WHERE m.id = ?1",
476            )
477            .map_err(|e| Error::OperationFailed {
478                operation: "prepare_get_memory".to_string(),
479                cause: e.to_string(),
480            })?;
481
482        let result: std::result::Result<Option<_>, _> = stmt
483            .query_row(params![id.as_str()], |row| {
484                let id_str: String = row.get(0)?;
485                let namespace_str: String = row.get(1)?;
486                let domain_str: Option<String> = row.get(2)?;
487                let status_str: String = row.get(3)?;
488                let created_at: i64 = row.get(4)?;
489                let tags_str: Option<String> = row.get(5)?;
490                let content: String = row.get(6)?;
491
492                Ok((
493                    id_str,
494                    namespace_str,
495                    domain_str,
496                    status_str,
497                    created_at,
498                    tags_str,
499                    content,
500                ))
501            })
502            .optional();
503
504        let result = result.map_err(|e| Error::OperationFailed {
505            operation: "get_memory".to_string(),
506            cause: e.to_string(),
507        })?;
508
509        let Some((id_str, namespace_str, domain_str, status_str, created_at, tags_str, content)) =
510            result
511        else {
512            return Ok(None);
513        };
514
515        // Parse namespace
516        let namespace = Namespace::parse(&namespace_str).unwrap_or_default();
517
518        // Parse domain
519        let domain = domain_str.map_or_else(Domain::new, |d: String| {
520            if d.is_empty() || d == "global" {
521                Domain::new()
522            } else {
523                let parts: Vec<&str> = d.split('/').collect();
524                match parts.len() {
525                    1 => Domain {
526                        organization: Some(parts[0].to_string()),
527                        project: None,
528                        repository: None,
529                    },
530                    2 => Domain {
531                        organization: Some(parts[0].to_string()),
532                        project: None,
533                        repository: Some(parts[1].to_string()),
534                    },
535                    _ => Domain::new(),
536                }
537            }
538        });
539
540        // Parse status
541        let status = match status_str.to_lowercase().as_str() {
542            "active" => MemoryStatus::Active,
543            "archived" => MemoryStatus::Archived,
544            "superseded" => MemoryStatus::Superseded,
545            "pending" => MemoryStatus::Pending,
546            "deleted" => MemoryStatus::Deleted,
547            _ => MemoryStatus::Active,
548        };
549
550        // Parse tags
551        let tags: Vec<String> = tags_str
552            .map(|t: String| {
553                t.split(',')
554                    .map(|s| s.trim().to_string())
555                    .filter(|s| !s.is_empty())
556                    .collect()
557            })
558            .unwrap_or_default();
559
560        #[allow(clippy::cast_sign_loss)]
561        let created_at_u64 = created_at as u64;
562
563        Ok(Some(Memory {
564            id: MemoryId::new(id_str),
565            content,
566            namespace,
567            domain,
568            status,
569            created_at: created_at_u64,
570            updated_at: created_at_u64,
571            embedding: None,
572            tags,
573            source: None,
574        }))
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use crate::models::{Domain, MemoryStatus, Namespace};
582
583    fn create_test_memory(id: &str, content: &str, namespace: Namespace) -> Memory {
584        Memory {
585            id: MemoryId::new(id),
586            content: content.to_string(),
587            namespace,
588            domain: Domain::new(),
589            status: MemoryStatus::Active,
590            created_at: 1_234_567_890,
591            updated_at: 1_234_567_890,
592            embedding: None,
593            tags: vec!["test".to_string()],
594            source: None,
595        }
596    }
597
598    #[test]
599    fn test_index_and_search() {
600        let mut backend = SqliteBackend::in_memory().unwrap();
601
602        let memory1 = create_test_memory("id1", "Rust programming language", Namespace::Decisions);
603        let memory2 = create_test_memory("id2", "Python scripting", Namespace::Learnings);
604        let memory3 =
605            create_test_memory("id3", "Rust ownership and borrowing", Namespace::Patterns);
606
607        backend.index(&memory1).unwrap();
608        backend.index(&memory2).unwrap();
609        backend.index(&memory3).unwrap();
610
611        // Search for "Rust"
612        let results = backend.search("Rust", &SearchFilter::new(), 10).unwrap();
613
614        assert_eq!(results.len(), 2);
615        let ids: Vec<_> = results.iter().map(|(id, _)| id.as_str()).collect();
616        assert!(ids.contains(&"id1"));
617        assert!(ids.contains(&"id3"));
618    }
619
620    #[test]
621    fn test_search_with_namespace_filter() {
622        let mut backend = SqliteBackend::in_memory().unwrap();
623
624        let memory1 = create_test_memory("id1", "Rust programming", Namespace::Decisions);
625        let memory2 = create_test_memory("id2", "Rust patterns", Namespace::Patterns);
626
627        backend.index(&memory1).unwrap();
628        backend.index(&memory2).unwrap();
629
630        // Search with namespace filter
631        let filter = SearchFilter::new().with_namespace(Namespace::Patterns);
632        let results = backend.search("Rust", &filter, 10).unwrap();
633
634        assert_eq!(results.len(), 1);
635        assert_eq!(results[0].0.as_str(), "id2");
636    }
637
638    #[test]
639    fn test_remove() {
640        let mut backend = SqliteBackend::in_memory().unwrap();
641
642        let memory = create_test_memory("to_remove", "Test content", Namespace::Decisions);
643        backend.index(&memory).unwrap();
644
645        // Verify it exists
646        let results = backend.search("content", &SearchFilter::new(), 10).unwrap();
647        assert_eq!(results.len(), 1);
648
649        // Remove it
650        let removed = backend.remove(&MemoryId::new("to_remove")).unwrap();
651        assert!(removed);
652
653        // Verify it's gone
654        let results = backend.search("content", &SearchFilter::new(), 10).unwrap();
655        assert!(results.is_empty());
656    }
657
658    #[test]
659    fn test_clear() {
660        let mut backend = SqliteBackend::in_memory().unwrap();
661
662        backend
663            .index(&create_test_memory("id1", "content1", Namespace::Decisions))
664            .unwrap();
665        backend
666            .index(&create_test_memory("id2", "content2", Namespace::Decisions))
667            .unwrap();
668
669        backend.clear().unwrap();
670
671        let results = backend.search("content", &SearchFilter::new(), 10).unwrap();
672        assert!(results.is_empty());
673    }
674
675    #[test]
676    fn test_reindex() {
677        let mut backend = SqliteBackend::in_memory().unwrap();
678
679        let memories = vec![
680            create_test_memory("id1", "memory one", Namespace::Decisions),
681            create_test_memory("id2", "memory two", Namespace::Learnings),
682            create_test_memory("id3", "memory three", Namespace::Patterns),
683        ];
684
685        backend.reindex(&memories).unwrap();
686
687        let results = backend.search("memory", &SearchFilter::new(), 10).unwrap();
688        assert_eq!(results.len(), 3);
689    }
690
691    #[test]
692    fn test_update_index() {
693        let mut backend = SqliteBackend::in_memory().unwrap();
694
695        let mut memory =
696            create_test_memory("update_test", "original content", Namespace::Decisions);
697        backend.index(&memory).unwrap();
698
699        // Update the memory
700        memory.content = "updated content completely different".to_string();
701        backend.index(&memory).unwrap();
702
703        // Search for old content should not find it
704        let old_results = backend
705            .search("original", &SearchFilter::new(), 10)
706            .unwrap();
707        assert!(old_results.is_empty());
708
709        // Search for new content should find it
710        let new_results = backend
711            .search("different", &SearchFilter::new(), 10)
712            .unwrap();
713        assert_eq!(new_results.len(), 1);
714    }
715}