Skip to main content

subcog/embedding/
fastembed.rs

1//! FastEmbed-based embedder.
2//!
3//! Provides semantic embeddings using the all-MiniLM-L6-v2 model via fastembed-rs.
4//! When the `fastembed-embeddings` feature is enabled, this uses real ONNX-based
5//! semantic embeddings. Otherwise, falls back to deterministic hash-based pseudo-embeddings.
6
7use super::{DEFAULT_DIMENSIONS, Embedder};
8use crate::{Error, Result};
9
10// ============================================================================
11// Native FastEmbed Implementation (with feature)
12// ============================================================================
13
14#[cfg(feature = "fastembed-embeddings")]
15mod native {
16    use super::{DEFAULT_DIMENSIONS, Embedder, Error, Result};
17    use std::panic::{AssertUnwindSafe, catch_unwind};
18    use std::sync::OnceLock;
19    use std::time::Instant;
20
21    /// Thread-safe singleton for the embedding model.
22    /// Uses `OnceLock` for lazy initialization on first use.
23    static EMBEDDING_MODEL: OnceLock<std::sync::Mutex<fastembed::TextEmbedding>> = OnceLock::new();
24
25    /// `FastEmbed` embedder using all-MiniLM-L6-v2.
26    ///
27    /// Uses the fastembed-rs library for real semantic embeddings.
28    /// The model is lazily loaded on first embed call to preserve cold start time.
29    pub struct FastEmbedEmbedder {
30        /// Model name for logging/debugging.
31        model_name: &'static str,
32    }
33
34    impl FastEmbedEmbedder {
35        /// Default embedding dimensions for all-MiniLM-L6-v2.
36        pub const DEFAULT_DIMENSIONS: usize = DEFAULT_DIMENSIONS;
37
38        /// Creates a new `FastEmbed` embedder.
39        ///
40        /// Note: Model is lazily loaded on first `embed()` call.
41        #[must_use]
42        pub const fn new() -> Self {
43            Self {
44                model_name: "all-MiniLM-L6-v2",
45            }
46        }
47
48        /// Creates a new embedder with custom dimensions.
49        ///
50        /// Note: This is provided for API compatibility but dimensions are
51        /// fixed by the model (384 for all-MiniLM-L6-v2).
52        #[must_use]
53        #[allow(clippy::unused_self)]
54        pub const fn with_dimensions(_dimensions: usize) -> Self {
55            // Dimensions are fixed by the model
56            Self::new()
57        }
58
59        /// Gets or initializes the embedding model (thread-safe).
60        ///
61        /// # Performance Note
62        ///
63        /// The model is loaded lazily on first use to preserve cold start time.
64        /// Subsequent calls return the cached instance.
65        ///
66        /// The first call blocks synchronously (~100-500ms) while loading the ONNX model.
67        /// This is an intentional design decision:
68        /// - One-time cost amortized over all subsequent calls (instant)
69        /// - Sync API is simpler and doesn't require async runtime everywhere
70        /// - Alternative (`tokio::spawn_blocking`) would require async `Embedder` trait
71        ///
72        /// For applications sensitive to first-call latency, consider warming up the
73        /// embedder during startup: `FastEmbedEmbedder::new().embed("warmup").ok();`
74        fn get_model() -> Result<&'static std::sync::Mutex<fastembed::TextEmbedding>> {
75            // Check if already initialized
76            if let Some(model) = EMBEDDING_MODEL.get() {
77                return Ok(model);
78            }
79
80            // Initialize the model
81            tracing::info!("Loading embedding model (first use)...");
82            let start = Instant::now();
83
84            let options = fastembed::InitOptions::new(fastembed::EmbeddingModel::AllMiniLML6V2)
85                .with_show_download_progress(false);
86
87            let model =
88                fastembed::TextEmbedding::try_new(options).map_err(|e| Error::OperationFailed {
89                    operation: "load_embedding_model".to_string(),
90                    cause: e.to_string(),
91                })?;
92
93            tracing::info!(
94                elapsed_ms = start.elapsed().as_millis() as u64,
95                model = "all-MiniLM-L6-v2",
96                "Embedding model loaded successfully"
97            );
98
99            // Store the model, ignoring if another thread beat us to it
100            let _ = EMBEDDING_MODEL.set(std::sync::Mutex::new(model));
101            // Return the (possibly other thread's) model
102            // SAFETY: We just set the model, so it must be present
103            EMBEDDING_MODEL.get().ok_or_else(|| Error::OperationFailed {
104                operation: "get_embedding_model".to_string(),
105                cause: "Model initialization race condition".to_string(),
106            })
107        }
108
109        /// Returns the model name.
110        #[must_use]
111        pub const fn model_name(&self) -> &'static str {
112            self.model_name
113        }
114    }
115
116    impl Default for FastEmbedEmbedder {
117        fn default() -> Self {
118            Self::new()
119        }
120    }
121
122    impl Embedder for FastEmbedEmbedder {
123        fn dimensions(&self) -> usize {
124            Self::DEFAULT_DIMENSIONS
125        }
126
127        fn embed(&self, text: &str) -> Result<Vec<f32>> {
128            if text.is_empty() {
129                return Err(Error::InvalidInput("Cannot embed empty text".to_string()));
130            }
131
132            let model = Self::get_model()?;
133            let mut model = model.lock().map_err(|e| Error::OperationFailed {
134                operation: "lock_embedding_model".to_string(),
135                cause: e.to_string(),
136            })?;
137
138            // PERF-HIGH-004: Use slice reference instead of allocating String.
139            // fastembed accepts impl AsRef<[S]> where S: AsRef<str>, so &[&str] works.
140            let texts = [text];
141
142            // Wrap ONNX runtime call in catch_unwind for graceful degradation (RES-M1).
143            // ONNX runtime can panic on malformed inputs or internal errors.
144            // AssertUnwindSafe is safe here because we don't access any mutable state
145            // after the panic, and fastembed::TextEmbedding is Send + Sync.
146            let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
147
148            let embeddings = result
149                .map_err(|panic_info| {
150                    let panic_msg = panic_info
151                        .downcast_ref::<&str>()
152                        .map(|s| (*s).to_string())
153                        .or_else(|| panic_info.downcast_ref::<String>().cloned())
154                        .unwrap_or_else(|| "unknown panic".to_string());
155                    tracing::error!(
156                        panic_message = %panic_msg,
157                        "ONNX runtime panicked during embedding"
158                    );
159                    Error::OperationFailed {
160                        operation: "embed".to_string(),
161                        cause: format!("ONNX runtime panic: {panic_msg}"),
162                    }
163                })?
164                .map_err(|e| Error::OperationFailed {
165                    operation: "embed".to_string(),
166                    cause: e.to_string(),
167                })?;
168
169            embeddings
170                .into_iter()
171                .next()
172                .ok_or_else(|| Error::OperationFailed {
173                    operation: "embed".to_string(),
174                    cause: "No embedding returned from model".to_string(),
175                })
176        }
177
178        fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
179            if texts.is_empty() {
180                return Ok(Vec::new());
181            }
182
183            if texts.iter().any(|t| t.is_empty()) {
184                return Err(Error::InvalidInput("Cannot embed empty text".to_string()));
185            }
186
187            let model = Self::get_model()?;
188            let mut model = model.lock().map_err(|e| Error::OperationFailed {
189                operation: "lock_embedding_model".to_string(),
190                cause: e.to_string(),
191            })?;
192
193            // PERF-HIGH-004: Pass slice directly instead of allocating Vec<String>.
194            // fastembed accepts impl AsRef<[S]> where S: AsRef<str>, so &[&str] works.
195
196            // Wrap ONNX runtime call in catch_unwind for graceful degradation (RES-M1).
197            let result = catch_unwind(AssertUnwindSafe(|| model.embed(texts, None)));
198
199            result
200                .map_err(|panic_info| {
201                    let panic_msg = panic_info
202                        .downcast_ref::<&str>()
203                        .map(|s| (*s).to_string())
204                        .or_else(|| panic_info.downcast_ref::<String>().cloned())
205                        .unwrap_or_else(|| "unknown panic".to_string());
206                    tracing::error!(
207                        panic_message = %panic_msg,
208                        batch_size = texts.len(),
209                        "ONNX runtime panicked during batch embedding"
210                    );
211                    Error::OperationFailed {
212                        operation: "embed_batch".to_string(),
213                        cause: format!("ONNX runtime panic: {panic_msg}"),
214                    }
215                })?
216                .map_err(|e| Error::OperationFailed {
217                    operation: "embed_batch".to_string(),
218                    cause: e.to_string(),
219                })
220        }
221    }
222}
223
224// ============================================================================
225// Fallback Implementation (without feature)
226// ============================================================================
227
228#[cfg(not(feature = "fastembed-embeddings"))]
229mod fallback {
230    use super::{DEFAULT_DIMENSIONS, Embedder, Error, Result};
231    use std::collections::hash_map::DefaultHasher;
232    use std::hash::{Hash, Hasher};
233
234    /// `FastEmbed` embedder using hash-based pseudo-embeddings.
235    ///
236    /// This is a placeholder implementation that generates deterministic
237    /// pseudo-embeddings based on content hashing. For production use,
238    /// enable the `fastembed-embeddings` feature.
239    ///
240    /// Note: Hash-based embeddings do NOT capture semantic similarity.
241    /// "database storage" and "PostgreSQL database" will NOT be similar.
242    pub struct FastEmbedEmbedder {
243        /// Embedding dimensions.
244        dimensions: usize,
245        /// Whether the embedder is initialized.
246        initialized: bool,
247    }
248
249    impl FastEmbedEmbedder {
250        /// Default embedding dimensions for all-MiniLM-L6-v2.
251        pub const DEFAULT_DIMENSIONS: usize = DEFAULT_DIMENSIONS;
252
253        /// Creates a new `FastEmbed` embedder.
254        #[must_use]
255        pub const fn new() -> Self {
256            Self {
257                dimensions: Self::DEFAULT_DIMENSIONS,
258                initialized: true,
259            }
260        }
261
262        /// Creates a new embedder with custom dimensions.
263        #[must_use]
264        pub const fn with_dimensions(dimensions: usize) -> Self {
265            Self {
266                dimensions,
267                initialized: true,
268            }
269        }
270
271        /// Generates a deterministic pseudo-embedding from text.
272        ///
273        /// This creates a normalized vector based on content hashing.
274        /// Not suitable for semantic similarity but useful for testing.
275        #[allow(clippy::cast_precision_loss)]
276        #[allow(clippy::cast_possible_truncation)]
277        fn pseudo_embed(&self, text: &str) -> Vec<f32> {
278            // Limit word iteration to prevent DoS on very long texts (PERF-H1)
279            const MAX_WORDS: usize = 1000;
280            let mut embedding = vec![0.0f32; self.dimensions];
281
282            // Generate deterministic values based on text content
283            // Iterate directly without collecting to avoid allocation
284            // Limit to MAX_WORDS to bound computation time
285            for (i, word) in text.split_whitespace().take(MAX_WORDS).enumerate() {
286                let mut hasher = DefaultHasher::new();
287                word.hash(&mut hasher);
288                let hash = hasher.finish();
289                Self::distribute_hash(&mut embedding, hash, i, self.dimensions);
290            }
291
292            Self::normalize_embedding(&mut embedding);
293            embedding
294        }
295
296        /// Distributes a hash value across embedding dimensions.
297        #[allow(clippy::cast_precision_loss)]
298        #[allow(clippy::cast_possible_truncation)]
299        fn distribute_hash(embedding: &mut [f32], hash: u64, word_idx: usize, dimensions: usize) {
300            for j in 0..8 {
301                let idx = ((hash >> (j * 8)) as usize + word_idx) % dimensions;
302                let value = ((hash >> (j * 4)) & 0xFF) as f32 / 255.0 - 0.5;
303                embedding[idx] += value;
304            }
305        }
306
307        /// Normalizes an embedding vector in-place.
308        fn normalize_embedding(embedding: &mut [f32]) {
309            let norm_sq: f32 = embedding.iter().map(|x| x * x).sum();
310            if norm_sq <= 0.0 {
311                return;
312            }
313            let inv_norm = norm_sq.sqrt().recip();
314            for v in embedding.iter_mut() {
315                *v *= inv_norm;
316            }
317        }
318    }
319
320    impl Default for FastEmbedEmbedder {
321        fn default() -> Self {
322            Self::new()
323        }
324    }
325
326    impl Embedder for FastEmbedEmbedder {
327        fn dimensions(&self) -> usize {
328            self.dimensions
329        }
330
331        fn embed(&self, text: &str) -> Result<Vec<f32>> {
332            if !self.initialized {
333                return Err(Error::OperationFailed {
334                    operation: "embed".to_string(),
335                    cause: "Embedder not initialized".to_string(),
336                });
337            }
338
339            if text.is_empty() {
340                return Err(Error::InvalidInput("Cannot embed empty text".to_string()));
341            }
342
343            // Use pseudo-embedding (hash-based fallback)
344            // WARNING: This does NOT provide semantic similarity
345            tracing::debug!(
346                "Using pseudo-embedding fallback (fastembed-embeddings feature not enabled)"
347            );
348            Ok(self.pseudo_embed(text))
349        }
350
351        fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
352            if !self.initialized {
353                return Err(Error::OperationFailed {
354                    operation: "embed_batch".to_string(),
355                    cause: "Embedder not initialized".to_string(),
356                });
357            }
358
359            texts.iter().map(|t| self.embed(t)).collect()
360        }
361    }
362}
363
364// ============================================================================
365// Public Re-exports
366// ============================================================================
367
368#[cfg(feature = "fastembed-embeddings")]
369pub use native::FastEmbedEmbedder;
370
371#[cfg(not(feature = "fastembed-embeddings"))]
372pub use fallback::FastEmbedEmbedder;
373
374// ============================================================================
375// Utility Functions
376// ============================================================================
377
378/// Computes cosine similarity between two embedding vectors.
379///
380/// # Arguments
381///
382/// * `a` - First embedding vector
383/// * `b` - Second embedding vector
384///
385/// # Returns
386///
387/// Cosine similarity in range [-1.0, 1.0], or 0.0 if vectors are invalid.
388#[must_use]
389pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
390    if a.len() != b.len() || a.is_empty() {
391        return 0.0;
392    }
393
394    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
395    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
396    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
397
398    if norm_a == 0.0 || norm_b == 0.0 {
399        return 0.0;
400    }
401
402    dot_product / (norm_a * norm_b)
403}
404
405// ============================================================================
406// Tests
407// ============================================================================
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412
413    #[test]
414    fn test_embedder_creation() {
415        let embedder = FastEmbedEmbedder::new();
416        assert_eq!(embedder.dimensions(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
417    }
418
419    #[test]
420    fn test_embed_empty_text() {
421        let embedder = FastEmbedEmbedder::new();
422        let result = embedder.embed("");
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn test_embedder_default_trait() {
428        let embedder = FastEmbedEmbedder::default();
429        assert_eq!(embedder.dimensions(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
430    }
431
432    #[test]
433    fn test_embed_batch_empty_list() {
434        let embedder = FastEmbedEmbedder::new();
435        let texts: Vec<&str> = vec![];
436
437        let result = embedder.embed_batch(&texts);
438        assert!(result.is_ok());
439        assert!(result.expect("embed_batch failed").is_empty());
440    }
441
442    #[test]
443    fn test_embed_batch_with_empty_fails() {
444        let embedder = FastEmbedEmbedder::new();
445        let texts = vec!["Valid text", "", "Another valid"];
446
447        // Batch with empty string should fail
448        let result = embedder.embed_batch(&texts);
449        assert!(result.is_err());
450    }
451
452    #[test]
453    fn test_cosine_similarity_identical() {
454        let v = vec![1.0, 0.0, 0.0];
455        let similarity = cosine_similarity(&v, &v);
456        assert!(
457            (similarity - 1.0).abs() < 0.001,
458            "Identical vectors should have similarity ~1.0"
459        );
460    }
461
462    #[test]
463    fn test_cosine_similarity_orthogonal() {
464        let v1 = vec![1.0, 0.0, 0.0];
465        let v2 = vec![0.0, 1.0, 0.0];
466        let similarity = cosine_similarity(&v1, &v2);
467        assert!(
468            similarity.abs() < 0.001,
469            "Orthogonal vectors should have similarity ~0.0"
470        );
471    }
472
473    #[test]
474    fn test_cosine_similarity_opposite() {
475        let v1 = vec![1.0, 0.0, 0.0];
476        let v2 = vec![-1.0, 0.0, 0.0];
477        let similarity = cosine_similarity(&v1, &v2);
478        assert!(
479            (similarity + 1.0).abs() < 0.001,
480            "Opposite vectors should have similarity ~-1.0"
481        );
482    }
483
484    #[test]
485    fn test_cosine_similarity_different_lengths() {
486        let v1 = vec![1.0, 0.0];
487        let v2 = vec![1.0, 0.0, 0.0];
488        let similarity = cosine_similarity(&v1, &v2);
489        assert!(
490            similarity.abs() < f32::EPSILON,
491            "Different length vectors should return 0.0, got {similarity}"
492        );
493    }
494
495    #[test]
496    fn test_cosine_similarity_empty() {
497        let v1: Vec<f32> = vec![];
498        let v2: Vec<f32> = vec![];
499        let similarity = cosine_similarity(&v1, &v2);
500        assert!(
501            similarity.abs() < f32::EPSILON,
502            "Empty vectors should return 0.0, got {similarity}"
503        );
504    }
505
506    // Tests that require the fastembed feature and model download.
507    // These tests are ignored by default because the model download from
508    // Hugging Face can be flaky in CI environments. Run with:
509    //   cargo test --all-features -- --ignored
510    #[cfg(feature = "fastembed-embeddings")]
511    mod fastembed_tests {
512        use super::*;
513
514        #[test]
515        #[ignore = "requires fastembed model download - flaky in CI"]
516        fn test_embed_success() {
517            let embedder = FastEmbedEmbedder::new();
518            let result = embedder.embed("Hello, world!");
519
520            assert!(result.is_ok());
521            let embedding = result.expect("embed failed");
522            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
523        }
524
525        #[test]
526        #[ignore = "requires fastembed model download - flaky in CI"]
527        fn test_embed_deterministic() {
528            let embedder = FastEmbedEmbedder::new();
529            let text = "Rust programming language";
530
531            let result1 = embedder.embed(text);
532            let result2 = embedder.embed(text);
533
534            // Same text should produce same embedding
535            assert!(result1.is_ok());
536            assert!(result2.is_ok());
537
538            let emb1 = result1.expect("embed failed");
539            let emb2 = result2.expect("embed failed");
540
541            for (v1, v2) in emb1.iter().zip(emb2.iter()) {
542                assert!((v1 - v2).abs() < f32::EPSILON);
543            }
544        }
545
546        #[test]
547        #[ignore = "requires fastembed model download - flaky in CI"]
548        fn test_embed_different_text() {
549            let embedder = FastEmbedEmbedder::new();
550
551            let result1 = embedder.embed("Rust programming");
552            let result2 = embedder.embed("Python scripting");
553
554            assert!(result1.is_ok());
555            assert!(result2.is_ok());
556
557            // Different text should produce different embeddings
558            let emb1 = result1.expect("embed failed");
559            let emb2 = result2.expect("embed failed");
560
561            let different = emb1
562                .iter()
563                .zip(emb2.iter())
564                .any(|(v1, v2)| (v1 - v2).abs() > f32::EPSILON);
565            assert!(different);
566        }
567
568        #[test]
569        #[ignore = "requires fastembed model download - flaky in CI"]
570        fn test_embed_normalized() {
571            let embedder = FastEmbedEmbedder::new();
572            let result = embedder.embed("Test embedding normalization");
573
574            assert!(result.is_ok());
575            let emb = result.expect("embed failed");
576
577            // Check that the embedding is normalized (magnitude ~= 1)
578            let magnitude: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
579            assert!(
580                (magnitude - 1.0).abs() < 0.01,
581                "Embedding magnitude should be ~1.0, got {magnitude}"
582            );
583        }
584
585        #[test]
586        #[ignore = "requires fastembed model download - flaky in CI"]
587        fn test_embed_batch() {
588            let embedder = FastEmbedEmbedder::new();
589            let texts = vec!["First text", "Second text", "Third text"];
590
591            let result = embedder.embed_batch(&texts);
592            assert!(result.is_ok());
593
594            let embeddings = result.expect("embed_batch failed");
595            assert_eq!(embeddings.len(), 3);
596
597            for emb in &embeddings {
598                assert_eq!(emb.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
599            }
600        }
601
602        #[test]
603        #[ignore = "requires fastembed model download - flaky in CI"]
604        fn test_semantic_similarity_related_text() {
605            let embedder = FastEmbedEmbedder::new();
606
607            let emb_db = embedder.embed("database storage").expect("embed failed");
608            let emb_pg = embedder.embed("PostgreSQL database").expect("embed failed");
609            let emb_cat = embedder.embed("cat dog pet animal").expect("embed failed");
610
611            let sim_related = cosine_similarity(&emb_db, &emb_pg);
612            let sim_unrelated = cosine_similarity(&emb_db, &emb_cat);
613
614            assert!(
615                sim_related > sim_unrelated,
616                "Related text ({sim_related}) should be more similar than unrelated ({sim_unrelated})"
617            );
618            assert!(
619                sim_related > 0.5,
620                "Related text should have high similarity (>0.5), got {sim_related}"
621            );
622        }
623
624        #[test]
625        #[ignore = "requires fastembed model download - flaky in CI"]
626        fn test_embed_unicode_text() {
627            let embedder = FastEmbedEmbedder::new();
628
629            // Unicode text should embed without error
630            let result = embedder.embed("Hello 世界 🌍 café");
631            assert!(result.is_ok());
632
633            let embedding = result.expect("embed failed");
634            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
635        }
636
637        #[test]
638        #[ignore = "requires fastembed model download - flaky in CI"]
639        fn test_embed_single_word() {
640            let embedder = FastEmbedEmbedder::new();
641            let result = embedder.embed("hello");
642
643            assert!(result.is_ok());
644            let embedding = result.expect("embed failed");
645            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
646
647            // Should be normalized
648            let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
649            assert!(magnitude > 0.9 && magnitude < 1.1);
650        }
651
652        #[test]
653        #[ignore = "requires fastembed model download - flaky in CI"]
654        fn test_embed_all_values_finite() {
655            let embedder = FastEmbedEmbedder::new();
656            let result = embedder.embed("Test for finite values");
657
658            assert!(result.is_ok());
659            let embedding = result.expect("embed failed");
660
661            // All values should be finite (not NaN or Inf)
662            for val in &embedding {
663                assert!(
664                    val.is_finite(),
665                    "Embedding contains non-finite value: {val}"
666                );
667            }
668        }
669
670        #[test]
671        #[ignore = "requires fastembed model download - flaky in CI"]
672        fn test_embed_values_in_range() {
673            let embedder = FastEmbedEmbedder::new();
674            let result = embedder.embed("Test for value range");
675
676            assert!(result.is_ok());
677            let embedding = result.expect("embed failed");
678
679            // Normalized embeddings should have values roughly in [-1, 1]
680            for val in &embedding {
681                assert!(
682                    *val >= -2.0 && *val <= 2.0,
683                    "Value {val} outside expected range"
684                );
685            }
686        }
687    }
688
689    // Fallback-specific tests
690    #[cfg(not(feature = "fastembed-embeddings"))]
691    mod fallback_tests {
692        use super::*;
693
694        #[test]
695        fn test_embed_success() {
696            let embedder = FastEmbedEmbedder::new();
697            let result = embedder.embed("Hello, world!");
698
699            assert!(result.is_ok());
700            let embedding = result.expect("embed failed");
701            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
702        }
703
704        #[test]
705        fn test_embed_deterministic() {
706            let embedder = FastEmbedEmbedder::new();
707            let text = "Rust programming language";
708
709            let result1 = embedder.embed(text);
710            let result2 = embedder.embed(text);
711
712            // Same text should produce same embedding
713            assert!(result1.is_ok());
714            assert!(result2.is_ok());
715
716            let emb1 = result1.expect("embed failed");
717            let emb2 = result2.expect("embed failed");
718
719            for (v1, v2) in emb1.iter().zip(emb2.iter()) {
720                assert!((v1 - v2).abs() < f32::EPSILON);
721            }
722        }
723
724        #[test]
725        fn test_custom_dimensions() {
726            let embedder = FastEmbedEmbedder::with_dimensions(512);
727            assert_eq!(embedder.dimensions(), 512);
728        }
729
730        #[test]
731        fn test_custom_dimensions_embed() {
732            let embedder = FastEmbedEmbedder::with_dimensions(128);
733
734            let result = embedder.embed("Test with custom dimensions");
735            assert!(result.is_ok());
736
737            let embedding = result.expect("embed failed");
738            assert_eq!(embedding.len(), 128);
739        }
740
741        #[test]
742        fn test_embed_whitespace_only() {
743            let embedder = FastEmbedEmbedder::new();
744
745            // Whitespace-only should produce an embedding (not empty text)
746            let result = embedder.embed("   \t\n  ");
747            // Depending on implementation, could be error or valid embedding
748            // Current implementation: whitespace splits to no words, produces zero vector
749            assert!(result.is_ok());
750        }
751
752        #[test]
753        fn test_embed_normalized() {
754            let embedder = FastEmbedEmbedder::new();
755            let result = embedder.embed("Test embedding normalization");
756
757            assert!(result.is_ok());
758            let emb = result.expect("embed failed");
759
760            // Check that the embedding is normalized (magnitude ~= 1)
761            let magnitude: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
762            assert!((magnitude - 1.0).abs() < 0.01);
763        }
764
765        #[test]
766        fn test_embed_batch() {
767            let embedder = FastEmbedEmbedder::new();
768            let texts = vec!["First text", "Second text", "Third text"];
769
770            let result = embedder.embed_batch(&texts);
771            assert!(result.is_ok());
772
773            let embeddings = result.expect("embed_batch failed");
774            assert_eq!(embeddings.len(), 3);
775
776            for emb in &embeddings {
777                assert_eq!(emb.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
778            }
779        }
780
781        #[test]
782        fn test_embed_batch_single_item() {
783            let embedder = FastEmbedEmbedder::new();
784            let texts = vec!["Single item"];
785
786            let result = embedder.embed_batch(&texts);
787            assert!(result.is_ok());
788
789            let embeddings = result.expect("embed_batch failed");
790            assert_eq!(embeddings.len(), 1);
791        }
792
793        #[test]
794        fn test_embed_case_sensitivity() {
795            let embedder = FastEmbedEmbedder::new();
796
797            let lower = embedder.embed("hello world").expect("embed failed");
798            let upper = embedder.embed("HELLO WORLD").expect("embed failed");
799            let mixed = embedder.embed("Hello World").expect("embed failed");
800
801            // Different cases should produce different embeddings
802            let lower_upper_different = lower
803                .iter()
804                .zip(upper.iter())
805                .any(|(a, b)| (a - b).abs() > f32::EPSILON);
806            let lower_mixed_different = lower
807                .iter()
808                .zip(mixed.iter())
809                .any(|(a, b)| (a - b).abs() > f32::EPSILON);
810
811            assert!(lower_upper_different);
812            assert!(lower_mixed_different);
813        }
814
815        #[test]
816        fn test_embed_word_order_matters() {
817            let embedder = FastEmbedEmbedder::new();
818
819            let emb1 = embedder.embed("the quick brown fox").expect("embed failed");
820            let emb2 = embedder.embed("brown quick the fox").expect("embed failed");
821
822            // Different word order should produce different embeddings
823            let different = emb1
824                .iter()
825                .zip(emb2.iter())
826                .any(|(a, b)| (a - b).abs() > f32::EPSILON);
827            assert!(different);
828        }
829
830        #[test]
831        fn test_embed_all_values_finite() {
832            let embedder = FastEmbedEmbedder::new();
833            let result = embedder.embed("Test for finite values");
834
835            assert!(result.is_ok());
836            let embedding = result.expect("embed failed");
837
838            // All values should be finite (not NaN or Inf)
839            for val in &embedding {
840                assert!(
841                    val.is_finite(),
842                    "Embedding contains non-finite value: {val}"
843                );
844            }
845        }
846
847        #[test]
848        fn test_embed_values_in_range() {
849            let embedder = FastEmbedEmbedder::new();
850            let result = embedder.embed("Test for value range");
851
852            assert!(result.is_ok());
853            let embedding = result.expect("embed failed");
854
855            // Normalized embeddings should have values roughly in [-1, 1]
856            for val in &embedding {
857                assert!(
858                    *val >= -2.0 && *val <= 2.0,
859                    "Value {val} outside expected range"
860                );
861            }
862        }
863
864        #[test]
865        fn test_embed_unicode_text() {
866            let embedder = FastEmbedEmbedder::new();
867
868            // Unicode text should embed without error
869            let result = embedder.embed("Hello 世界 🌍 café");
870            assert!(result.is_ok());
871
872            let embedding = result.expect("embed failed");
873            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
874        }
875
876        #[test]
877        fn test_embed_very_long_text() {
878            let embedder = FastEmbedEmbedder::new();
879
880            // Create a long text
881            let long_text = "word ".repeat(10000);
882            let result = embedder.embed(&long_text);
883
884            assert!(result.is_ok());
885            let embedding = result.expect("embed failed");
886            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
887        }
888
889        #[test]
890        fn test_embed_special_characters() {
891            let embedder = FastEmbedEmbedder::new();
892
893            let result = embedder.embed("!@#$%^&*()_+-=[]{}|;':\",./<>?");
894            assert!(result.is_ok());
895        }
896
897        #[test]
898        fn test_embed_numeric_text() {
899            let embedder = FastEmbedEmbedder::new();
900
901            let result = embedder.embed("12345 67890");
902            assert!(result.is_ok());
903
904            let embedding = result.expect("embed failed");
905            assert_eq!(embedding.len(), FastEmbedEmbedder::DEFAULT_DIMENSIONS);
906        }
907    }
908}