Skip to main content

subcog/storage/index/
postgresql.rs

1//! PostgreSQL-based index backend.
2//!
3//! Provides full-text search using PostgreSQL's built-in tsvector/tsquery.
4//!
5//! # TLS Support (COMP-C3)
6//!
7//! Enable the `postgres-tls` feature for encrypted connections:
8//!
9//! ```toml
10//! [dependencies]
11//! subcog = { version = "0.1", features = ["postgres-tls"] }
12//! ```
13//!
14//! Then use a connection URL with `sslmode=require`:
15//! ```text
16//! postgresql://user:pass@host:5432/db?sslmode=require
17//! ```
18
19#[cfg(feature = "postgres")]
20mod implementation {
21    use crate::models::{Memory, MemoryId, SearchFilter};
22    use crate::storage::migrations::{Migration, MigrationRunner};
23    use crate::storage::traits::IndexBackend;
24    use crate::{Error, Result};
25    use deadpool_postgres::{Config, Pool, Runtime};
26    use tokio::runtime::Handle;
27
28    #[cfg(not(feature = "postgres-tls"))]
29    use tokio_postgres::NoTls;
30
31    #[cfg(feature = "postgres-tls")]
32    use tokio_postgres_rustls::MakeRustlsConnect;
33
34    /// Embedded migrations compiled into the binary.
35    const MIGRATIONS: &[Migration] = &[
36        Migration {
37            version: 1,
38            description: "Initial index table with FTS",
39            sql: r"
40                CREATE TABLE IF NOT EXISTS {table} (
41                    id TEXT PRIMARY KEY,
42                    content TEXT NOT NULL,
43                    namespace TEXT NOT NULL,
44                    domain TEXT NOT NULL,
45                    status TEXT NOT NULL,
46                    tags TEXT[] DEFAULT '{}',
47                    created_at BIGINT NOT NULL,
48                    updated_at BIGINT NOT NULL,
49                    search_vector TSVECTOR GENERATED ALWAYS AS (
50                        setweight(to_tsvector('english', coalesce(content, '')), 'A') ||
51                        setweight(to_tsvector('english', coalesce(array_to_string(tags, ' '), '')), 'B')
52                    ) STORED
53                );
54            ",
55        },
56        Migration {
57            version: 2,
58            description: "Add GIN index on search_vector",
59            sql: r"
60                CREATE INDEX IF NOT EXISTS {table}_search_idx ON {table} USING GIN (search_vector);
61            ",
62        },
63        Migration {
64            version: 3,
65            description: "Add namespace and updated_at indexes",
66            sql: r"
67                CREATE INDEX IF NOT EXISTS {table}_namespace_idx ON {table} (namespace);
68                CREATE INDEX IF NOT EXISTS {table}_updated_idx ON {table} (updated_at DESC);
69            ",
70        },
71        Migration {
72            version: 4,
73            description: "Add status and created_at indexes",
74            sql: r"
75                CREATE INDEX IF NOT EXISTS {table}_status_idx ON {table} (status);
76                CREATE INDEX IF NOT EXISTS {table}_created_idx ON {table} (created_at DESC);
77            ",
78        },
79        Migration {
80            version: 5,
81            description: "Add facet columns (ADR-0048/0049)",
82            sql: r"
83                ALTER TABLE {table} ADD COLUMN IF NOT EXISTS project_id TEXT;
84                ALTER TABLE {table} ADD COLUMN IF NOT EXISTS branch TEXT;
85                ALTER TABLE {table} ADD COLUMN IF NOT EXISTS file_path TEXT;
86                CREATE INDEX IF NOT EXISTS {table}_project_idx ON {table} (project_id);
87                CREATE INDEX IF NOT EXISTS {table}_project_branch_idx ON {table} (project_id, branch);
88                CREATE INDEX IF NOT EXISTS {table}_file_path_idx ON {table} (file_path);
89            ",
90        },
91    ];
92
93    /// Allowed table names for SQL injection prevention.
94    const ALLOWED_TABLE_NAMES: &[&str] = &[
95        "memories_index",
96        "memories",
97        "subcog_memories",
98        "subcog_index",
99    ];
100
101    /// Validates that a table name is in the whitelist.
102    fn validate_table_name(name: &str) -> Result<()> {
103        if ALLOWED_TABLE_NAMES.contains(&name) {
104            Ok(())
105        } else {
106            Err(Error::InvalidInput(format!(
107                "Table name '{name}' is not allowed. Allowed names: {ALLOWED_TABLE_NAMES:?}",
108            )))
109        }
110    }
111
112    /// Validates PostgreSQL connection URL format (SEC-M2).
113    ///
114    /// Prevents connection string injection by validating:
115    /// - URL scheme is `postgresql://` or `postgres://`
116    /// - Host contains only valid characters (alphanumeric, `.`, `-`, `_`)
117    /// - Database name contains only valid characters
118    /// - No dangerous URL parameters that could alter connection behavior
119    ///
120    /// # Errors
121    ///
122    /// Returns `Error::InvalidInput` if the connection URL is invalid or contains
123    /// potentially dangerous parameters.
124    fn validate_connection_url(url_str: &str) -> Result<()> {
125        // Check scheme
126        if !url_str.starts_with("postgresql://") && !url_str.starts_with("postgres://") {
127            return Err(Error::InvalidInput(
128                "Connection URL must start with postgresql:// or postgres://".to_string(),
129            ));
130        }
131
132        // Parse URL to validate components using reqwest's re-exported url crate
133        let parsed = reqwest::Url::parse(url_str)
134            .map_err(|e| Error::InvalidInput(format!("Invalid connection URL format: {e}")))?;
135
136        // Validate host (prevent injection via malformed hostnames)
137        if let Some(host) = parsed.host_str() {
138            let is_valid_host = host
139                .chars()
140                .all(|c: char| c.is_ascii_alphanumeric() || c == '.' || c == '-' || c == '_');
141            if !is_valid_host {
142                tracing::warn!(
143                    host = host,
144                    "PostgreSQL connection URL contains suspicious host characters"
145                );
146                return Err(Error::InvalidInput(
147                    "Connection URL host contains invalid characters".to_string(),
148                ));
149            }
150        }
151
152        // Validate database name if present
153        if let Some(path) = parsed.path().strip_prefix('/') {
154            let is_valid_db = path.is_empty()
155                || path
156                    .chars()
157                    .all(|c: char| c.is_ascii_alphanumeric() || c == '_' || c == '-');
158            if !is_valid_db {
159                tracing::warn!(
160                    database = path,
161                    "PostgreSQL connection URL contains suspicious database name"
162                );
163                return Err(Error::InvalidInput(
164                    "Connection URL database name contains invalid characters".to_string(),
165                ));
166            }
167        }
168
169        // Block dangerous connection parameters that could alter behavior
170        let dangerous_params = ["host", "hostaddr", "client_encoding", "options"];
171        for (key, _) in parsed.query_pairs() {
172            if dangerous_params.contains(&key.as_ref()) {
173                tracing::warn!(
174                    param = key.as_ref(),
175                    "PostgreSQL connection URL contains blocked parameter"
176                );
177                return Err(Error::InvalidInput(format!(
178                    "Connection URL parameter '{key}' is not allowed in query string"
179                )));
180            }
181        }
182
183        Ok(())
184    }
185
186    /// PostgreSQL-based index backend.
187    ///
188    /// Uses `deadpool_postgres::Pool` for thread-safe connection pooling,
189    /// enabling `&self` methods without interior mutability wrappers.
190    pub struct PostgresIndexBackend {
191        /// Connection pool (thread-safe via internal `Arc`).
192        pool: Pool,
193        /// Table name for memories (validated against whitelist).
194        table_name: String,
195    }
196
197    /// Helper to map pool errors.
198    fn pool_error(e: impl std::fmt::Display) -> Error {
199        Error::OperationFailed {
200            operation: "postgres_get_client".to_string(),
201            cause: e.to_string(),
202        }
203    }
204
205    /// Helper to map query errors.
206    fn query_error(op: &str, e: impl std::fmt::Display) -> Error {
207        Error::OperationFailed {
208            operation: op.to_string(),
209            cause: e.to_string(),
210        }
211    }
212
213    impl PostgresIndexBackend {
214        /// Creates a new PostgreSQL index backend.
215        ///
216        /// # TLS Support (COMP-C3)
217        ///
218        /// When the `postgres-tls` feature is enabled, connections use TLS by default.
219        /// For production, use a connection URL with `sslmode=require`:
220        /// ```text
221        /// postgresql://user:pass@host:5432/db?sslmode=require
222        /// ```
223        ///
224        /// # Errors
225        ///
226        /// Returns an error if the connection pool fails to initialize
227        /// or if the table name is not in the allowed whitelist.
228        #[cfg(not(feature = "postgres-tls"))]
229        pub fn new(connection_url: &str, table_name: impl Into<String>) -> Result<Self> {
230            let table_name = table_name.into();
231
232            // Validate table name against whitelist to prevent SQL injection
233            validate_table_name(&table_name)?;
234
235            let config = Self::parse_connection_url(connection_url)?;
236            let cfg = Self::build_pool_config(&config);
237
238            let pool = cfg.create_pool(Some(Runtime::Tokio1), NoTls).map_err(|e| {
239                Error::OperationFailed {
240                    operation: "postgres_create_pool".to_string(),
241                    cause: e.to_string(),
242                }
243            })?;
244
245            let backend = Self { pool, table_name };
246            backend.run_migrations()?;
247            Ok(backend)
248        }
249
250        /// Creates a new PostgreSQL index backend with TLS encryption (COMP-C3).
251        ///
252        /// Uses rustls for TLS connections. The connection URL should include
253        /// `sslmode=require` or `sslmode=verify-full` for production use.
254        ///
255        /// # Errors
256        ///
257        /// Returns an error if the connection pool fails to initialize,
258        /// if TLS configuration fails, or if the table name is not allowed.
259        #[cfg(feature = "postgres-tls")]
260        pub fn new(connection_url: &str, table_name: impl Into<String>) -> Result<Self> {
261            let table_name = table_name.into();
262
263            // Validate table name against whitelist to prevent SQL injection
264            validate_table_name(&table_name)?;
265
266            let config = Self::parse_connection_url(connection_url)?;
267            let cfg = Self::build_pool_config(&config);
268
269            // Build TLS connector with rustls
270            let tls_config = rustls::ClientConfig::builder()
271                .with_root_certificates(Self::root_cert_store())
272                .with_no_client_auth();
273
274            let tls = MakeRustlsConnect::new(tls_config);
275
276            let pool = cfg.create_pool(Some(Runtime::Tokio1), tls).map_err(|e| {
277                Error::OperationFailed {
278                    operation: "postgres_create_pool_tls".to_string(),
279                    cause: e.to_string(),
280                }
281            })?;
282
283            let backend = Self { pool, table_name };
284            backend.run_migrations()?;
285            Ok(backend)
286        }
287
288        /// Builds root certificate store for TLS.
289        #[cfg(feature = "postgres-tls")]
290        fn root_cert_store() -> rustls::RootCertStore {
291            let mut roots = rustls::RootCertStore::empty();
292
293            // Try to load system certificates
294            #[cfg(feature = "postgres-tls")]
295            {
296                // Use webpki-roots for portable certificate bundle
297                roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
298            }
299
300            roots
301        }
302
303        /// Parses the connection URL into a tokio-postgres config (SEC-M2).
304        ///
305        /// Validates the URL for security before parsing to prevent injection attacks.
306        fn parse_connection_url(url: &str) -> Result<tokio_postgres::Config> {
307            // Validate URL format and block dangerous parameters (SEC-M2)
308            validate_connection_url(url)?;
309
310            url.parse::<tokio_postgres::Config>()
311                .map_err(|e| Error::OperationFailed {
312                    operation: "postgres_parse_url".to_string(),
313                    cause: e.to_string(),
314                })
315        }
316
317        /// Extracts host string from tokio-postgres Host.
318        #[cfg(unix)]
319        fn host_to_string(h: &tokio_postgres::config::Host) -> String {
320            match h {
321                tokio_postgres::config::Host::Tcp(s) => s.clone(),
322                tokio_postgres::config::Host::Unix(p) => p.to_string_lossy().to_string(),
323            }
324        }
325
326        /// Extracts host string from tokio-postgres Host (Windows: Tcp only).
327        #[cfg(not(unix))]
328        fn host_to_string(h: &tokio_postgres::config::Host) -> String {
329            let tokio_postgres::config::Host::Tcp(s) = h;
330            s.clone()
331        }
332
333        /// Maximum connections in pool (CHAOS-H1).
334        const POOL_MAX_SIZE: usize = 20;
335
336        /// Builds a deadpool config from tokio-postgres config.
337        ///
338        /// # Pool Exhaustion Protection (CHAOS-H1)
339        ///
340        /// Configures connection pool with safety limits:
341        /// - Max 20 connections (prevents pool exhaustion)
342        /// - Runtime pool builder sets timeouts for wait/create/recycle
343        ///
344        /// # Statement Caching (DB-H4)
345        ///
346        /// Statement caching is handled automatically by `tokio-postgres` connections.
347        /// Each connection maintains its own prepared statement cache. The
348        /// `RecyclingMethod::Fast` setting preserves connections (and their statement
349        /// caches) across uses, providing implicit statement caching without
350        /// additional configuration.
351        fn build_pool_config(config: &tokio_postgres::Config) -> Config {
352            let mut cfg = Config::new();
353            cfg.host = config.get_hosts().first().map(Self::host_to_string);
354            cfg.port = config.get_ports().first().copied();
355            cfg.user = config.get_user().map(String::from);
356            cfg.password = config
357                .get_password()
358                .map(|p| String::from_utf8_lossy(p).to_string());
359            cfg.dbname = config.get_dbname().map(String::from);
360
361            // Pool exhaustion protection (CHAOS-H1)
362            cfg.pool = Some(deadpool_postgres::PoolConfig {
363                max_size: Self::POOL_MAX_SIZE,
364                ..Default::default()
365            });
366
367            // Configure manager with fast recycling for statement cache reuse
368            cfg.manager = Some(deadpool_postgres::ManagerConfig {
369                recycling_method: deadpool_postgres::RecyclingMethod::Fast,
370            });
371
372            cfg
373        }
374
375        /// Creates a backend with default settings.
376        ///
377        /// # Errors
378        ///
379        /// Returns an error if the connection fails.
380        pub fn with_defaults() -> Result<Self> {
381            Self::new("postgresql://localhost/subcog", "memories_index")
382        }
383
384        /// Runs a blocking operation on the async pool.
385        fn block_on<F, T>(&self, f: F) -> Result<T>
386        where
387            F: std::future::Future<Output = Result<T>>,
388        {
389            if let Ok(handle) = Handle::try_current() {
390                handle.block_on(f)
391            } else {
392                let rt = tokio::runtime::Builder::new_current_thread()
393                    .enable_all()
394                    .build()
395                    .map_err(|e| Error::OperationFailed {
396                        operation: "postgres_create_runtime".to_string(),
397                        cause: e.to_string(),
398                    })?;
399                rt.block_on(f)
400            }
401        }
402
403        /// Runs migrations.
404        fn run_migrations(&self) -> Result<()> {
405            self.block_on(async {
406                let runner = MigrationRunner::new(self.pool.clone(), &self.table_name);
407                runner.run(MIGRATIONS).await
408            })
409        }
410
411        /// Builds WHERE clause for filters.
412        fn build_where_clause(filter: &SearchFilter, start_param: i32) -> (String, Vec<String>) {
413            let mut clauses = Vec::new();
414            let mut params = Vec::new();
415            let mut param_num = start_param;
416
417            Self::add_namespace_filter(filter, &mut clauses, &mut params, &mut param_num);
418            Self::add_domain_filter(filter, &mut clauses, &mut params, &mut param_num);
419            Self::add_project_filter(filter, &mut clauses, &mut params, &mut param_num);
420            Self::add_branch_filter(filter, &mut clauses, &mut params, &mut param_num);
421            Self::add_file_path_filter(filter, &mut clauses, &mut params, &mut param_num);
422            Self::add_status_filter(filter, &mut clauses, &mut params, &mut param_num);
423
424            let clause = if clauses.is_empty() {
425                String::new()
426            } else {
427                format!(" AND {}", clauses.join(" AND "))
428            };
429
430            (clause, params)
431        }
432
433        /// Adds namespace filter to WHERE clause.
434        fn add_namespace_filter(
435            filter: &SearchFilter,
436            clauses: &mut Vec<String>,
437            params: &mut Vec<String>,
438            param_num: &mut i32,
439        ) {
440            if filter.namespaces.is_empty() {
441                return;
442            }
443            let placeholders: Vec<String> = filter
444                .namespaces
445                .iter()
446                .map(|_| {
447                    let p = format!("${param_num}");
448                    *param_num += 1;
449                    p
450                })
451                .collect();
452            clauses.push(format!("namespace IN ({})", placeholders.join(", ")));
453            for ns in &filter.namespaces {
454                params.push(ns.as_str().to_string());
455            }
456        }
457
458        /// Adds domain filter to WHERE clause.
459        fn add_domain_filter(
460            filter: &SearchFilter,
461            clauses: &mut Vec<String>,
462            params: &mut Vec<String>,
463            param_num: &mut i32,
464        ) {
465            if filter.domains.is_empty() {
466                return;
467            }
468            let placeholders: Vec<String> = filter
469                .domains
470                .iter()
471                .map(|_| {
472                    let p = format!("${param_num}");
473                    *param_num += 1;
474                    p
475                })
476                .collect();
477            clauses.push(format!("domain IN ({})", placeholders.join(", ")));
478            for d in &filter.domains {
479                params.push(d.to_string());
480            }
481        }
482
483        /// Adds status filter to WHERE clause.
484        fn add_status_filter(
485            filter: &SearchFilter,
486            clauses: &mut Vec<String>,
487            params: &mut Vec<String>,
488            param_num: &mut i32,
489        ) {
490            if filter.statuses.is_empty() {
491                return;
492            }
493            let placeholders: Vec<String> = filter
494                .statuses
495                .iter()
496                .map(|_| {
497                    let p = format!("${param_num}");
498                    *param_num += 1;
499                    p
500                })
501                .collect();
502            clauses.push(format!("status IN ({})", placeholders.join(", ")));
503            for s in &filter.statuses {
504                params.push(s.as_str().to_string());
505            }
506        }
507
508        fn add_project_filter(
509            filter: &SearchFilter,
510            clauses: &mut Vec<String>,
511            params: &mut Vec<String>,
512            param_num: &mut i32,
513        ) {
514            let Some(project_id) = filter.project_id.as_ref() else {
515                return;
516            };
517            clauses.push(format!("project_id = ${param_num}"));
518            *param_num += 1;
519            params.push(project_id.clone());
520        }
521
522        fn add_branch_filter(
523            filter: &SearchFilter,
524            clauses: &mut Vec<String>,
525            params: &mut Vec<String>,
526            param_num: &mut i32,
527        ) {
528            let Some(branch) = filter.branch.as_ref() else {
529                return;
530            };
531            clauses.push(format!("branch = ${param_num}"));
532            *param_num += 1;
533            params.push(branch.clone());
534        }
535
536        fn add_file_path_filter(
537            filter: &SearchFilter,
538            clauses: &mut Vec<String>,
539            params: &mut Vec<String>,
540            param_num: &mut i32,
541        ) {
542            let Some(file_path) = filter.file_path.as_ref() else {
543                return;
544            };
545            clauses.push(format!("file_path = ${param_num}"));
546            *param_num += 1;
547            params.push(file_path.clone());
548        }
549
550        /// Async implementation of index operation.
551        #[allow(clippy::cast_possible_wrap)]
552        async fn index_async(&self, memory: &Memory) -> Result<()> {
553            let client = self.pool.get().await.map_err(pool_error)?;
554
555            let upsert = format!(
556                r"INSERT INTO {} (id, content, namespace, domain, project_id, branch, file_path, status, tags, created_at, updated_at)
557                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
558                ON CONFLICT (id) DO UPDATE SET
559                    content = EXCLUDED.content,
560                    namespace = EXCLUDED.namespace,
561                    domain = EXCLUDED.domain,
562                    project_id = EXCLUDED.project_id,
563                    branch = EXCLUDED.branch,
564                    file_path = EXCLUDED.file_path,
565                    status = EXCLUDED.status,
566                    tags = EXCLUDED.tags,
567                    updated_at = EXCLUDED.updated_at",
568                self.table_name
569            );
570
571            let tags: Vec<&str> = memory.tags.iter().map(String::as_str).collect();
572            let domain_str = memory.domain.to_string();
573            let namespace_str = memory.namespace.as_str();
574            let status_str = memory.status.as_str();
575
576            client
577                .execute(
578                    &upsert,
579                    &[
580                        &memory.id.as_str(),
581                        &memory.content,
582                        &namespace_str,
583                        &domain_str,
584                        &memory.project_id,
585                        &memory.branch,
586                        &memory.file_path,
587                        &status_str,
588                        &tags,
589                        &(memory.created_at as i64),
590                        &(memory.updated_at as i64),
591                    ],
592                )
593                .await
594                .map_err(|e| query_error("postgres_index", e))?;
595
596            Ok(())
597        }
598
599        /// Async implementation of remove operation.
600        async fn remove_async(&self, id: &MemoryId) -> Result<bool> {
601            let client = self.pool.get().await.map_err(pool_error)?;
602            let delete = format!("DELETE FROM {} WHERE id = $1", self.table_name);
603            let rows = client
604                .execute(&delete, &[&id.as_str()])
605                .await
606                .map_err(|e| query_error("postgres_remove", e))?;
607            Ok(rows > 0)
608        }
609
610        /// Async implementation of search operation.
611        async fn search_async(
612            &self,
613            query: &str,
614            filter: &SearchFilter,
615            limit: usize,
616        ) -> Result<Vec<(MemoryId, f32)>> {
617            let client = self.pool.get().await.map_err(pool_error)?;
618            let (filter_clause, filter_params) = Self::build_where_clause(filter, 2);
619
620            let search_query = format!(
621                r"SELECT id, ts_rank(search_vector, websearch_to_tsquery('english', $1)) as score
622                FROM {}
623                WHERE search_vector @@ websearch_to_tsquery('english', $1)
624                {}
625                ORDER BY score DESC
626                LIMIT {}",
627                self.table_name, filter_clause, limit
628            );
629
630            let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = Vec::new();
631            params.push(&query);
632            for p in &filter_params {
633                params.push(p);
634            }
635
636            let rows = client
637                .query(&search_query, &params)
638                .await
639                .map_err(|e| query_error("postgres_search", e))?;
640
641            Ok(rows
642                .iter()
643                .map(|row| {
644                    let id: String = row.get(0);
645                    let score: f32 = row.get(1);
646                    (MemoryId::new(&id), score)
647                })
648                .collect())
649        }
650
651        /// Async implementation of `list_all` operation.
652        async fn list_all_async(
653            &self,
654            filter: &SearchFilter,
655            limit: usize,
656        ) -> Result<Vec<(MemoryId, f32)>> {
657            let client = self.pool.get().await.map_err(pool_error)?;
658            let (filter_clause, filter_params) = Self::build_where_clause(filter, 1);
659
660            let where_prefix = if filter_clause.is_empty() {
661                String::new()
662            } else {
663                format!("WHERE {}", filter_clause.trim_start_matches(" AND "))
664            };
665
666            let list_query = format!(
667                r"SELECT id, 1.0::real as score
668                FROM {}
669                {}
670                ORDER BY updated_at DESC
671                LIMIT {}",
672                self.table_name, where_prefix, limit
673            );
674
675            let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
676                filter_params.iter().map(|p| p as _).collect();
677
678            let rows = client
679                .query(&list_query, &params)
680                .await
681                .map_err(|e| query_error("postgres_list_all", e))?;
682
683            Ok(rows
684                .iter()
685                .map(|row| {
686                    let id: String = row.get(0);
687                    let score: f32 = row.get(1);
688                    (MemoryId::new(&id), score)
689                })
690                .collect())
691        }
692
693        /// Async implementation of clear operation.
694        async fn clear_async(&self) -> Result<()> {
695            let client = self.pool.get().await.map_err(pool_error)?;
696            let truncate = format!("TRUNCATE TABLE {}", self.table_name);
697            client
698                .execute(&truncate, &[])
699                .await
700                .map_err(|e| query_error("postgres_clear", e))?;
701            Ok(())
702        }
703    }
704
705    impl IndexBackend for PostgresIndexBackend {
706        fn index(&self, memory: &Memory) -> Result<()> {
707            self.block_on(self.index_async(memory))
708        }
709
710        fn remove(&self, id: &MemoryId) -> Result<bool> {
711            self.block_on(self.remove_async(id))
712        }
713
714        fn search(
715            &self,
716            query: &str,
717            filter: &SearchFilter,
718            limit: usize,
719        ) -> Result<Vec<(MemoryId, f32)>> {
720            self.block_on(self.search_async(query, filter, limit))
721        }
722
723        fn list_all(&self, filter: &SearchFilter, limit: usize) -> Result<Vec<(MemoryId, f32)>> {
724            self.block_on(self.list_all_async(filter, limit))
725        }
726
727        fn get_memory(&self, _id: &MemoryId) -> Result<Option<Memory>> {
728            // Index backend stores minimal data for search, not full memories
729            Ok(None)
730        }
731
732        fn clear(&self) -> Result<()> {
733            self.block_on(self.clear_async())
734        }
735    }
736
737    #[cfg(test)]
738    mod tests {
739        use super::*;
740
741        #[test]
742        fn test_validate_connection_url_valid() {
743            // Valid PostgreSQL URLs
744            assert!(validate_connection_url("postgresql://localhost/mydb").is_ok());
745            assert!(validate_connection_url("postgres://user:pass@localhost:5432/mydb").is_ok());
746            assert!(
747                validate_connection_url(
748                    "postgresql://user:pass@db.example.com:5432/mydb?sslmode=require"
749                )
750                .is_ok()
751            );
752            assert!(validate_connection_url("postgresql://localhost/my_db-test").is_ok());
753        }
754
755        #[test]
756        fn test_validate_connection_url_invalid_scheme() {
757            // Invalid scheme
758            assert!(validate_connection_url("mysql://localhost/mydb").is_err());
759            assert!(validate_connection_url("http://localhost/mydb").is_err());
760            assert!(validate_connection_url("localhost/mydb").is_err());
761        }
762
763        #[test]
764        fn test_validate_connection_url_invalid_host() {
765            // Invalid host characters (injection attempts)
766            assert!(validate_connection_url("postgresql://local<script>host/mydb").is_err());
767            assert!(validate_connection_url("postgresql://host;drop table/mydb").is_err());
768        }
769
770        #[test]
771        fn test_validate_connection_url_invalid_database() {
772            // Invalid database name characters
773            assert!(validate_connection_url("postgresql://localhost/my;db").is_err());
774            assert!(validate_connection_url("postgresql://localhost/db<script>").is_err());
775        }
776
777        #[test]
778        fn test_validate_connection_url_blocked_params() {
779            // Blocked dangerous parameters
780            assert!(validate_connection_url("postgresql://localhost/mydb?host=evil.com").is_err());
781            assert!(
782                validate_connection_url("postgresql://localhost/mydb?hostaddr=1.2.3.4").is_err()
783            );
784            assert!(
785                validate_connection_url("postgresql://localhost/mydb?options=-c log_statement=all")
786                    .is_err()
787            );
788            assert!(
789                validate_connection_url("postgresql://localhost/mydb?client_encoding=SQL_ASCII")
790                    .is_err()
791            );
792        }
793
794        #[test]
795        fn test_validate_connection_url_allowed_params() {
796            // Allowed parameters should pass
797            assert!(validate_connection_url("postgresql://localhost/mydb?sslmode=require").is_ok());
798            assert!(
799                validate_connection_url(
800                    "postgresql://localhost/mydb?connect_timeout=10&application_name=subcog"
801                )
802                .is_ok()
803            );
804        }
805
806        #[test]
807        fn test_validate_table_name() {
808            // Valid table names
809            assert!(validate_table_name("memories_index").is_ok());
810            assert!(validate_table_name("subcog_memories").is_ok());
811
812            // Invalid table names
813            assert!(validate_table_name("users").is_err());
814            assert!(validate_table_name("memories_index; DROP TABLE users").is_err());
815        }
816    }
817}
818
819#[cfg(feature = "postgres")]
820pub use implementation::PostgresIndexBackend;
821
822#[cfg(not(feature = "postgres"))]
823mod stub {
824    use crate::models::{Memory, MemoryId, SearchFilter};
825    use crate::storage::traits::IndexBackend;
826    use crate::{Error, Result};
827
828    /// Stub PostgreSQL backend when feature is not enabled.
829    pub struct PostgresIndexBackend;
830
831    impl PostgresIndexBackend {
832        /// Creates a new PostgreSQL index backend (stub).
833        ///
834        /// # Errors
835        ///
836        /// Always returns an error because the feature is not enabled.
837        pub fn new(_connection_url: &str, _table_name: impl Into<String>) -> Result<Self> {
838            Err(Error::FeatureNotEnabled("postgres".to_string()))
839        }
840
841        /// Creates a backend with default settings (stub).
842        ///
843        /// # Errors
844        ///
845        /// Always returns an error because the feature is not enabled.
846        pub fn with_defaults() -> Result<Self> {
847            Err(Error::FeatureNotEnabled("postgres".to_string()))
848        }
849    }
850
851    impl IndexBackend for PostgresIndexBackend {
852        fn index(&self, _memory: &Memory) -> Result<()> {
853            Err(Error::FeatureNotEnabled("postgres".to_string()))
854        }
855
856        fn remove(&self, _id: &MemoryId) -> Result<bool> {
857            Err(Error::FeatureNotEnabled("postgres".to_string()))
858        }
859
860        fn search(
861            &self,
862            _query: &str,
863            _filter: &SearchFilter,
864            _limit: usize,
865        ) -> Result<Vec<(MemoryId, f32)>> {
866            Err(Error::FeatureNotEnabled("postgres".to_string()))
867        }
868
869        fn list_all(&self, _filter: &SearchFilter, _limit: usize) -> Result<Vec<(MemoryId, f32)>> {
870            Err(Error::FeatureNotEnabled("postgres".to_string()))
871        }
872
873        fn get_memory(&self, _id: &MemoryId) -> Result<Option<Memory>> {
874            Err(Error::FeatureNotEnabled("postgres".to_string()))
875        }
876
877        fn clear(&self) -> Result<()> {
878            Err(Error::FeatureNotEnabled("postgres".to_string()))
879        }
880    }
881}
882
883#[cfg(not(feature = "postgres"))]
884pub use stub::PostgresIndexBackend;