Skip to main content

subcog/storage/prompt/
redis.rs

1//! Redis-based prompt storage using `RediSearch`.
2//!
3//! Stores prompts in Redis with full-text search support via `RediSearch`.
4
5#[cfg(feature = "redis")]
6mod implementation {
7    use crate::models::PromptTemplate;
8    use crate::storage::prompt::PromptStorage;
9    use crate::{Error, Result};
10    use redis::{Client, Commands, Connection};
11
12    /// Redis-based prompt storage using `RediSearch`.
13    pub struct RedisPromptStorage {
14        /// Redis client.
15        client: Client,
16        /// Index name for prompts.
17        index_name: String,
18    }
19
20    impl RedisPromptStorage {
21        /// Creates a new Redis prompt storage.
22        ///
23        /// # Errors
24        ///
25        /// Returns an error if the Redis connection fails.
26        pub fn new(connection_url: &str, index_name: impl Into<String>) -> Result<Self> {
27            let client = Client::open(connection_url).map_err(|e| Error::OperationFailed {
28                operation: "redis_connect".to_string(),
29                cause: e.to_string(),
30            })?;
31
32            let storage = Self {
33                client,
34                index_name: index_name.into(),
35            };
36
37            // Ensure index exists
38            storage.ensure_index()?;
39
40            Ok(storage)
41        }
42
43        /// Creates a storage with default settings.
44        ///
45        /// # Errors
46        ///
47        /// Returns an error if the Redis connection fails.
48        pub fn with_defaults() -> Result<Self> {
49            Self::new("redis://localhost:6379", "subcog_prompts")
50        }
51
52        /// Gets a connection from the client.
53        fn get_connection(&self) -> Result<Connection> {
54            self.client
55                .get_connection()
56                .map_err(|e| Error::OperationFailed {
57                    operation: "redis_get_connection".to_string(),
58                    cause: e.to_string(),
59                })
60        }
61
62        /// Ensures the `RediSearch` index exists.
63        fn ensure_index(&self) -> Result<()> {
64            let mut conn = self.get_connection()?;
65
66            // Check if index exists
67            if self.index_exists(&mut conn) {
68                return Ok(());
69            }
70
71            // Create the index with schema
72            self.create_index(&mut conn)
73        }
74
75        /// Checks if the index already exists.
76        fn index_exists(&self, conn: &mut Connection) -> bool {
77            let result: redis::RedisResult<Vec<String>> = redis::cmd("FT._LIST").query(conn);
78
79            result
80                .map(|indices| indices.iter().any(|i| i == &self.index_name))
81                .unwrap_or(false)
82        }
83
84        /// Creates the `RediSearch` index for prompts.
85        fn create_index(&self, conn: &mut Connection) -> Result<()> {
86            let result: redis::RedisResult<String> = redis::cmd("FT.CREATE")
87                .arg(&self.index_name)
88                .arg("ON")
89                .arg("HASH")
90                .arg("PREFIX")
91                .arg(1)
92                .arg("prompt:")
93                .arg("SCHEMA")
94                .arg("name")
95                .arg("TEXT")
96                .arg("WEIGHT")
97                .arg("2.0")
98                .arg("description")
99                .arg("TEXT")
100                .arg("WEIGHT")
101                .arg("1.0")
102                .arg("content")
103                .arg("TEXT")
104                .arg("WEIGHT")
105                .arg("0.5")
106                .arg("tags")
107                .arg("TAG")
108                .arg("author")
109                .arg("TAG")
110                .arg("usage_count")
111                .arg("NUMERIC")
112                .arg("SORTABLE")
113                .arg("created_at")
114                .arg("NUMERIC")
115                .arg("SORTABLE")
116                .arg("updated_at")
117                .arg("NUMERIC")
118                .arg("SORTABLE")
119                .query(conn);
120
121            match result {
122                Ok(_) => Ok(()),
123                Err(e) if e.to_string().contains("Index already exists") => Ok(()),
124                Err(e) => Err(Error::OperationFailed {
125                    operation: "redis_create_prompt_index".to_string(),
126                    cause: e.to_string(),
127                }),
128            }
129        }
130
131        /// Builds the Redis key for a prompt.
132        fn prompt_key(name: &str) -> String {
133            format!("prompt:{name}")
134        }
135
136        /// Serializes a prompt to Redis hash fields.
137        fn serialize_prompt(template: &PromptTemplate) -> Vec<(String, String)> {
138            let variables_json = serde_json::to_string(&template.variables).unwrap_or_default();
139            let tags_str = template.tags.join(",");
140
141            vec![
142                ("name".to_string(), template.name.clone()),
143                ("description".to_string(), template.description.clone()),
144                ("content".to_string(), template.content.clone()),
145                ("variables".to_string(), variables_json),
146                ("tags".to_string(), tags_str),
147                (
148                    "author".to_string(),
149                    template.author.clone().unwrap_or_default(),
150                ),
151                ("usage_count".to_string(), template.usage_count.to_string()),
152                ("created_at".to_string(), template.created_at.to_string()),
153                ("updated_at".to_string(), template.updated_at.to_string()),
154            ]
155        }
156
157        /// Deserializes a prompt from Redis hash fields.
158        fn deserialize_prompt(
159            fields: &std::collections::HashMap<String, String>,
160        ) -> Option<PromptTemplate> {
161            let name = fields.get("name")?.clone();
162            let description = fields.get("description").cloned().unwrap_or_default();
163            let content = fields.get("content").cloned().unwrap_or_default();
164            let variables_json = fields.get("variables").cloned().unwrap_or_default();
165            let tags_str = fields.get("tags").cloned().unwrap_or_default();
166            let author = fields.get("author").cloned().filter(|s| !s.is_empty());
167            let usage_count: u64 = fields
168                .get("usage_count")
169                .and_then(|s| s.parse().ok())
170                .unwrap_or(0);
171            let created_at: u64 = fields
172                .get("created_at")
173                .and_then(|s| s.parse().ok())
174                .unwrap_or(0);
175            let updated_at: u64 = fields
176                .get("updated_at")
177                .and_then(|s| s.parse().ok())
178                .unwrap_or(0);
179
180            let variables = serde_json::from_str(&variables_json).unwrap_or_default();
181            let tags: Vec<String> = tags_str
182                .split(',')
183                .map(|s| s.trim().to_string())
184                .filter(|s| !s.is_empty())
185                .collect();
186
187            Some(PromptTemplate {
188                name,
189                description,
190                content,
191                variables,
192                tags,
193                author,
194                usage_count,
195                created_at,
196                updated_at,
197            })
198        }
199    }
200
201    impl PromptStorage for RedisPromptStorage {
202        fn save(&self, template: &PromptTemplate) -> Result<String> {
203            let mut conn = self.get_connection()?;
204            let key = Self::prompt_key(&template.name);
205
206            // Serialize and store as hash
207            let fields = Self::serialize_prompt(template);
208            let field_refs: Vec<(&str, &str)> = fields
209                .iter()
210                .map(|(k, v)| (k.as_str(), v.as_str()))
211                .collect();
212
213            let _: () =
214                conn.hset_multiple(&key, &field_refs)
215                    .map_err(|e| Error::OperationFailed {
216                        operation: "redis_save_prompt".to_string(),
217                        cause: e.to_string(),
218                    })?;
219
220            Ok(format!("prompt_redis_{}", template.name))
221        }
222
223        fn get(&self, name: &str) -> Result<Option<PromptTemplate>> {
224            let mut conn = self.get_connection()?;
225            let key = Self::prompt_key(name);
226
227            // Get all fields from hash
228            let result: redis::RedisResult<std::collections::HashMap<String, String>> =
229                conn.hgetall(&key);
230
231            match result {
232                Ok(fields) if fields.is_empty() => Ok(None),
233                Ok(fields) => Ok(Self::deserialize_prompt(&fields)),
234                Err(e) => Err(Error::OperationFailed {
235                    operation: "redis_get_prompt".to_string(),
236                    cause: e.to_string(),
237                }),
238            }
239        }
240
241        fn list(
242            &self,
243            tags: Option<&[String]>,
244            name_pattern: Option<&str>,
245        ) -> Result<Vec<PromptTemplate>> {
246            let mut conn = self.get_connection()?;
247
248            // Build query
249            let mut query_parts = vec!["*".to_string()];
250
251            // Add tag filter
252            if let Some(tag_list) = tags.filter(|t| !t.is_empty()) {
253                query_parts.extend(tag_list.iter().map(|tag| format!("@tags:{{{tag}}}")));
254            }
255
256            // Add name pattern filter
257            if let Some(pattern) = name_pattern {
258                query_parts.push(format!("@name:{pattern}"));
259            }
260
261            let query = if query_parts.len() == 1 {
262                "*".to_string()
263            } else {
264                query_parts[1..].join(" ")
265            };
266
267            // FT.SEARCH idx "query" LIMIT 0 1000 SORTBY usage_count DESC
268            let result: redis::RedisResult<Vec<redis::Value>> = redis::cmd("FT.SEARCH")
269                .arg(&self.index_name)
270                .arg(&query)
271                .arg("LIMIT")
272                .arg(0)
273                .arg(1000)
274                .arg("SORTBY")
275                .arg("usage_count")
276                .arg("DESC")
277                .query(&mut conn);
278
279            match result {
280                Ok(values) => Ok(self.parse_search_results(&values)),
281                Err(e) => Err(Error::OperationFailed {
282                    operation: "redis_list_prompts".to_string(),
283                    cause: e.to_string(),
284                }),
285            }
286        }
287
288        fn delete(&self, name: &str) -> Result<bool> {
289            let mut conn = self.get_connection()?;
290            let key = Self::prompt_key(name);
291
292            let deleted: i32 = conn.del(&key).map_err(|e| Error::OperationFailed {
293                operation: "redis_delete_prompt".to_string(),
294                cause: e.to_string(),
295            })?;
296
297            Ok(deleted > 0)
298        }
299
300        #[allow(clippy::cast_sign_loss)]
301        fn increment_usage(&self, name: &str) -> Result<u64> {
302            let mut conn = self.get_connection()?;
303            let key = Self::prompt_key(name);
304
305            // Increment usage_count field
306            let count: i64 =
307                conn.hincr(&key, "usage_count", 1)
308                    .map_err(|e| Error::OperationFailed {
309                        operation: "redis_increment_usage".to_string(),
310                        cause: e.to_string(),
311                    })?;
312
313            // Update updated_at timestamp
314            let now = std::time::SystemTime::now()
315                .duration_since(std::time::UNIX_EPOCH)
316                .map(|d| d.as_secs())
317                .unwrap_or(0);
318
319            let _: () = conn
320                .hset(&key, "updated_at", now)
321                .map_err(|e| Error::OperationFailed {
322                    operation: "redis_update_timestamp".to_string(),
323                    cause: e.to_string(),
324                })?;
325
326            Ok(count as u64)
327        }
328    }
329
330    impl RedisPromptStorage {
331        /// Parses FT.SEARCH results into prompt templates.
332        #[allow(clippy::excessive_nesting)]
333        fn parse_search_results(&self, values: &[redis::Value]) -> Vec<PromptTemplate> {
334            if values.is_empty() {
335                return Vec::new();
336            }
337
338            let mut results = Vec::new();
339            let mut i = 1; // Skip count
340            while i < values.len() {
341                // Skip key
342                i += 1;
343
344                // Parse fields array and collect template if valid
345                if let Some(t) = self.try_parse_value_at(values, i) {
346                    results.push(t);
347                }
348                i += 1;
349            }
350            results
351        }
352
353        /// Tries to parse a value at the given index as a template.
354        fn try_parse_value_at(
355            &self,
356            values: &[redis::Value],
357            idx: usize,
358        ) -> Option<PromptTemplate> {
359            let value = values.get(idx)?;
360            match value {
361                redis::Value::Array(fields) => self.parse_field_array(fields),
362                _ => None,
363            }
364        }
365
366        /// Parses a field array from FT.SEARCH results.
367        #[allow(clippy::excessive_nesting)]
368        fn parse_field_array(&self, fields: &[redis::Value]) -> Option<PromptTemplate> {
369            let mut map = std::collections::HashMap::new();
370
371            for pair in fields.chunks(2) {
372                if let [
373                    redis::Value::BulkString(key),
374                    redis::Value::BulkString(value),
375                ] = pair
376                {
377                    let key_str = String::from_utf8_lossy(key).to_string();
378                    let value_str = String::from_utf8_lossy(value).to_string();
379                    map.insert(key_str, value_str);
380                }
381            }
382
383            Self::deserialize_prompt(&map)
384        }
385    }
386}
387
388#[cfg(feature = "redis")]
389pub use implementation::RedisPromptStorage;
390
391#[cfg(not(feature = "redis"))]
392mod stub {
393    use crate::models::PromptTemplate;
394    use crate::storage::prompt::PromptStorage;
395    use crate::{Error, Result};
396
397    /// Stub Redis prompt storage when feature is not enabled.
398    pub struct RedisPromptStorage;
399
400    impl RedisPromptStorage {
401        /// Creates a new Redis prompt storage (stub).
402        ///
403        /// # Errors
404        ///
405        /// Always returns an error because the feature is not enabled.
406        pub fn new(_connection_url: &str, _index_name: impl Into<String>) -> Result<Self> {
407            Err(Error::FeatureNotEnabled("redis".to_string()))
408        }
409
410        /// Creates a storage with default settings (stub).
411        ///
412        /// # Errors
413        ///
414        /// Always returns an error because the feature is not enabled.
415        pub fn with_defaults() -> Result<Self> {
416            Err(Error::FeatureNotEnabled("redis".to_string()))
417        }
418    }
419
420    impl PromptStorage for RedisPromptStorage {
421        fn save(&self, _template: &PromptTemplate) -> Result<String> {
422            Err(Error::FeatureNotEnabled("redis".to_string()))
423        }
424
425        fn get(&self, _name: &str) -> Result<Option<PromptTemplate>> {
426            Err(Error::FeatureNotEnabled("redis".to_string()))
427        }
428
429        fn list(
430            &self,
431            _tags: Option<&[String]>,
432            _name_pattern: Option<&str>,
433        ) -> Result<Vec<PromptTemplate>> {
434            Err(Error::FeatureNotEnabled("redis".to_string()))
435        }
436
437        fn delete(&self, _name: &str) -> Result<bool> {
438            Err(Error::FeatureNotEnabled("redis".to_string()))
439        }
440
441        fn increment_usage(&self, _name: &str) -> Result<u64> {
442            Err(Error::FeatureNotEnabled("redis".to_string()))
443        }
444    }
445}
446
447#[cfg(not(feature = "redis"))]
448pub use stub::RedisPromptStorage;