1use crate::models::{MemoryId, SearchFilter};
10use crate::storage::traits::VectorBackend;
11use crate::{Error, Result};
12use std::collections::HashMap;
13use std::fs;
14use std::path::PathBuf;
15
16pub struct UsearchBackend {
22 index_path: PathBuf,
24 dimensions: usize,
26 vectors: HashMap<String, Vec<f32>>,
28 dirty: bool,
30}
31
32impl UsearchBackend {
33 pub const DEFAULT_DIMENSIONS: usize = 384;
35
36 #[must_use]
38 pub fn new(index_path: impl Into<PathBuf>, dimensions: usize) -> Self {
39 Self {
40 index_path: index_path.into(),
41 dimensions,
42 vectors: HashMap::new(),
43 dirty: false,
44 }
45 }
46
47 #[must_use]
49 pub fn with_default_dimensions(index_path: impl Into<PathBuf>) -> Self {
50 Self::new(index_path, Self::DEFAULT_DIMENSIONS)
51 }
52
53 #[must_use]
55 pub fn in_memory(dimensions: usize) -> Self {
56 Self {
57 index_path: PathBuf::new(),
58 dimensions,
59 vectors: HashMap::new(),
60 dirty: false,
61 }
62 }
63
64 #[must_use]
66 pub const fn index_path(&self) -> &PathBuf {
67 &self.index_path
68 }
69
70 pub fn load(&mut self) -> Result<()> {
76 if self.index_path.as_os_str().is_empty() {
77 return Ok(());
78 }
79
80 if !self.index_path.exists() {
81 return Ok(());
82 }
83
84 let content = fs::read_to_string(&self.index_path).map_err(|e| Error::OperationFailed {
85 operation: "load_index".to_string(),
86 cause: e.to_string(),
87 })?;
88
89 let data: IndexData =
90 serde_json::from_str(&content).map_err(|e| Error::OperationFailed {
91 operation: "parse_index".to_string(),
92 cause: e.to_string(),
93 })?;
94
95 if data.dimensions != self.dimensions {
96 return Err(Error::InvalidInput(format!(
97 "Index dimensions mismatch: expected {}, got {}",
98 self.dimensions, data.dimensions
99 )));
100 }
101
102 self.vectors = data.vectors;
103 self.dirty = false;
104
105 Ok(())
106 }
107
108 pub fn save(&mut self) -> Result<()> {
114 if self.index_path.as_os_str().is_empty() {
115 return Ok(());
116 }
117
118 if !self.dirty {
119 return Ok(());
120 }
121
122 let data = IndexData {
123 dimensions: self.dimensions,
124 vectors: self.vectors.clone(),
125 };
126
127 let content = serde_json::to_string(&data).map_err(|e| Error::OperationFailed {
128 operation: "serialize_index".to_string(),
129 cause: e.to_string(),
130 })?;
131
132 if let Some(parent) = self.index_path.parent() {
134 fs::create_dir_all(parent).map_err(|e| Error::OperationFailed {
135 operation: "create_index_dir".to_string(),
136 cause: e.to_string(),
137 })?;
138 }
139
140 fs::write(&self.index_path, content).map_err(|e| Error::OperationFailed {
141 operation: "write_index".to_string(),
142 cause: e.to_string(),
143 })?;
144
145 self.dirty = false;
146 Ok(())
147 }
148
149 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
151 if a.len() != b.len() {
152 return 0.0;
153 }
154
155 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
156 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
157 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
158
159 if norm_a == 0.0 || norm_b == 0.0 {
160 return 0.0;
161 }
162
163 f32::midpoint(dot_product / (norm_a * norm_b), 1.0)
165 }
166
167 fn validate_embedding(&self, embedding: &[f32]) -> Result<()> {
169 if embedding.len() != self.dimensions {
170 return Err(Error::InvalidInput(format!(
171 "Embedding dimension mismatch: expected {}, got {}",
172 self.dimensions,
173 embedding.len()
174 )));
175 }
176 Ok(())
177 }
178}
179
180#[derive(serde::Serialize, serde::Deserialize)]
182struct IndexData {
183 dimensions: usize,
184 vectors: HashMap<String, Vec<f32>>,
185}
186
187impl VectorBackend for UsearchBackend {
188 fn dimensions(&self) -> usize {
189 self.dimensions
190 }
191
192 fn upsert(&mut self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
193 self.validate_embedding(embedding)?;
194
195 self.vectors
196 .insert(id.as_str().to_string(), embedding.to_vec());
197 self.dirty = true;
198
199 Ok(())
200 }
201
202 fn remove(&mut self, id: &MemoryId) -> Result<bool> {
203 let removed = self.vectors.remove(id.as_str()).is_some();
204 if removed {
205 self.dirty = true;
206 }
207 Ok(removed)
208 }
209
210 fn search(
211 &self,
212 query_embedding: &[f32],
213 _filter: &SearchFilter,
214 limit: usize,
215 ) -> Result<Vec<(MemoryId, f32)>> {
216 self.validate_embedding(query_embedding)?;
217
218 let mut scores: Vec<(String, f32)> = self
220 .vectors
221 .iter()
222 .map(|(id, vec)| {
223 let score = Self::cosine_similarity(query_embedding, vec);
224 (id.clone(), score)
225 })
226 .collect();
227
228 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
230
231 let results: Vec<(MemoryId, f32)> = scores
233 .into_iter()
234 .take(limit)
235 .map(|(id, score)| (MemoryId::new(id), score))
236 .collect();
237
238 Ok(results)
239 }
240
241 fn count(&self) -> Result<usize> {
242 Ok(self.vectors.len())
243 }
244
245 fn clear(&mut self) -> Result<()> {
246 self.vectors.clear();
247 self.dirty = true;
248 Ok(())
249 }
250}
251
252impl Drop for UsearchBackend {
253 fn drop(&mut self) {
254 let _ = self.save();
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use tempfile::TempDir;
263
264 fn create_random_embedding(dimensions: usize) -> Vec<f32> {
265 (0..dimensions).map(|i| ((i % 10) as f32) / 10.0).collect()
266 }
267
268 fn create_normalized_embedding(dimensions: usize, seed: f32) -> Vec<f32> {
269 let raw: Vec<f32> = (0..dimensions).map(|i| (i as f32 + seed).sin()).collect();
270 let norm: f32 = raw.iter().map(|x| x * x).sum::<f32>().sqrt();
271 if norm > 0.0 {
272 raw.into_iter().map(|x| x / norm).collect()
273 } else {
274 raw
275 }
276 }
277
278 #[test]
279 fn test_usearch_backend_creation() {
280 let backend = UsearchBackend::new("/tmp/test.idx", 384);
281 assert_eq!(backend.dimensions(), 384);
282
283 let default = UsearchBackend::with_default_dimensions("/tmp/test.idx");
284 assert_eq!(default.dimensions(), UsearchBackend::DEFAULT_DIMENSIONS);
285
286 let memory = UsearchBackend::in_memory(512);
287 assert_eq!(memory.dimensions(), 512);
288 }
289
290 #[test]
291 fn test_upsert_and_count() {
292 let mut backend = UsearchBackend::in_memory(384);
293
294 let id1 = MemoryId::new("id1");
295 let embedding1 = create_random_embedding(384);
296 backend.upsert(&id1, &embedding1).unwrap();
297
298 assert_eq!(backend.count().unwrap(), 1);
299
300 let id2 = MemoryId::new("id2");
301 let embedding2 = create_random_embedding(384);
302 backend.upsert(&id2, &embedding2).unwrap();
303
304 assert_eq!(backend.count().unwrap(), 2);
305 }
306
307 #[test]
308 fn test_upsert_dimension_mismatch() {
309 let mut backend = UsearchBackend::in_memory(384);
310
311 let id = MemoryId::new("test");
312 let wrong_dim = create_random_embedding(256);
313
314 let result = backend.upsert(&id, &wrong_dim);
315 assert!(result.is_err());
316 }
317
318 #[test]
319 fn test_remove() {
320 let mut backend = UsearchBackend::in_memory(384);
321
322 let id = MemoryId::new("test");
323 let embedding = create_random_embedding(384);
324 backend.upsert(&id, &embedding).unwrap();
325
326 assert_eq!(backend.count().unwrap(), 1);
327
328 let removed = backend.remove(&id).unwrap();
329 assert!(removed);
330 assert_eq!(backend.count().unwrap(), 0);
331
332 let removed = backend.remove(&id).unwrap();
334 assert!(!removed);
335 }
336
337 #[test]
338 fn test_search() {
339 let mut backend = UsearchBackend::in_memory(384);
340
341 for i in 0..5 {
343 let id = MemoryId::new(format!("id{i}"));
344 let embedding = create_normalized_embedding(384, i as f32);
345 backend.upsert(&id, &embedding).unwrap();
346 }
347
348 let query = create_normalized_embedding(384, 0.0);
350 let results = backend.search(&query, &SearchFilter::new(), 3).unwrap();
351
352 assert_eq!(results.len(), 3);
353
354 assert_eq!(results[0].0.as_str(), "id0");
356 assert!(results[0].1 > 0.99); }
358
359 #[test]
360 fn test_search_empty() {
361 let backend = UsearchBackend::in_memory(384);
362
363 let query = create_random_embedding(384);
364 let results = backend.search(&query, &SearchFilter::new(), 10).unwrap();
365
366 assert!(results.is_empty());
367 }
368
369 #[test]
370 fn test_clear() {
371 let mut backend = UsearchBackend::in_memory(384);
372
373 for i in 0..3 {
375 let id = MemoryId::new(format!("id{i}"));
376 let embedding = create_random_embedding(384);
377 backend.upsert(&id, &embedding).unwrap();
378 }
379
380 assert_eq!(backend.count().unwrap(), 3);
381
382 backend.clear().unwrap();
383 assert_eq!(backend.count().unwrap(), 0);
384 }
385
386 #[test]
387 fn test_cosine_similarity() {
388 let v1 = vec![1.0, 0.0, 0.0];
390 let similarity = UsearchBackend::cosine_similarity(&v1, &v1);
391 assert!((similarity - 1.0).abs() < 0.001);
392
393 let v2 = vec![0.0, 1.0, 0.0];
395 let similarity = UsearchBackend::cosine_similarity(&v1, &v2);
396 assert!((similarity - 0.5).abs() < 0.001); let v3 = vec![-1.0, 0.0, 0.0];
400 let similarity = UsearchBackend::cosine_similarity(&v1, &v3);
401 assert!(similarity < 0.001);
402 }
403
404 #[test]
405 fn test_persistence() {
406 let dir = TempDir::new().unwrap();
407 let index_path = dir.path().join("test.idx");
408
409 {
411 let mut backend = UsearchBackend::new(&index_path, 384);
412
413 let id = MemoryId::new("persistent");
414 let embedding = create_random_embedding(384);
415 backend.upsert(&id, &embedding).unwrap();
416 backend.save().unwrap();
417 }
418
419 {
421 let mut backend = UsearchBackend::new(&index_path, 384);
422 backend.load().unwrap();
423
424 assert_eq!(backend.count().unwrap(), 1);
425 }
426 }
427
428 #[test]
429 fn test_load_dimension_mismatch() {
430 let dir = TempDir::new().unwrap();
431 let index_path = dir.path().join("test.idx");
432
433 {
435 let mut backend = UsearchBackend::new(&index_path, 384);
436 let id = MemoryId::new("test");
437 let embedding = create_random_embedding(384);
438 backend.upsert(&id, &embedding).unwrap();
439 backend.save().unwrap();
440 }
441
442 {
444 let mut backend = UsearchBackend::new(&index_path, 512);
445 let result = backend.load();
446 assert!(result.is_err());
447 }
448 }
449
450 #[test]
451 fn test_load_nonexistent() {
452 let mut backend = UsearchBackend::new("/nonexistent/path/index.idx", 384);
453 let result = backend.load();
454 assert!(result.is_ok()); assert_eq!(backend.count().unwrap(), 0);
456 }
457
458 #[test]
459 fn test_update_existing() {
460 let mut backend = UsearchBackend::in_memory(384);
461
462 let id = MemoryId::new("test");
463 let embedding1 = create_normalized_embedding(384, 1.0);
464 backend.upsert(&id, &embedding1).unwrap();
465
466 let embedding2 = create_normalized_embedding(384, 2.0);
468 backend.upsert(&id, &embedding2).unwrap();
469
470 assert_eq!(backend.count().unwrap(), 1);
472
473 let query = create_normalized_embedding(384, 2.0);
475 let results = backend.search(&query, &SearchFilter::new(), 1).unwrap();
476
477 assert_eq!(results.len(), 1);
478 assert!(results[0].1 > 0.99); }
480}