1use crate::models::{Memory, MemoryId, SearchFilter};
6use crate::storage::traits::IndexBackend;
7use crate::{Error, Result};
8use rusqlite::{Connection, OptionalExtension, params};
9use std::path::{Path, PathBuf};
10use std::sync::Mutex;
11
12pub struct SqliteBackend {
14 conn: Mutex<Connection>,
16 db_path: Option<PathBuf>,
18}
19
20impl SqliteBackend {
21 pub fn new(db_path: impl Into<PathBuf>) -> Result<Self> {
27 let db_path = db_path.into();
28 let conn = Connection::open(&db_path).map_err(|e| Error::OperationFailed {
29 operation: "open_sqlite".to_string(),
30 cause: e.to_string(),
31 })?;
32
33 let backend = Self {
34 conn: Mutex::new(conn),
35 db_path: Some(db_path),
36 };
37
38 backend.initialize()?;
39 Ok(backend)
40 }
41
42 pub fn in_memory() -> Result<Self> {
48 let conn = Connection::open_in_memory().map_err(|e| Error::OperationFailed {
49 operation: "open_sqlite_memory".to_string(),
50 cause: e.to_string(),
51 })?;
52
53 let backend = Self {
54 conn: Mutex::new(conn),
55 db_path: None,
56 };
57
58 backend.initialize()?;
59 Ok(backend)
60 }
61
62 #[must_use]
64 pub fn db_path(&self) -> Option<&Path> {
65 self.db_path.as_deref()
66 }
67
68 fn initialize(&self) -> Result<()> {
70 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
71 operation: "lock_connection".to_string(),
72 cause: e.to_string(),
73 })?;
74
75 conn.execute(
77 "CREATE TABLE IF NOT EXISTS memories (
78 id TEXT PRIMARY KEY,
79 namespace TEXT NOT NULL,
80 domain TEXT,
81 status TEXT NOT NULL,
82 created_at INTEGER NOT NULL,
83 tags TEXT,
84 source TEXT
85 )",
86 [],
87 )
88 .map_err(|e| Error::OperationFailed {
89 operation: "create_memories_table".to_string(),
90 cause: e.to_string(),
91 })?;
92
93 let _ = conn.execute("ALTER TABLE memories ADD COLUMN source TEXT", []);
95
96 conn.execute(
98 "CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
99 id,
100 content,
101 tags
102 )",
103 [],
104 )
105 .map_err(|e| Error::OperationFailed {
106 operation: "create_fts_table".to_string(),
107 cause: e.to_string(),
108 })?;
109
110 Ok(())
111 }
112
113 fn build_filter_clause_numbered(
116 &self,
117 filter: &SearchFilter,
118 start_param: usize,
119 ) -> (String, Vec<String>, usize) {
120 let mut conditions = Vec::new();
121 let mut params = Vec::new();
122 let mut param_idx = start_param;
123
124 if !filter.namespaces.is_empty() {
125 let placeholders: Vec<String> = filter
126 .namespaces
127 .iter()
128 .map(|_| {
129 let p = format!("?{param_idx}");
130 param_idx += 1;
131 p
132 })
133 .collect();
134 conditions.push(format!("m.namespace IN ({})", placeholders.join(",")));
135 for ns in &filter.namespaces {
136 params.push(ns.as_str().to_string());
137 }
138 }
139
140 if !filter.statuses.is_empty() {
141 let placeholders: Vec<String> = filter
142 .statuses
143 .iter()
144 .map(|_| {
145 let p = format!("?{param_idx}");
146 param_idx += 1;
147 p
148 })
149 .collect();
150 conditions.push(format!("m.status IN ({})", placeholders.join(",")));
151 for s in &filter.statuses {
152 params.push(s.as_str().to_string());
153 }
154 }
155
156 for tag in &filter.tags {
159 conditions.push(format!("(',' || m.tags || ',') LIKE ?{param_idx}"));
160 param_idx += 1;
161 params.push(format!("%,{tag},%"));
162 }
163
164 if !filter.tags_any.is_empty() {
166 let or_conditions: Vec<String> = filter
167 .tags_any
168 .iter()
169 .map(|tag| {
170 let cond = format!("(',' || m.tags || ',') LIKE ?{param_idx}");
171 param_idx += 1;
172 params.push(format!("%,{tag},%"));
173 cond
174 })
175 .collect();
176 conditions.push(format!("({})", or_conditions.join(" OR ")));
177 }
178
179 for tag in &filter.excluded_tags {
181 conditions.push(format!("(',' || m.tags || ',') NOT LIKE ?{param_idx}"));
182 param_idx += 1;
183 params.push(format!("%,{tag},%"));
184 }
185
186 if let Some(ref pattern) = filter.source_pattern {
188 let sql_pattern = pattern.replace('*', "%").replace('?', "_");
190 conditions.push(format!("m.source LIKE ?{param_idx}"));
191 param_idx += 1;
192 params.push(sql_pattern);
193 }
194
195 if let Some(after) = filter.created_after {
196 conditions.push(format!("m.created_at >= ?{param_idx}"));
197 param_idx += 1;
198 params.push(after.to_string());
199 }
200
201 if let Some(before) = filter.created_before {
202 conditions.push(format!("m.created_at <= ?{param_idx}"));
203 param_idx += 1;
204 params.push(before.to_string());
205 }
206
207 let clause = if conditions.is_empty() {
208 String::new()
209 } else {
210 format!(" AND {}", conditions.join(" AND "))
211 };
212
213 (clause, params, param_idx)
214 }
215}
216
217impl IndexBackend for SqliteBackend {
218 fn index(&mut self, memory: &Memory) -> Result<()> {
219 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
220 operation: "lock_connection".to_string(),
221 cause: e.to_string(),
222 })?;
223
224 let tags_str = memory.tags.join(",");
225 let domain_str = memory.domain.to_string();
226
227 conn.execute(
229 "INSERT OR REPLACE INTO memories (id, namespace, domain, status, created_at, tags, source)
230 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
231 params![
232 memory.id.as_str(),
233 memory.namespace.as_str(),
234 domain_str,
235 memory.status.as_str(),
236 memory.created_at,
237 tags_str,
238 memory.source.as_deref()
239 ],
240 )
241 .map_err(|e| Error::OperationFailed {
242 operation: "insert_memory".to_string(),
243 cause: e.to_string(),
244 })?;
245
246 conn.execute(
248 "DELETE FROM memories_fts WHERE id = ?1",
249 params![memory.id.as_str()],
250 )
251 .map_err(|e| Error::OperationFailed {
252 operation: "delete_fts".to_string(),
253 cause: e.to_string(),
254 })?;
255
256 conn.execute(
258 "INSERT INTO memories_fts (id, content, tags) VALUES (?1, ?2, ?3)",
259 params![memory.id.as_str(), memory.content, tags_str],
260 )
261 .map_err(|e| Error::OperationFailed {
262 operation: "insert_fts".to_string(),
263 cause: e.to_string(),
264 })?;
265
266 Ok(())
267 }
268
269 fn remove(&mut self, id: &MemoryId) -> Result<bool> {
270 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
271 operation: "lock_connection".to_string(),
272 cause: e.to_string(),
273 })?;
274
275 conn.execute(
277 "DELETE FROM memories_fts WHERE id = ?1",
278 params![id.as_str()],
279 )
280 .map_err(|e| Error::OperationFailed {
281 operation: "delete_fts".to_string(),
282 cause: e.to_string(),
283 })?;
284
285 let deleted = conn
287 .execute("DELETE FROM memories WHERE id = ?1", params![id.as_str()])
288 .map_err(|e| Error::OperationFailed {
289 operation: "delete_memory".to_string(),
290 cause: e.to_string(),
291 })?;
292
293 Ok(deleted > 0)
294 }
295
296 fn search(
297 &self,
298 query: &str,
299 filter: &SearchFilter,
300 limit: usize,
301 ) -> Result<Vec<(MemoryId, f32)>> {
302 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
303 operation: "lock_connection".to_string(),
304 cause: e.to_string(),
305 })?;
306
307 let (filter_clause, filter_params, next_param) =
310 self.build_filter_clause_numbered(filter, 2);
311
312 let sql = format!(
315 "SELECT f.id, bm25(memories_fts) as score
316 FROM memories_fts f
317 JOIN memories m ON f.id = m.id
318 WHERE memories_fts MATCH ?1 {filter_clause}
319 ORDER BY score
320 LIMIT ?{next_param}"
321 );
322
323 let mut stmt = conn.prepare(&sql).map_err(|e| Error::OperationFailed {
324 operation: "prepare_search".to_string(),
325 cause: e.to_string(),
326 })?;
327
328 let mut results = Vec::new();
330
331 let fts_query = query
334 .split_whitespace()
335 .map(|term| {
336 let escaped = term.replace('"', "\"\"");
338 format!("\"{escaped}\"")
339 })
340 .collect::<Vec<_>>()
341 .join(" OR ");
342
343 let rows = stmt
344 .query_map(
345 rusqlite::params_from_iter(
346 std::iter::once(fts_query)
347 .chain(filter_params.into_iter())
348 .chain(std::iter::once(limit.to_string())),
349 ),
350 |row| {
351 let id: String = row.get(0)?;
352 let score: f64 = row.get(1)?;
353 Ok((id, score))
354 },
355 )
356 .map_err(|e| Error::OperationFailed {
357 operation: "execute_search".to_string(),
358 cause: e.to_string(),
359 })?;
360
361 for row in rows {
362 let (id, score) = row.map_err(|e| Error::OperationFailed {
363 operation: "read_search_row".to_string(),
364 cause: e.to_string(),
365 })?;
366
367 #[allow(clippy::cast_possible_truncation)]
370 let normalized_score = (1.0 / (1.0 - score)).min(1.0) as f32;
371
372 results.push((MemoryId::new(id), normalized_score));
373 }
374
375 if let Some(min_score) = filter.min_score {
377 results.retain(|(_, score)| *score >= min_score);
378 }
379
380 Ok(results)
381 }
382
383 fn clear(&mut self) -> Result<()> {
384 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
385 operation: "lock_connection".to_string(),
386 cause: e.to_string(),
387 })?;
388
389 conn.execute("DELETE FROM memories_fts", [])
390 .map_err(|e| Error::OperationFailed {
391 operation: "clear_fts".to_string(),
392 cause: e.to_string(),
393 })?;
394
395 conn.execute("DELETE FROM memories", [])
396 .map_err(|e| Error::OperationFailed {
397 operation: "clear_memories".to_string(),
398 cause: e.to_string(),
399 })?;
400
401 Ok(())
402 }
403
404 fn list_all(&self, filter: &SearchFilter, limit: usize) -> Result<Vec<(MemoryId, f32)>> {
405 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
406 operation: "lock_connection".to_string(),
407 cause: e.to_string(),
408 })?;
409
410 let (filter_clause, filter_params, next_param) =
412 self.build_filter_clause_numbered(filter, 1);
413
414 let sql = format!(
416 "SELECT m.id, 1.0 as score
417 FROM memories m
418 WHERE 1=1 {filter_clause}
419 ORDER BY m.created_at DESC
420 LIMIT ?{next_param}"
421 );
422
423 let mut stmt = conn.prepare(&sql).map_err(|e| Error::OperationFailed {
424 operation: "prepare_list_all".to_string(),
425 cause: e.to_string(),
426 })?;
427
428 let mut results = Vec::new();
429
430 let rows = stmt
431 .query_map(
432 rusqlite::params_from_iter(
433 filter_params
434 .into_iter()
435 .chain(std::iter::once(limit.to_string())),
436 ),
437 |row| {
438 let id: String = row.get(0)?;
439 let score: f64 = row.get(1)?;
440 Ok((id, score))
441 },
442 )
443 .map_err(|e| Error::OperationFailed {
444 operation: "list_all".to_string(),
445 cause: e.to_string(),
446 })?;
447
448 for row in rows {
449 let (id, score) = row.map_err(|e| Error::OperationFailed {
450 operation: "read_list_row".to_string(),
451 cause: e.to_string(),
452 })?;
453
454 #[allow(clippy::cast_possible_truncation)]
455 results.push((MemoryId::new(id), score as f32));
456 }
457
458 Ok(results)
459 }
460
461 fn get_memory(&self, id: &MemoryId) -> Result<Option<Memory>> {
462 use crate::models::{Domain, MemoryStatus, Namespace};
463
464 let conn = self.conn.lock().map_err(|e| Error::OperationFailed {
465 operation: "lock_connection".to_string(),
466 cause: e.to_string(),
467 })?;
468
469 let mut stmt = conn
471 .prepare(
472 "SELECT m.id, m.namespace, m.domain, m.status, m.created_at, m.tags, f.content
473 FROM memories m
474 JOIN memories_fts f ON m.id = f.id
475 WHERE m.id = ?1",
476 )
477 .map_err(|e| Error::OperationFailed {
478 operation: "prepare_get_memory".to_string(),
479 cause: e.to_string(),
480 })?;
481
482 let result: std::result::Result<Option<_>, _> = stmt
483 .query_row(params![id.as_str()], |row| {
484 let id_str: String = row.get(0)?;
485 let namespace_str: String = row.get(1)?;
486 let domain_str: Option<String> = row.get(2)?;
487 let status_str: String = row.get(3)?;
488 let created_at: i64 = row.get(4)?;
489 let tags_str: Option<String> = row.get(5)?;
490 let content: String = row.get(6)?;
491
492 Ok((
493 id_str,
494 namespace_str,
495 domain_str,
496 status_str,
497 created_at,
498 tags_str,
499 content,
500 ))
501 })
502 .optional();
503
504 let result = result.map_err(|e| Error::OperationFailed {
505 operation: "get_memory".to_string(),
506 cause: e.to_string(),
507 })?;
508
509 let Some((id_str, namespace_str, domain_str, status_str, created_at, tags_str, content)) =
510 result
511 else {
512 return Ok(None);
513 };
514
515 let namespace = Namespace::parse(&namespace_str).unwrap_or_default();
517
518 let domain = domain_str.map_or_else(Domain::new, |d: String| {
520 if d.is_empty() || d == "global" {
521 Domain::new()
522 } else {
523 let parts: Vec<&str> = d.split('/').collect();
524 match parts.len() {
525 1 => Domain {
526 organization: Some(parts[0].to_string()),
527 project: None,
528 repository: None,
529 },
530 2 => Domain {
531 organization: Some(parts[0].to_string()),
532 project: None,
533 repository: Some(parts[1].to_string()),
534 },
535 _ => Domain::new(),
536 }
537 }
538 });
539
540 let status = match status_str.to_lowercase().as_str() {
542 "active" => MemoryStatus::Active,
543 "archived" => MemoryStatus::Archived,
544 "superseded" => MemoryStatus::Superseded,
545 "pending" => MemoryStatus::Pending,
546 "deleted" => MemoryStatus::Deleted,
547 _ => MemoryStatus::Active,
548 };
549
550 let tags: Vec<String> = tags_str
552 .map(|t: String| {
553 t.split(',')
554 .map(|s| s.trim().to_string())
555 .filter(|s| !s.is_empty())
556 .collect()
557 })
558 .unwrap_or_default();
559
560 #[allow(clippy::cast_sign_loss)]
561 let created_at_u64 = created_at as u64;
562
563 Ok(Some(Memory {
564 id: MemoryId::new(id_str),
565 content,
566 namespace,
567 domain,
568 status,
569 created_at: created_at_u64,
570 updated_at: created_at_u64,
571 embedding: None,
572 tags,
573 source: None,
574 }))
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::models::{Domain, MemoryStatus, Namespace};
582
583 fn create_test_memory(id: &str, content: &str, namespace: Namespace) -> Memory {
584 Memory {
585 id: MemoryId::new(id),
586 content: content.to_string(),
587 namespace,
588 domain: Domain::new(),
589 status: MemoryStatus::Active,
590 created_at: 1_234_567_890,
591 updated_at: 1_234_567_890,
592 embedding: None,
593 tags: vec!["test".to_string()],
594 source: None,
595 }
596 }
597
598 #[test]
599 fn test_index_and_search() {
600 let mut backend = SqliteBackend::in_memory().unwrap();
601
602 let memory1 = create_test_memory("id1", "Rust programming language", Namespace::Decisions);
603 let memory2 = create_test_memory("id2", "Python scripting", Namespace::Learnings);
604 let memory3 =
605 create_test_memory("id3", "Rust ownership and borrowing", Namespace::Patterns);
606
607 backend.index(&memory1).unwrap();
608 backend.index(&memory2).unwrap();
609 backend.index(&memory3).unwrap();
610
611 let results = backend.search("Rust", &SearchFilter::new(), 10).unwrap();
613
614 assert_eq!(results.len(), 2);
615 let ids: Vec<_> = results.iter().map(|(id, _)| id.as_str()).collect();
616 assert!(ids.contains(&"id1"));
617 assert!(ids.contains(&"id3"));
618 }
619
620 #[test]
621 fn test_search_with_namespace_filter() {
622 let mut backend = SqliteBackend::in_memory().unwrap();
623
624 let memory1 = create_test_memory("id1", "Rust programming", Namespace::Decisions);
625 let memory2 = create_test_memory("id2", "Rust patterns", Namespace::Patterns);
626
627 backend.index(&memory1).unwrap();
628 backend.index(&memory2).unwrap();
629
630 let filter = SearchFilter::new().with_namespace(Namespace::Patterns);
632 let results = backend.search("Rust", &filter, 10).unwrap();
633
634 assert_eq!(results.len(), 1);
635 assert_eq!(results[0].0.as_str(), "id2");
636 }
637
638 #[test]
639 fn test_remove() {
640 let mut backend = SqliteBackend::in_memory().unwrap();
641
642 let memory = create_test_memory("to_remove", "Test content", Namespace::Decisions);
643 backend.index(&memory).unwrap();
644
645 let results = backend.search("content", &SearchFilter::new(), 10).unwrap();
647 assert_eq!(results.len(), 1);
648
649 let removed = backend.remove(&MemoryId::new("to_remove")).unwrap();
651 assert!(removed);
652
653 let results = backend.search("content", &SearchFilter::new(), 10).unwrap();
655 assert!(results.is_empty());
656 }
657
658 #[test]
659 fn test_clear() {
660 let mut backend = SqliteBackend::in_memory().unwrap();
661
662 backend
663 .index(&create_test_memory("id1", "content1", Namespace::Decisions))
664 .unwrap();
665 backend
666 .index(&create_test_memory("id2", "content2", Namespace::Decisions))
667 .unwrap();
668
669 backend.clear().unwrap();
670
671 let results = backend.search("content", &SearchFilter::new(), 10).unwrap();
672 assert!(results.is_empty());
673 }
674
675 #[test]
676 fn test_reindex() {
677 let mut backend = SqliteBackend::in_memory().unwrap();
678
679 let memories = vec![
680 create_test_memory("id1", "memory one", Namespace::Decisions),
681 create_test_memory("id2", "memory two", Namespace::Learnings),
682 create_test_memory("id3", "memory three", Namespace::Patterns),
683 ];
684
685 backend.reindex(&memories).unwrap();
686
687 let results = backend.search("memory", &SearchFilter::new(), 10).unwrap();
688 assert_eq!(results.len(), 3);
689 }
690
691 #[test]
692 fn test_update_index() {
693 let mut backend = SqliteBackend::in_memory().unwrap();
694
695 let mut memory =
696 create_test_memory("update_test", "original content", Namespace::Decisions);
697 backend.index(&memory).unwrap();
698
699 memory.content = "updated content completely different".to_string();
701 backend.index(&memory).unwrap();
702
703 let old_results = backend
705 .search("original", &SearchFilter::new(), 10)
706 .unwrap();
707 assert!(old_results.is_empty());
708
709 let new_results = backend
711 .search("different", &SearchFilter::new(), 10)
712 .unwrap();
713 assert_eq!(new_results.len(), 1);
714 }
715}