subcog/storage/
migrations.rs1#[cfg(feature = "postgres")]
24#[allow(clippy::excessive_nesting)]
25mod implementation {
26 use crate::{Error, Result};
27 use deadpool_postgres::Pool;
28
29 #[derive(Debug, Clone, Copy)]
31 pub struct Migration {
32 pub version: i32,
34 pub description: &'static str,
36 pub sql: &'static str,
39 }
40
41 pub struct MigrationRunner {
43 pool: Pool,
44 table_name: String,
45 }
46
47 impl MigrationRunner {
48 #[must_use]
50 pub fn new(pool: Pool, table_name: impl Into<String>) -> Self {
51 Self {
52 pool,
53 table_name: table_name.into(),
54 }
55 }
56
57 #[must_use]
59 pub fn table_name(&self) -> &str {
60 &self.table_name
61 }
62
63 pub async fn run(&self, migrations: &[Migration]) -> Result<()> {
69 let mut client = self.pool.get().await.map_err(|e| Error::OperationFailed {
70 operation: "migration_get_connection".to_string(),
71 cause: e.to_string(),
72 })?;
73
74 self.ensure_migrations_table(&client).await?;
76
77 let current_version = self.get_current_version(&client).await?;
79
80 for migration in migrations {
82 if migration.version > current_version {
83 self.apply_migration(&mut client, migration).await?;
84 }
85 }
86
87 Ok(())
88 }
89
90 pub async fn current_version(&self) -> Result<i32> {
96 let client = self.pool.get().await.map_err(|e| Error::OperationFailed {
97 operation: "migration_get_connection".to_string(),
98 cause: e.to_string(),
99 })?;
100
101 let migrations_table = self.migrations_table_name();
103 let exists = self.table_exists(&client, &migrations_table).await?;
104
105 if !exists {
106 return Ok(0);
107 }
108
109 self.get_current_version(&client).await
110 }
111
112 fn migrations_table_name(&self) -> String {
114 format!("{}_schema_migrations", self.table_name)
115 }
116
117 async fn ensure_migrations_table(&self, client: &deadpool_postgres::Object) -> Result<()> {
119 let migrations_table = self.migrations_table_name();
120
121 let sql = format!(
122 r"
123 CREATE TABLE IF NOT EXISTS {migrations_table} (
124 version INTEGER PRIMARY KEY,
125 description TEXT NOT NULL,
126 applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
127 )
128 "
129 );
130
131 client
132 .execute(&sql, &[])
133 .await
134 .map_err(|e| Error::OperationFailed {
135 operation: "create_migrations_table".to_string(),
136 cause: e.to_string(),
137 })?;
138
139 Ok(())
140 }
141
142 async fn table_exists(
144 &self,
145 client: &deadpool_postgres::Object,
146 table_name: &str,
147 ) -> Result<bool> {
148 let sql = r"
149 SELECT EXISTS (
150 SELECT FROM information_schema.tables
151 WHERE table_name = $1
152 )
153 ";
154
155 let exists: bool = client
156 .query_one(sql, &[&table_name])
157 .await
158 .map(|row| row.get(0))
159 .unwrap_or(false);
160
161 Ok(exists)
162 }
163
164 async fn get_current_version(&self, client: &deadpool_postgres::Object) -> Result<i32> {
166 let migrations_table = self.migrations_table_name();
167 let sql = format!("SELECT COALESCE(MAX(version), 0) FROM {migrations_table}");
168
169 let version: i32 = client
170 .query_one(&sql, &[])
171 .await
172 .map(|row| row.get(0))
173 .unwrap_or(0);
174
175 Ok(version)
176 }
177
178 async fn apply_migration(
187 &self,
188 client: &mut deadpool_postgres::Object,
189 migration: &Migration,
190 ) -> Result<()> {
191 let migrations_table = self.migrations_table_name();
192
193 let sql = migration.sql.replace("{table}", &self.table_name);
195
196 let tx = client
198 .transaction()
199 .await
200 .map_err(|e| Error::OperationFailed {
201 operation: format!("migration_v{}_begin_tx", migration.version),
202 cause: e.to_string(),
203 })?;
204
205 for statement in sql.split(';') {
207 let statement = statement.trim();
208 if statement.is_empty() {
209 continue;
210 }
211
212 tx.execute(statement, &[])
213 .await
214 .map_err(|e| Error::OperationFailed {
215 operation: format!(
216 "migration_v{}: {}",
217 migration.version, migration.description
218 ),
219 cause: e.to_string(),
220 })?;
221 }
222
223 let record_sql =
225 format!("INSERT INTO {migrations_table} (version, description) VALUES ($1, $2)");
226
227 tx.execute(&record_sql, &[&migration.version, &migration.description])
228 .await
229 .map_err(|e| Error::OperationFailed {
230 operation: "record_migration".to_string(),
231 cause: e.to_string(),
232 })?;
233
234 tx.commit().await.map_err(|e| Error::OperationFailed {
236 operation: format!("migration_v{}_commit", migration.version),
237 cause: e.to_string(),
238 })?;
239
240 tracing::info!(
241 version = migration.version,
242 description = migration.description,
243 table = self.table_name,
244 "Applied migration"
245 );
246
247 Ok(())
248 }
249 }
250
251 #[must_use]
253 pub fn max_version(migrations: &[Migration]) -> i32 {
254 migrations.iter().map(|m| m.version).max().unwrap_or(0)
255 }
256}
257
258#[cfg(feature = "postgres")]
259pub use implementation::{Migration, MigrationRunner, max_version};
260
261#[cfg(not(feature = "postgres"))]
262mod stub {
263 #[derive(Debug, Clone, Copy)]
265 pub struct Migration {
266 pub version: i32,
268 pub description: &'static str,
270 pub sql: &'static str,
272 }
273
274 #[must_use]
276 pub const fn max_version(migrations: &[Migration]) -> i32 {
277 let mut max = 0;
278 let mut i = 0;
279 while i < migrations.len() {
280 if migrations[i].version > max {
281 max = migrations[i].version;
282 }
283 i += 1;
284 }
285 max
286 }
287}
288
289#[cfg(not(feature = "postgres"))]
290pub use stub::{Migration, max_version};