Skip to main content

subcog/storage/vector/
pgvector.rs

1//! pgvector-based vector backend.
2//!
3//! Provides vector similarity search using PostgreSQL with pgvector extension.
4
5#[cfg(feature = "postgres")]
6mod implementation {
7    use crate::embedding::DEFAULT_DIMENSIONS;
8    use crate::models::MemoryId;
9    use crate::storage::migrations::{Migration, MigrationRunner};
10    use crate::storage::traits::{VectorBackend, VectorFilter};
11    use crate::{Error, Result};
12    use deadpool_postgres::{Config, Pool, Runtime};
13    use tokio::runtime::Handle;
14    use tokio_postgres::NoTls;
15
16    /// Embedded migrations compiled into the binary.
17    /// Note: Migration 1 assumes pgvector extension is already installed.
18    /// Run `CREATE EXTENSION IF NOT EXISTS vector;` before using this backend.
19    const MIGRATIONS: &[Migration] = &[
20        Migration {
21            version: 1,
22            description: "Initial vectors table",
23            sql: r"
24                CREATE TABLE IF NOT EXISTS {table} (
25                    id TEXT PRIMARY KEY,
26                    embedding vector(384),
27                    namespace TEXT,
28                    created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW())::BIGINT
29                );
30            ",
31        },
32        Migration {
33            version: 2,
34            description: "Add HNSW index for cosine similarity",
35            sql: r"
36                CREATE INDEX IF NOT EXISTS {table}_embedding_idx
37                ON {table} USING hnsw (embedding vector_cosine_ops)
38                WITH (m = 16, ef_construction = 64);
39            ",
40        },
41        Migration {
42            version: 3,
43            description: "Add namespace index for filtering",
44            sql: r"
45                CREATE INDEX IF NOT EXISTS {table}_namespace_idx ON {table} (namespace);
46            ",
47        },
48    ];
49
50    /// pgvector-based vector backend.
51    pub struct PgvectorBackend {
52        /// Connection pool.
53        pool: Pool,
54        /// Table name for vectors.
55        table_name: String,
56        /// Embedding dimensions.
57        dimensions: usize,
58    }
59
60    /// Helper to map pool errors.
61    fn pool_error(e: impl std::fmt::Display) -> Error {
62        Error::OperationFailed {
63            operation: "pgvector_get_client".to_string(),
64            cause: e.to_string(),
65        }
66    }
67
68    /// Helper to map query errors.
69    fn query_error(op: &str, e: impl std::fmt::Display) -> Error {
70        Error::OperationFailed {
71            operation: op.to_string(),
72            cause: e.to_string(),
73        }
74    }
75
76    impl PgvectorBackend {
77        /// Creates a new pgvector backend.
78        ///
79        /// # Errors
80        ///
81        /// Returns an error if the connection pool fails to initialize or
82        /// if migrations fail (which can happen if pgvector extension is not installed).
83        pub fn new(
84            connection_url: &str,
85            table_name: impl Into<String>,
86            dimensions: usize,
87        ) -> Result<Self> {
88            let table_name = table_name.into();
89            let config = Self::parse_connection_url(connection_url)?;
90            let cfg = Self::build_pool_config(&config);
91
92            let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
93                Error::OperationFailed {
94                    operation: "pgvector_create_pool".to_string(),
95                    cause: e.to_string(),
96                }
97            })?;
98
99            let backend = Self {
100                pool,
101                table_name,
102                dimensions,
103            };
104            backend.run_migrations()?;
105            Ok(backend)
106        }
107
108        /// Parses the connection URL into a tokio-postgres config.
109        fn parse_connection_url(url: &str) -> Result<tokio_postgres::Config> {
110            url.parse::<tokio_postgres::Config>()
111                .map_err(|e| Error::OperationFailed {
112                    operation: "pgvector_parse_url".to_string(),
113                    cause: e.to_string(),
114                })
115        }
116
117        /// Extracts host string from tokio-postgres Host.
118        #[cfg(unix)]
119        fn host_to_string(h: &tokio_postgres::config::Host) -> String {
120            match h {
121                tokio_postgres::config::Host::Tcp(s) => s.clone(),
122                tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().to_string(),
123            }
124        }
125
126        /// Extracts host string from tokio-postgres Host (Windows: Tcp only).
127        #[cfg(not(unix))]
128        fn host_to_string(h: &tokio_postgres::config::Host) -> String {
129            let tokio_postgres::config::Host::Tcp(s) = h;
130            s.clone()
131        }
132
133        /// Builds a deadpool config from tokio-postgres config.
134        fn build_pool_config(config: &tokio_postgres::Config) -> Config {
135            let mut cfg = Config::new();
136            cfg.host = config.get_hosts().first().map(Self::host_to_string);
137            cfg.port = config.get_ports().first().copied();
138            cfg.user = config.get_user().map(String::from);
139            cfg.password = config
140                .get_password()
141                .map(|p| String::from_utf8_lossy(p).to_string());
142            cfg.dbname = config.get_dbname().map(String::from);
143            cfg
144        }
145
146        /// Creates a backend with default settings.
147        ///
148        /// # Errors
149        ///
150        /// Returns an error if the connection fails.
151        pub fn with_defaults() -> Result<Self> {
152            Self::new(
153                "postgresql://localhost/subcog",
154                "memory_vectors",
155                DEFAULT_DIMENSIONS,
156            )
157        }
158
159        /// Runs a blocking operation on the async pool.
160        fn block_on<F, T>(&self, f: F) -> Result<T>
161        where
162            F: std::future::Future<Output = Result<T>>,
163        {
164            if let Ok(handle) = Handle::try_current() {
165                handle.block_on(f)
166            } else {
167                let rt = tokio::runtime::Builder::new_current_thread()
168                    .enable_all()
169                    .build()
170                    .map_err(|e| Error::OperationFailed {
171                        operation: "pgvector_create_runtime".to_string(),
172                        cause: e.to_string(),
173                    })?;
174                rt.block_on(f)
175            }
176        }
177
178        /// Runs migrations.
179        fn run_migrations(&self) -> Result<()> {
180            self.block_on(async {
181                let runner = MigrationRunner::new(self.pool.clone(), &self.table_name);
182                runner.run(MIGRATIONS).await
183            })
184        }
185
186        /// Formats an embedding as a pgvector string: `'[1.0,2.0,3.0]'`.
187        fn format_embedding(embedding: &[f32]) -> String {
188            let values: Vec<String> = embedding
189                .iter()
190                .map(std::string::ToString::to_string)
191                .collect();
192            format!("[{}]", values.join(","))
193        }
194
195        /// Async implementation of upsert operation.
196        async fn upsert_async(&self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
197            let client = self.pool.get().await.map_err(pool_error)?;
198
199            let embedding_str = Self::format_embedding(embedding);
200
201            let upsert = format!(
202                r"INSERT INTO {} (id, embedding)
203                VALUES ($1, $2::vector)
204                ON CONFLICT (id) DO UPDATE SET
205                    embedding = EXCLUDED.embedding",
206                self.table_name
207            );
208
209            client
210                .execute(&upsert, &[&id.as_str(), &embedding_str])
211                .await
212                .map_err(|e| query_error("pgvector_upsert", e))?;
213
214            Ok(())
215        }
216
217        /// Async implementation of remove operation.
218        async fn remove_async(&self, id: &MemoryId) -> Result<bool> {
219            let client = self.pool.get().await.map_err(pool_error)?;
220            let delete = format!("DELETE FROM {} WHERE id = $1", self.table_name);
221            let rows = client
222                .execute(&delete, &[&id.as_str()])
223                .await
224                .map_err(|e| query_error("pgvector_remove", e))?;
225            Ok(rows > 0)
226        }
227
228        /// Async implementation of search operation.
229        /// Returns cosine similarity (1 - `cosine_distance`).
230        async fn search_async(
231            &self,
232            query_embedding: &[f32],
233            filter: &VectorFilter,
234            limit: usize,
235        ) -> Result<Vec<(MemoryId, f32)>> {
236            let client = self.pool.get().await.map_err(pool_error)?;
237            let embedding_str = Self::format_embedding(query_embedding);
238
239            // Build namespace filter if present
240            let (namespace_clause, namespace_params) = Self::build_namespace_filter(filter);
241
242            let search_query = format!(
243                r"SELECT id, 1 - (embedding <=> $1::vector) as similarity
244                FROM {}
245                {}
246                ORDER BY embedding <=> $1::vector
247                LIMIT {}",
248                self.table_name, namespace_clause, limit
249            );
250
251            let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
252            params.push(&embedding_str);
253            for p in &namespace_params {
254                params.push(p);
255            }
256
257            let rows = client
258                .query(&search_query, &params)
259                .await
260                .map_err(|e| query_error("pgvector_search", e))?;
261
262            Ok(rows
263                .iter()
264                .map(|row| {
265                    let id: String = row.get(0);
266                    let similarity: f64 = row.get(1);
267                    #[allow(clippy::cast_possible_truncation)]
268                    (MemoryId::new(&id), similarity as f32)
269                })
270                .collect())
271        }
272
273        /// Builds namespace filter clause.
274        fn build_namespace_filter(filter: &VectorFilter) -> (String, Vec<String>) {
275            if filter.namespaces.is_empty() {
276                return (String::new(), Vec::new());
277            }
278
279            let placeholders: Vec<String> = filter
280                .namespaces
281                .iter()
282                .enumerate()
283                .map(|(i, _)| format!("${}", i + 2))
284                .collect();
285
286            let clause = format!("WHERE namespace IN ({})", placeholders.join(", "));
287            let params: Vec<String> = filter
288                .namespaces
289                .iter()
290                .map(|ns| ns.as_str().to_string())
291                .collect();
292
293            (clause, params)
294        }
295
296        /// Async implementation of count operation.
297        #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
298        async fn count_async(&self) -> Result<usize> {
299            let client = self.pool.get().await.map_err(pool_error)?;
300            let query = format!("SELECT COUNT(*) FROM {}", self.table_name);
301            let row = client
302                .query_one(&query, &[])
303                .await
304                .map_err(|e| query_error("pgvector_count", e))?;
305            let count: i64 = row.get(0);
306            Ok(count as usize)
307        }
308
309        /// Async implementation of clear operation.
310        async fn clear_async(&self) -> Result<()> {
311            let client = self.pool.get().await.map_err(pool_error)?;
312            let truncate = format!("TRUNCATE TABLE {}", self.table_name);
313            client
314                .execute(&truncate, &[])
315                .await
316                .map_err(|e| query_error("pgvector_clear", e))?;
317            Ok(())
318        }
319    }
320
321    impl VectorBackend for PgvectorBackend {
322        fn dimensions(&self) -> usize {
323            self.dimensions
324        }
325
326        fn upsert(&self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
327            self.block_on(self.upsert_async(id, embedding))
328        }
329
330        fn remove(&self, id: &MemoryId) -> Result<bool> {
331            self.block_on(self.remove_async(id))
332        }
333
334        fn search(
335            &self,
336            query_embedding: &[f32],
337            filter: &VectorFilter,
338            limit: usize,
339        ) -> Result<Vec<(MemoryId, f32)>> {
340            self.block_on(self.search_async(query_embedding, filter, limit))
341        }
342
343        fn count(&self) -> Result<usize> {
344            self.block_on(self.count_async())
345        }
346
347        fn clear(&self) -> Result<()> {
348            self.block_on(self.clear_async())
349        }
350    }
351}
352
353#[cfg(feature = "postgres")]
354pub use implementation::PgvectorBackend;
355
356// Re-export centralized DEFAULT_DIMENSIONS from embedding module
357pub use crate::embedding::DEFAULT_DIMENSIONS;
358
359#[cfg(not(feature = "postgres"))]
360mod stub {
361    use crate::embedding::DEFAULT_DIMENSIONS;
362    use crate::models::MemoryId;
363    use crate::storage::traits::{VectorBackend, VectorFilter};
364    use crate::{Error, Result};
365
366    /// pgvector-based vector backend (stub).
367    pub struct PgvectorBackend {
368        /// PostgreSQL connection URL.
369        connection_url: String,
370        /// Table name for vectors.
371        table_name: String,
372        /// Embedding dimensions.
373        dimensions: usize,
374    }
375
376    impl PgvectorBackend {
377        /// Creates a new pgvector backend (stub).
378        #[must_use]
379        pub fn new(
380            connection_url: impl Into<String>,
381            table_name: impl Into<String>,
382            dimensions: usize,
383        ) -> Self {
384            Self {
385                connection_url: connection_url.into(),
386                table_name: table_name.into(),
387                dimensions,
388            }
389        }
390
391        /// Creates a backend with default settings (stub).
392        #[must_use]
393        pub fn with_defaults() -> Self {
394            Self::new(
395                "postgresql://localhost/subcog",
396                "memory_vectors",
397                DEFAULT_DIMENSIONS,
398            )
399        }
400    }
401
402    impl VectorBackend for PgvectorBackend {
403        fn dimensions(&self) -> usize {
404            self.dimensions
405        }
406
407        fn upsert(&self, _id: &MemoryId, _embedding: &[f32]) -> Result<()> {
408            Err(Error::NotImplemented(format!(
409                "PgvectorBackend::upsert for {} on {}",
410                self.table_name, self.connection_url
411            )))
412        }
413
414        fn remove(&self, _id: &MemoryId) -> Result<bool> {
415            Err(Error::NotImplemented(format!(
416                "PgvectorBackend::remove for {} on {}",
417                self.table_name, self.connection_url
418            )))
419        }
420
421        fn search(
422            &self,
423            _query_embedding: &[f32],
424            _filter: &VectorFilter,
425            _limit: usize,
426        ) -> Result<Vec<(MemoryId, f32)>> {
427            Err(Error::NotImplemented(format!(
428                "PgvectorBackend::search for {} on {}",
429                self.table_name, self.connection_url
430            )))
431        }
432
433        fn count(&self) -> Result<usize> {
434            Err(Error::NotImplemented(format!(
435                "PgvectorBackend::count for {} on {}",
436                self.table_name, self.connection_url
437            )))
438        }
439
440        fn clear(&self) -> Result<()> {
441            Err(Error::NotImplemented(format!(
442                "PgvectorBackend::clear for {} on {}",
443                self.table_name, self.connection_url
444            )))
445        }
446    }
447}
448
449#[cfg(not(feature = "postgres"))]
450pub use stub::PgvectorBackend;