1#[cfg(feature = "redis")]
6mod implementation {
7 use crate::models::PromptTemplate;
8 use crate::storage::prompt::PromptStorage;
9 use crate::{Error, Result};
10 use redis::{Client, Commands, Connection};
11
12 pub struct RedisPromptStorage {
14 client: Client,
16 index_name: String,
18 }
19
20 impl RedisPromptStorage {
21 pub fn new(connection_url: &str, index_name: impl Into<String>) -> Result<Self> {
27 let client = Client::open(connection_url).map_err(|e| Error::OperationFailed {
28 operation: "redis_connect".to_string(),
29 cause: e.to_string(),
30 })?;
31
32 let storage = Self {
33 client,
34 index_name: index_name.into(),
35 };
36
37 storage.ensure_index()?;
39
40 Ok(storage)
41 }
42
43 pub fn with_defaults() -> Result<Self> {
49 Self::new("redis://localhost:6379", "subcog_prompts")
50 }
51
52 fn get_connection(&self) -> Result<Connection> {
54 self.client
55 .get_connection()
56 .map_err(|e| Error::OperationFailed {
57 operation: "redis_get_connection".to_string(),
58 cause: e.to_string(),
59 })
60 }
61
62 fn ensure_index(&self) -> Result<()> {
64 let mut conn = self.get_connection()?;
65
66 if self.index_exists(&mut conn) {
68 return Ok(());
69 }
70
71 self.create_index(&mut conn)
73 }
74
75 fn index_exists(&self, conn: &mut Connection) -> bool {
77 let result: redis::RedisResult<Vec<String>> = redis::cmd("FT._LIST").query(conn);
78
79 result
80 .map(|indices| indices.iter().any(|i| i == &self.index_name))
81 .unwrap_or(false)
82 }
83
84 fn create_index(&self, conn: &mut Connection) -> Result<()> {
86 let result: redis::RedisResult<String> = redis::cmd("FT.CREATE")
87 .arg(&self.index_name)
88 .arg("ON")
89 .arg("HASH")
90 .arg("PREFIX")
91 .arg(1)
92 .arg("prompt:")
93 .arg("SCHEMA")
94 .arg("name")
95 .arg("TEXT")
96 .arg("WEIGHT")
97 .arg("2.0")
98 .arg("description")
99 .arg("TEXT")
100 .arg("WEIGHT")
101 .arg("1.0")
102 .arg("content")
103 .arg("TEXT")
104 .arg("WEIGHT")
105 .arg("0.5")
106 .arg("tags")
107 .arg("TAG")
108 .arg("author")
109 .arg("TAG")
110 .arg("usage_count")
111 .arg("NUMERIC")
112 .arg("SORTABLE")
113 .arg("created_at")
114 .arg("NUMERIC")
115 .arg("SORTABLE")
116 .arg("updated_at")
117 .arg("NUMERIC")
118 .arg("SORTABLE")
119 .query(conn);
120
121 match result {
122 Ok(_) => Ok(()),
123 Err(e) if e.to_string().contains("Index already exists") => Ok(()),
124 Err(e) => Err(Error::OperationFailed {
125 operation: "redis_create_prompt_index".to_string(),
126 cause: e.to_string(),
127 }),
128 }
129 }
130
131 fn prompt_key(name: &str) -> String {
133 format!("prompt:{name}")
134 }
135
136 fn serialize_prompt(template: &PromptTemplate) -> Vec<(String, String)> {
138 let variables_json = serde_json::to_string(&template.variables).unwrap_or_default();
139 let tags_str = template.tags.join(",");
140
141 vec![
142 ("name".to_string(), template.name.clone()),
143 ("description".to_string(), template.description.clone()),
144 ("content".to_string(), template.content.clone()),
145 ("variables".to_string(), variables_json),
146 ("tags".to_string(), tags_str),
147 (
148 "author".to_string(),
149 template.author.clone().unwrap_or_default(),
150 ),
151 ("usage_count".to_string(), template.usage_count.to_string()),
152 ("created_at".to_string(), template.created_at.to_string()),
153 ("updated_at".to_string(), template.updated_at.to_string()),
154 ]
155 }
156
157 fn deserialize_prompt(
159 fields: &std::collections::HashMap<String, String>,
160 ) -> Option<PromptTemplate> {
161 let name = fields.get("name")?.clone();
162 let description = fields.get("description").cloned().unwrap_or_default();
163 let content = fields.get("content").cloned().unwrap_or_default();
164 let variables_json = fields.get("variables").cloned().unwrap_or_default();
165 let tags_str = fields.get("tags").cloned().unwrap_or_default();
166 let author = fields.get("author").cloned().filter(|s| !s.is_empty());
167 let usage_count: u64 = fields
168 .get("usage_count")
169 .and_then(|s| s.parse().ok())
170 .unwrap_or(0);
171 let created_at: u64 = fields
172 .get("created_at")
173 .and_then(|s| s.parse().ok())
174 .unwrap_or(0);
175 let updated_at: u64 = fields
176 .get("updated_at")
177 .and_then(|s| s.parse().ok())
178 .unwrap_or(0);
179
180 let variables = serde_json::from_str(&variables_json).unwrap_or_default();
181 let tags: Vec<String> = tags_str
182 .split(',')
183 .map(|s| s.trim().to_string())
184 .filter(|s| !s.is_empty())
185 .collect();
186
187 Some(PromptTemplate {
188 name,
189 description,
190 content,
191 variables,
192 tags,
193 author,
194 usage_count,
195 created_at,
196 updated_at,
197 })
198 }
199 }
200
201 impl PromptStorage for RedisPromptStorage {
202 fn save(&self, template: &PromptTemplate) -> Result<String> {
203 let mut conn = self.get_connection()?;
204 let key = Self::prompt_key(&template.name);
205
206 let fields = Self::serialize_prompt(template);
208 let field_refs: Vec<(&str, &str)> = fields
209 .iter()
210 .map(|(k, v)| (k.as_str(), v.as_str()))
211 .collect();
212
213 let _: () =
214 conn.hset_multiple(&key, &field_refs)
215 .map_err(|e| Error::OperationFailed {
216 operation: "redis_save_prompt".to_string(),
217 cause: e.to_string(),
218 })?;
219
220 Ok(format!("prompt_redis_{}", template.name))
221 }
222
223 fn get(&self, name: &str) -> Result<Option<PromptTemplate>> {
224 let mut conn = self.get_connection()?;
225 let key = Self::prompt_key(name);
226
227 let result: redis::RedisResult<std::collections::HashMap<String, String>> =
229 conn.hgetall(&key);
230
231 match result {
232 Ok(fields) if fields.is_empty() => Ok(None),
233 Ok(fields) => Ok(Self::deserialize_prompt(&fields)),
234 Err(e) => Err(Error::OperationFailed {
235 operation: "redis_get_prompt".to_string(),
236 cause: e.to_string(),
237 }),
238 }
239 }
240
241 fn list(
242 &self,
243 tags: Option<&[String]>,
244 name_pattern: Option<&str>,
245 ) -> Result<Vec<PromptTemplate>> {
246 let mut conn = self.get_connection()?;
247
248 let mut query_parts = vec!["*".to_string()];
250
251 if let Some(tag_list) = tags.filter(|t| !t.is_empty()) {
253 query_parts.extend(tag_list.iter().map(|tag| format!("@tags:{{{tag}}}")));
254 }
255
256 if let Some(pattern) = name_pattern {
258 query_parts.push(format!("@name:{pattern}"));
259 }
260
261 let query = if query_parts.len() == 1 {
262 "*".to_string()
263 } else {
264 query_parts[1..].join(" ")
265 };
266
267 let result: redis::RedisResult<Vec<redis::Value>> = redis::cmd("FT.SEARCH")
269 .arg(&self.index_name)
270 .arg(&query)
271 .arg("LIMIT")
272 .arg(0)
273 .arg(1000)
274 .arg("SORTBY")
275 .arg("usage_count")
276 .arg("DESC")
277 .query(&mut conn);
278
279 match result {
280 Ok(values) => Ok(self.parse_search_results(&values)),
281 Err(e) => Err(Error::OperationFailed {
282 operation: "redis_list_prompts".to_string(),
283 cause: e.to_string(),
284 }),
285 }
286 }
287
288 fn delete(&self, name: &str) -> Result<bool> {
289 let mut conn = self.get_connection()?;
290 let key = Self::prompt_key(name);
291
292 let deleted: i32 = conn.del(&key).map_err(|e| Error::OperationFailed {
293 operation: "redis_delete_prompt".to_string(),
294 cause: e.to_string(),
295 })?;
296
297 Ok(deleted > 0)
298 }
299
300 #[allow(clippy::cast_sign_loss)]
301 fn increment_usage(&self, name: &str) -> Result<u64> {
302 let mut conn = self.get_connection()?;
303 let key = Self::prompt_key(name);
304
305 let count: i64 =
307 conn.hincr(&key, "usage_count", 1)
308 .map_err(|e| Error::OperationFailed {
309 operation: "redis_increment_usage".to_string(),
310 cause: e.to_string(),
311 })?;
312
313 let now = std::time::SystemTime::now()
315 .duration_since(std::time::UNIX_EPOCH)
316 .map(|d| d.as_secs())
317 .unwrap_or(0);
318
319 let _: () = conn
320 .hset(&key, "updated_at", now)
321 .map_err(|e| Error::OperationFailed {
322 operation: "redis_update_timestamp".to_string(),
323 cause: e.to_string(),
324 })?;
325
326 Ok(count as u64)
327 }
328 }
329
330 impl RedisPromptStorage {
331 #[allow(clippy::excessive_nesting)]
333 fn parse_search_results(&self, values: &[redis::Value]) -> Vec<PromptTemplate> {
334 if values.is_empty() {
335 return Vec::new();
336 }
337
338 let mut results = Vec::new();
339 let mut i = 1; while i < values.len() {
341 i += 1;
343
344 if let Some(t) = self.try_parse_value_at(values, i) {
346 results.push(t);
347 }
348 i += 1;
349 }
350 results
351 }
352
353 fn try_parse_value_at(
355 &self,
356 values: &[redis::Value],
357 idx: usize,
358 ) -> Option<PromptTemplate> {
359 let value = values.get(idx)?;
360 match value {
361 redis::Value::Array(fields) => self.parse_field_array(fields),
362 _ => None,
363 }
364 }
365
366 #[allow(clippy::excessive_nesting)]
368 fn parse_field_array(&self, fields: &[redis::Value]) -> Option<PromptTemplate> {
369 let mut map = std::collections::HashMap::new();
370
371 for pair in fields.chunks(2) {
372 if let [
373 redis::Value::BulkString(key),
374 redis::Value::BulkString(value),
375 ] = pair
376 {
377 let key_str = String::from_utf8_lossy(key).to_string();
378 let value_str = String::from_utf8_lossy(value).to_string();
379 map.insert(key_str, value_str);
380 }
381 }
382
383 Self::deserialize_prompt(&map)
384 }
385 }
386}
387
388#[cfg(feature = "redis")]
389pub use implementation::RedisPromptStorage;
390
391#[cfg(not(feature = "redis"))]
392mod stub {
393 use crate::models::PromptTemplate;
394 use crate::storage::prompt::PromptStorage;
395 use crate::{Error, Result};
396
397 pub struct RedisPromptStorage;
399
400 impl RedisPromptStorage {
401 pub fn new(_connection_url: &str, _index_name: impl Into<String>) -> Result<Self> {
407 Err(Error::FeatureNotEnabled("redis".to_string()))
408 }
409
410 pub fn with_defaults() -> Result<Self> {
416 Err(Error::FeatureNotEnabled("redis".to_string()))
417 }
418 }
419
420 impl PromptStorage for RedisPromptStorage {
421 fn save(&self, _template: &PromptTemplate) -> Result<String> {
422 Err(Error::FeatureNotEnabled("redis".to_string()))
423 }
424
425 fn get(&self, _name: &str) -> Result<Option<PromptTemplate>> {
426 Err(Error::FeatureNotEnabled("redis".to_string()))
427 }
428
429 fn list(
430 &self,
431 _tags: Option<&[String]>,
432 _name_pattern: Option<&str>,
433 ) -> Result<Vec<PromptTemplate>> {
434 Err(Error::FeatureNotEnabled("redis".to_string()))
435 }
436
437 fn delete(&self, _name: &str) -> Result<bool> {
438 Err(Error::FeatureNotEnabled("redis".to_string()))
439 }
440
441 fn increment_usage(&self, _name: &str) -> Result<u64> {
442 Err(Error::FeatureNotEnabled("redis".to_string()))
443 }
444 }
445}
446
447#[cfg(not(feature = "redis"))]
448pub use stub::RedisPromptStorage;