Skip to main content

subcog/storage/
migrations.rs

1//! PostgreSQL migration system for schema management.
2//!
3//! Provides a compile-time embedded migration system that automatically
4//! upgrades database schemas when the application starts.
5//!
6//! # Usage
7//!
8//! ```rust,ignore
9//! use subcog::storage::migrations::{Migration, MigrationRunner};
10//!
11//! const MIGRATIONS: &[Migration] = &[
12//!     Migration {
13//!         version: 1,
14//!         description: "Initial table",
15//!         sql: "CREATE TABLE IF NOT EXISTS {table} (id SERIAL PRIMARY KEY);",
16//!     },
17//! ];
18//!
19//! let runner = MigrationRunner::new(pool, "my_table");
20//! runner.run(MIGRATIONS).await?;
21//! ```
22
23#[cfg(feature = "postgres")]
24#[allow(clippy::excessive_nesting)]
25mod implementation {
26    use crate::{Error, Result};
27    use deadpool_postgres::Pool;
28
29    /// A single migration with version and SQL.
30    #[derive(Debug, Clone, Copy)]
31    pub struct Migration {
32        /// Migration version (sequential, starting at 1).
33        pub version: i32,
34        /// Human-readable description.
35        pub description: &'static str,
36        /// SQL to apply (may contain multiple statements separated by semicolons).
37        /// Use `{table}` as a placeholder for the table name.
38        pub sql: &'static str,
39    }
40
41    /// Runs migrations for a PostgreSQL table.
42    pub struct MigrationRunner {
43        pool: Pool,
44        table_name: String,
45    }
46
47    impl MigrationRunner {
48        /// Creates a new migration runner.
49        #[must_use]
50        pub fn new(pool: Pool, table_name: impl Into<String>) -> Self {
51            Self {
52                pool,
53                table_name: table_name.into(),
54            }
55        }
56
57        /// Returns the table name.
58        #[must_use]
59        pub fn table_name(&self) -> &str {
60            &self.table_name
61        }
62
63        /// Runs all pending migrations.
64        ///
65        /// # Errors
66        ///
67        /// Returns an error if a migration fails.
68        pub async fn run(&self, migrations: &[Migration]) -> Result<()> {
69            let mut client = self.pool.get().await.map_err(|e| Error::OperationFailed {
70                operation: "migration_get_connection".to_string(),
71                cause: e.to_string(),
72            })?;
73
74            // Ensure migrations tracking table exists
75            self.ensure_migrations_table(&client).await?;
76
77            // Get current version
78            let current_version = self.get_current_version(&client).await?;
79
80            // Apply pending migrations
81            for migration in migrations {
82                if migration.version > current_version {
83                    self.apply_migration(&mut client, migration).await?;
84                }
85            }
86
87            Ok(())
88        }
89
90        /// Returns the current schema version.
91        ///
92        /// # Errors
93        ///
94        /// Returns an error if the database cannot be queried.
95        pub async fn current_version(&self) -> Result<i32> {
96            let client = self.pool.get().await.map_err(|e| Error::OperationFailed {
97                operation: "migration_get_connection".to_string(),
98                cause: e.to_string(),
99            })?;
100
101            // Check if migrations table exists first
102            let migrations_table = self.migrations_table_name();
103            let exists = self.table_exists(&client, &migrations_table).await?;
104
105            if !exists {
106                return Ok(0);
107            }
108
109            self.get_current_version(&client).await
110        }
111
112        /// Returns the name of the migrations tracking table.
113        fn migrations_table_name(&self) -> String {
114            format!("{}_schema_migrations", self.table_name)
115        }
116
117        /// Ensures the `schema_migrations` table exists.
118        async fn ensure_migrations_table(&self, client: &deadpool_postgres::Object) -> Result<()> {
119            let migrations_table = self.migrations_table_name();
120
121            let sql = format!(
122                r"
123                CREATE TABLE IF NOT EXISTS {migrations_table} (
124                    version INTEGER PRIMARY KEY,
125                    description TEXT NOT NULL,
126                    applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
127                )
128                "
129            );
130
131            client
132                .execute(&sql, &[])
133                .await
134                .map_err(|e| Error::OperationFailed {
135                    operation: "create_migrations_table".to_string(),
136                    cause: e.to_string(),
137                })?;
138
139            Ok(())
140        }
141
142        /// Checks if a table exists.
143        async fn table_exists(
144            &self,
145            client: &deadpool_postgres::Object,
146            table_name: &str,
147        ) -> Result<bool> {
148            let sql = r"
149                SELECT EXISTS (
150                    SELECT FROM information_schema.tables
151                    WHERE table_name = $1
152                )
153            ";
154
155            let exists: bool = client
156                .query_one(sql, &[&table_name])
157                .await
158                .map(|row| row.get(0))
159                .unwrap_or(false);
160
161            Ok(exists)
162        }
163
164        /// Gets the current schema version.
165        async fn get_current_version(&self, client: &deadpool_postgres::Object) -> Result<i32> {
166            let migrations_table = self.migrations_table_name();
167            let sql = format!("SELECT COALESCE(MAX(version), 0) FROM {migrations_table}");
168
169            let version: i32 = client
170                .query_one(&sql, &[])
171                .await
172                .map(|row| row.get(0))
173                .unwrap_or(0);
174
175            Ok(version)
176        }
177
178        /// Applies a single migration within a transaction.
179        ///
180        /// # Transaction Safety (CRIT-001)
181        ///
182        /// All migration statements and the version record are executed within
183        /// a single transaction. If any statement fails, the entire migration
184        /// is rolled back, preventing partial schema updates that could leave
185        /// the database in an inconsistent state.
186        async fn apply_migration(
187            &self,
188            client: &mut deadpool_postgres::Object,
189            migration: &Migration,
190        ) -> Result<()> {
191            let migrations_table = self.migrations_table_name();
192
193            // Replace {table} placeholder with actual table name
194            let sql = migration.sql.replace("{table}", &self.table_name);
195
196            // Start transaction for atomic migration application
197            let tx = client
198                .transaction()
199                .await
200                .map_err(|e| Error::OperationFailed {
201                    operation: format!("migration_v{}_begin_tx", migration.version),
202                    cause: e.to_string(),
203                })?;
204
205            // Execute all statements within the transaction
206            for statement in sql.split(';') {
207                let statement = statement.trim();
208                if statement.is_empty() {
209                    continue;
210                }
211
212                tx.execute(statement, &[])
213                    .await
214                    .map_err(|e| Error::OperationFailed {
215                        operation: format!(
216                            "migration_v{}: {}",
217                            migration.version, migration.description
218                        ),
219                        cause: e.to_string(),
220                    })?;
221            }
222
223            // Record the migration within the same transaction
224            let record_sql =
225                format!("INSERT INTO {migrations_table} (version, description) VALUES ($1, $2)");
226
227            tx.execute(&record_sql, &[&migration.version, &migration.description])
228                .await
229                .map_err(|e| Error::OperationFailed {
230                    operation: "record_migration".to_string(),
231                    cause: e.to_string(),
232                })?;
233
234            // Commit the transaction - all statements succeed or none do
235            tx.commit().await.map_err(|e| Error::OperationFailed {
236                operation: format!("migration_v{}_commit", migration.version),
237                cause: e.to_string(),
238            })?;
239
240            tracing::info!(
241                version = migration.version,
242                description = migration.description,
243                table = self.table_name,
244                "Applied migration"
245            );
246
247            Ok(())
248        }
249    }
250
251    /// Maximum version across a set of migrations.
252    #[must_use]
253    pub fn max_version(migrations: &[Migration]) -> i32 {
254        migrations.iter().map(|m| m.version).max().unwrap_or(0)
255    }
256}
257
258#[cfg(feature = "postgres")]
259pub use implementation::{Migration, MigrationRunner, max_version};
260
261#[cfg(not(feature = "postgres"))]
262mod stub {
263    /// A single migration with version and SQL (stub).
264    #[derive(Debug, Clone, Copy)]
265    pub struct Migration {
266        /// Migration version.
267        pub version: i32,
268        /// Human-readable description.
269        pub description: &'static str,
270        /// SQL to apply.
271        pub sql: &'static str,
272    }
273
274    /// Maximum version across a set of migrations.
275    #[must_use]
276    pub const fn max_version(migrations: &[Migration]) -> i32 {
277        let mut max = 0;
278        let mut i = 0;
279        while i < migrations.len() {
280            if migrations[i].version > max {
281                max = migrations[i].version;
282            }
283            i += 1;
284        }
285        max
286    }
287}
288
289#[cfg(not(feature = "postgres"))]
290pub use stub::{Migration, max_version};