1#[cfg(feature = "postgres")]
6mod implementation {
7 use crate::embedding::DEFAULT_DIMENSIONS;
8 use crate::models::MemoryId;
9 use crate::storage::migrations::{Migration, MigrationRunner};
10 use crate::storage::traits::{VectorBackend, VectorFilter};
11 use crate::{Error, Result};
12 use deadpool_postgres::{Config, Pool, Runtime};
13 use tokio::runtime::Handle;
14 use tokio_postgres::NoTls;
15
16 const MIGRATIONS: &[Migration] = &[
20 Migration {
21 version: 1,
22 description: "Initial vectors table",
23 sql: r"
24 CREATE TABLE IF NOT EXISTS {table} (
25 id TEXT PRIMARY KEY,
26 embedding vector(384),
27 namespace TEXT,
28 created_at BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW())::BIGINT
29 );
30 ",
31 },
32 Migration {
33 version: 2,
34 description: "Add HNSW index for cosine similarity",
35 sql: r"
36 CREATE INDEX IF NOT EXISTS {table}_embedding_idx
37 ON {table} USING hnsw (embedding vector_cosine_ops)
38 WITH (m = 16, ef_construction = 64);
39 ",
40 },
41 Migration {
42 version: 3,
43 description: "Add namespace index for filtering",
44 sql: r"
45 CREATE INDEX IF NOT EXISTS {table}_namespace_idx ON {table} (namespace);
46 ",
47 },
48 ];
49
50 pub struct PgvectorBackend {
52 pool: Pool,
54 table_name: String,
56 dimensions: usize,
58 }
59
60 fn pool_error(e: impl std::fmt::Display) -> Error {
62 Error::OperationFailed {
63 operation: "pgvector_get_client".to_string(),
64 cause: e.to_string(),
65 }
66 }
67
68 fn query_error(op: &str, e: impl std::fmt::Display) -> Error {
70 Error::OperationFailed {
71 operation: op.to_string(),
72 cause: e.to_string(),
73 }
74 }
75
76 impl PgvectorBackend {
77 pub fn new(
84 connection_url: &str,
85 table_name: impl Into<String>,
86 dimensions: usize,
87 ) -> Result<Self> {
88 let table_name = table_name.into();
89 let config = Self::parse_connection_url(connection_url)?;
90 let cfg = Self::build_pool_config(&config);
91
92 let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
93 Error::OperationFailed {
94 operation: "pgvector_create_pool".to_string(),
95 cause: e.to_string(),
96 }
97 })?;
98
99 let backend = Self {
100 pool,
101 table_name,
102 dimensions,
103 };
104 backend.run_migrations()?;
105 Ok(backend)
106 }
107
108 fn parse_connection_url(url: &str) -> Result<tokio_postgres::Config> {
110 url.parse::<tokio_postgres::Config>()
111 .map_err(|e| Error::OperationFailed {
112 operation: "pgvector_parse_url".to_string(),
113 cause: e.to_string(),
114 })
115 }
116
117 #[cfg(unix)]
119 fn host_to_string(h: &tokio_postgres::config::Host) -> String {
120 match h {
121 tokio_postgres::config::Host::Tcp(s) => s.clone(),
122 tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().to_string(),
123 }
124 }
125
126 #[cfg(not(unix))]
128 fn host_to_string(h: &tokio_postgres::config::Host) -> String {
129 let tokio_postgres::config::Host::Tcp(s) = h;
130 s.clone()
131 }
132
133 fn build_pool_config(config: &tokio_postgres::Config) -> Config {
135 let mut cfg = Config::new();
136 cfg.host = config.get_hosts().first().map(Self::host_to_string);
137 cfg.port = config.get_ports().first().copied();
138 cfg.user = config.get_user().map(String::from);
139 cfg.password = config
140 .get_password()
141 .map(|p| String::from_utf8_lossy(p).to_string());
142 cfg.dbname = config.get_dbname().map(String::from);
143 cfg
144 }
145
146 pub fn with_defaults() -> Result<Self> {
152 Self::new(
153 "postgresql://localhost/subcog",
154 "memory_vectors",
155 DEFAULT_DIMENSIONS,
156 )
157 }
158
159 fn block_on<F, T>(&self, f: F) -> Result<T>
161 where
162 F: std::future::Future<Output = Result<T>>,
163 {
164 if let Ok(handle) = Handle::try_current() {
165 handle.block_on(f)
166 } else {
167 let rt = tokio::runtime::Builder::new_current_thread()
168 .enable_all()
169 .build()
170 .map_err(|e| Error::OperationFailed {
171 operation: "pgvector_create_runtime".to_string(),
172 cause: e.to_string(),
173 })?;
174 rt.block_on(f)
175 }
176 }
177
178 fn run_migrations(&self) -> Result<()> {
180 self.block_on(async {
181 let runner = MigrationRunner::new(self.pool.clone(), &self.table_name);
182 runner.run(MIGRATIONS).await
183 })
184 }
185
186 fn format_embedding(embedding: &[f32]) -> String {
188 let values: Vec<String> = embedding
189 .iter()
190 .map(std::string::ToString::to_string)
191 .collect();
192 format!("[{}]", values.join(","))
193 }
194
195 async fn upsert_async(&self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
197 let client = self.pool.get().await.map_err(pool_error)?;
198
199 let embedding_str = Self::format_embedding(embedding);
200
201 let upsert = format!(
202 r"INSERT INTO {} (id, embedding)
203 VALUES ($1, $2::vector)
204 ON CONFLICT (id) DO UPDATE SET
205 embedding = EXCLUDED.embedding",
206 self.table_name
207 );
208
209 client
210 .execute(&upsert, &[&id.as_str(), &embedding_str])
211 .await
212 .map_err(|e| query_error("pgvector_upsert", e))?;
213
214 Ok(())
215 }
216
217 async fn remove_async(&self, id: &MemoryId) -> Result<bool> {
219 let client = self.pool.get().await.map_err(pool_error)?;
220 let delete = format!("DELETE FROM {} WHERE id = $1", self.table_name);
221 let rows = client
222 .execute(&delete, &[&id.as_str()])
223 .await
224 .map_err(|e| query_error("pgvector_remove", e))?;
225 Ok(rows > 0)
226 }
227
228 async fn search_async(
231 &self,
232 query_embedding: &[f32],
233 filter: &VectorFilter,
234 limit: usize,
235 ) -> Result<Vec<(MemoryId, f32)>> {
236 let client = self.pool.get().await.map_err(pool_error)?;
237 let embedding_str = Self::format_embedding(query_embedding);
238
239 let (namespace_clause, namespace_params) = Self::build_namespace_filter(filter);
241
242 let search_query = format!(
243 r"SELECT id, 1 - (embedding <=> $1::vector) as similarity
244 FROM {}
245 {}
246 ORDER BY embedding <=> $1::vector
247 LIMIT {}",
248 self.table_name, namespace_clause, limit
249 );
250
251 let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
252 params.push(&embedding_str);
253 for p in &namespace_params {
254 params.push(p);
255 }
256
257 let rows = client
258 .query(&search_query, ¶ms)
259 .await
260 .map_err(|e| query_error("pgvector_search", e))?;
261
262 Ok(rows
263 .iter()
264 .map(|row| {
265 let id: String = row.get(0);
266 let similarity: f64 = row.get(1);
267 #[allow(clippy::cast_possible_truncation)]
268 (MemoryId::new(&id), similarity as f32)
269 })
270 .collect())
271 }
272
273 fn build_namespace_filter(filter: &VectorFilter) -> (String, Vec<String>) {
275 if filter.namespaces.is_empty() {
276 return (String::new(), Vec::new());
277 }
278
279 let placeholders: Vec<String> = filter
280 .namespaces
281 .iter()
282 .enumerate()
283 .map(|(i, _)| format!("${}", i + 2))
284 .collect();
285
286 let clause = format!("WHERE namespace IN ({})", placeholders.join(", "));
287 let params: Vec<String> = filter
288 .namespaces
289 .iter()
290 .map(|ns| ns.as_str().to_string())
291 .collect();
292
293 (clause, params)
294 }
295
296 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
298 async fn count_async(&self) -> Result<usize> {
299 let client = self.pool.get().await.map_err(pool_error)?;
300 let query = format!("SELECT COUNT(*) FROM {}", self.table_name);
301 let row = client
302 .query_one(&query, &[])
303 .await
304 .map_err(|e| query_error("pgvector_count", e))?;
305 let count: i64 = row.get(0);
306 Ok(count as usize)
307 }
308
309 async fn clear_async(&self) -> Result<()> {
311 let client = self.pool.get().await.map_err(pool_error)?;
312 let truncate = format!("TRUNCATE TABLE {}", self.table_name);
313 client
314 .execute(&truncate, &[])
315 .await
316 .map_err(|e| query_error("pgvector_clear", e))?;
317 Ok(())
318 }
319 }
320
321 impl VectorBackend for PgvectorBackend {
322 fn dimensions(&self) -> usize {
323 self.dimensions
324 }
325
326 fn upsert(&self, id: &MemoryId, embedding: &[f32]) -> Result<()> {
327 self.block_on(self.upsert_async(id, embedding))
328 }
329
330 fn remove(&self, id: &MemoryId) -> Result<bool> {
331 self.block_on(self.remove_async(id))
332 }
333
334 fn search(
335 &self,
336 query_embedding: &[f32],
337 filter: &VectorFilter,
338 limit: usize,
339 ) -> Result<Vec<(MemoryId, f32)>> {
340 self.block_on(self.search_async(query_embedding, filter, limit))
341 }
342
343 fn count(&self) -> Result<usize> {
344 self.block_on(self.count_async())
345 }
346
347 fn clear(&self) -> Result<()> {
348 self.block_on(self.clear_async())
349 }
350 }
351}
352
353#[cfg(feature = "postgres")]
354pub use implementation::PgvectorBackend;
355
356pub use crate::embedding::DEFAULT_DIMENSIONS;
358
359#[cfg(not(feature = "postgres"))]
360mod stub {
361 use crate::embedding::DEFAULT_DIMENSIONS;
362 use crate::models::MemoryId;
363 use crate::storage::traits::{VectorBackend, VectorFilter};
364 use crate::{Error, Result};
365
366 pub struct PgvectorBackend {
368 connection_url: String,
370 table_name: String,
372 dimensions: usize,
374 }
375
376 impl PgvectorBackend {
377 #[must_use]
379 pub fn new(
380 connection_url: impl Into<String>,
381 table_name: impl Into<String>,
382 dimensions: usize,
383 ) -> Self {
384 Self {
385 connection_url: connection_url.into(),
386 table_name: table_name.into(),
387 dimensions,
388 }
389 }
390
391 #[must_use]
393 pub fn with_defaults() -> Self {
394 Self::new(
395 "postgresql://localhost/subcog",
396 "memory_vectors",
397 DEFAULT_DIMENSIONS,
398 )
399 }
400 }
401
402 impl VectorBackend for PgvectorBackend {
403 fn dimensions(&self) -> usize {
404 self.dimensions
405 }
406
407 fn upsert(&self, _id: &MemoryId, _embedding: &[f32]) -> Result<()> {
408 Err(Error::NotImplemented(format!(
409 "PgvectorBackend::upsert for {} on {}",
410 self.table_name, self.connection_url
411 )))
412 }
413
414 fn remove(&self, _id: &MemoryId) -> Result<bool> {
415 Err(Error::NotImplemented(format!(
416 "PgvectorBackend::remove for {} on {}",
417 self.table_name, self.connection_url
418 )))
419 }
420
421 fn search(
422 &self,
423 _query_embedding: &[f32],
424 _filter: &VectorFilter,
425 _limit: usize,
426 ) -> Result<Vec<(MemoryId, f32)>> {
427 Err(Error::NotImplemented(format!(
428 "PgvectorBackend::search for {} on {}",
429 self.table_name, self.connection_url
430 )))
431 }
432
433 fn count(&self) -> Result<usize> {
434 Err(Error::NotImplemented(format!(
435 "PgvectorBackend::count for {} on {}",
436 self.table_name, self.connection_url
437 )))
438 }
439
440 fn clear(&self) -> Result<()> {
441 Err(Error::NotImplemented(format!(
442 "PgvectorBackend::clear for {} on {}",
443 self.table_name, self.connection_url
444 )))
445 }
446 }
447}
448
449#[cfg(not(feature = "postgres"))]
450pub use stub::PgvectorBackend;