1#![allow(
6 clippy::needless_pass_by_value,
7 clippy::cast_precision_loss,
8 clippy::unused_self,
9 clippy::unnecessary_wraps
10)]
11
12use crate::io::formats::{Format, create_import_source};
13use crate::io::traits::ImportSource;
14use crate::io::validation::{ImportValidator, ValidationSeverity};
15use crate::models::{Domain, Namespace};
16use crate::services::CaptureService;
17use crate::services::deduplication::ContentHasher;
18use crate::{Error, Result};
19use std::io::BufRead;
20use std::path::Path;
21use std::sync::Arc;
22
23#[derive(Debug, Clone)]
25pub struct ImportOptions {
26 pub format: Format,
28 pub default_namespace: Namespace,
30 pub default_domain: Domain,
32 pub skip_duplicates: bool,
34 pub skip_invalid: bool,
36 pub dry_run: bool,
38}
39
40impl Default for ImportOptions {
41 fn default() -> Self {
42 Self {
43 format: Format::Json,
44 default_namespace: Namespace::Decisions,
45 default_domain: Domain::new(),
46 skip_duplicates: true,
47 skip_invalid: true,
48 dry_run: false,
49 }
50 }
51}
52
53impl ImportOptions {
54 #[must_use]
56 pub const fn with_format(mut self, format: Format) -> Self {
57 self.format = format;
58 self
59 }
60
61 #[must_use]
63 pub const fn with_default_namespace(mut self, namespace: Namespace) -> Self {
64 self.default_namespace = namespace;
65 self
66 }
67
68 #[must_use]
70 pub fn with_default_domain(mut self, domain: Domain) -> Self {
71 self.default_domain = domain;
72 self
73 }
74
75 #[must_use]
77 pub const fn with_skip_duplicates(mut self, skip: bool) -> Self {
78 self.skip_duplicates = skip;
79 self
80 }
81
82 #[must_use]
84 pub const fn with_dry_run(mut self, dry_run: bool) -> Self {
85 self.dry_run = dry_run;
86 self
87 }
88}
89
90pub type ProgressCallback = Box<dyn Fn(&ImportProgress) + Send>;
92
93#[derive(Debug, Clone, Default)]
95pub struct ImportProgress {
96 pub processed: usize,
98 pub imported: usize,
100 pub skipped_duplicates: usize,
102 pub skipped_invalid: usize,
104 pub total_estimate: Option<usize>,
106 pub current: usize,
108}
109
110impl ImportProgress {
111 #[must_use]
113 pub fn percent_complete(&self) -> Option<f32> {
114 self.total_estimate.map(|total| {
115 if total == 0 {
116 100.0
117 } else {
118 (self.processed as f32 / total as f32) * 100.0
119 }
120 })
121 }
122}
123
124#[derive(Debug, Clone)]
126pub struct ImportResult {
127 pub imported: usize,
129 pub skipped_duplicates: usize,
131 pub skipped_invalid: usize,
133 pub total_processed: usize,
135 pub warnings: Vec<String>,
137 pub errors: Vec<String>,
139}
140
141impl ImportResult {
142 #[must_use]
144 pub const fn new() -> Self {
145 Self {
146 imported: 0,
147 skipped_duplicates: 0,
148 skipped_invalid: 0,
149 total_processed: 0,
150 warnings: Vec::new(),
151 errors: Vec::new(),
152 }
153 }
154
155 #[must_use]
157 pub const fn has_imports(&self) -> bool {
158 self.imported > 0
159 }
160
161 #[must_use]
163 pub const fn has_errors(&self) -> bool {
164 !self.errors.is_empty()
165 }
166}
167
168impl Default for ImportResult {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174pub struct ImportService {
176 capture_service: Arc<CaptureService>,
178}
179
180impl ImportService {
181 #[must_use]
183 pub const fn new(capture_service: Arc<CaptureService>) -> Self {
184 Self { capture_service }
185 }
186
187 pub fn import_from_file(
193 &self,
194 path: &Path,
195 options: ImportOptions,
196 progress: Option<ProgressCallback>,
197 ) -> Result<ImportResult> {
198 let format = if options.format == Format::Json {
199 Format::from_path(path).unwrap_or(Format::Json)
201 } else {
202 options.format
203 };
204
205 let file = std::fs::File::open(path).map_err(|e| Error::OperationFailed {
206 operation: "open_import_file".to_string(),
207 cause: e.to_string(),
208 })?;
209 let reader = std::io::BufReader::new(file);
210
211 self.import_from_reader(reader, options.with_format(format), progress)
212 }
213
214 pub fn import_from_reader<R: BufRead + 'static>(
220 &self,
221 reader: R,
222 options: ImportOptions,
223 progress: Option<ProgressCallback>,
224 ) -> Result<ImportResult> {
225 let mut source = create_import_source(reader, options.format)?;
226 self.import_from_source(source.as_mut(), &options, progress)
227 }
228
229 #[allow(clippy::excessive_nesting)]
235 pub fn import_from_source(
236 &self,
237 source: &mut dyn ImportSource,
238 options: &ImportOptions,
239 progress: Option<ProgressCallback>,
240 ) -> Result<ImportResult> {
241 let validator = ImportValidator::new()
242 .with_default_namespace(options.default_namespace)
243 .with_default_domain(options.default_domain.clone());
244
245 let mut result = ImportResult::new();
246 let mut prog = ImportProgress {
247 total_estimate: source.size_hint(),
248 ..Default::default()
249 };
250
251 let mut seen_hashes = std::collections::HashSet::new();
253
254 while let Some(imported) = source.next()? {
255 prog.current += 1;
256 prog.processed += 1;
257 result.total_processed += 1;
258
259 let validation = validator.validate(&imported);
261
262 for issue in &validation.issues {
264 if issue.severity == ValidationSeverity::Warning {
265 result.warnings.push(format!(
266 "Record {}: {}: {}",
267 prog.current, issue.field, issue.message
268 ));
269 }
270 }
271
272 if !validation.is_valid {
274 if options.skip_invalid {
275 prog.skipped_invalid += 1;
276 result.skipped_invalid += 1;
277 for issue in &validation.issues {
278 if issue.severity == ValidationSeverity::Error {
279 result.errors.push(format!(
280 "Record {}: {}: {}",
281 prog.current, issue.field, issue.message
282 ));
283 }
284 }
285 if let Some(ref cb) = progress {
286 cb(&prog);
287 }
288 continue;
289 }
290 return Err(Error::InvalidInput(format!(
291 "Record {}: validation failed",
292 prog.current
293 )));
294 }
295
296 let content_hash = ContentHasher::hash(&imported.content);
298 if options.skip_duplicates {
299 if seen_hashes.contains(&content_hash) {
301 prog.skipped_duplicates += 1;
302 result.skipped_duplicates += 1;
303 if let Some(ref cb) = progress {
304 cb(&prog);
305 }
306 continue;
307 }
308
309 let hash_tag = ContentHasher::content_to_tag(&imported.content);
311 if self.memory_exists_with_tag(&hash_tag)? {
312 prog.skipped_duplicates += 1;
313 result.skipped_duplicates += 1;
314 seen_hashes.insert(content_hash);
315 if let Some(ref cb) = progress {
316 cb(&prog);
317 }
318 continue;
319 }
320
321 seen_hashes.insert(content_hash);
322 }
323
324 if options.dry_run {
326 prog.imported += 1;
328 result.imported += 1;
329 } else {
330 let request = validator.to_capture_request(imported);
331 match self.capture_service.capture(request) {
332 Ok(_) => {
333 prog.imported += 1;
334 result.imported += 1;
335 },
336 Err(e) => {
337 if options.skip_invalid {
338 result
339 .errors
340 .push(format!("Record {}: capture failed: {}", prog.current, e));
341 prog.skipped_invalid += 1;
342 result.skipped_invalid += 1;
343 } else {
344 return Err(e);
345 }
346 },
347 }
348 }
349
350 if let Some(ref cb) = progress {
351 cb(&prog);
352 }
353 }
354
355 Ok(result)
356 }
357
358 const fn memory_exists_with_tag(&self, _hash_tag: &str) -> Result<bool> {
360 Ok(false)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::config::Config;
371 use std::io::Cursor;
372
373 fn test_capture_service() -> Arc<CaptureService> {
374 Arc::new(CaptureService::new(Config::default()))
375 }
376
377 #[test]
378 fn test_import_options_defaults() {
379 let options = ImportOptions::default();
380 assert_eq!(options.format, Format::Json);
381 assert!(options.skip_duplicates);
382 assert!(options.skip_invalid);
383 assert!(!options.dry_run);
384 }
385
386 #[test]
387 fn test_import_progress_percent() {
388 let progress = ImportProgress {
389 processed: 50,
390 total_estimate: Some(100),
391 ..Default::default()
392 };
393 assert_eq!(progress.percent_complete(), Some(50.0));
394
395 let unknown = ImportProgress::default();
396 assert!(unknown.percent_complete().is_none());
397 }
398
399 #[test]
400 fn test_import_result_has_imports() {
401 let mut result = ImportResult::new();
402 assert!(!result.has_imports());
403
404 result.imported = 1;
405 assert!(result.has_imports());
406 }
407
408 #[test]
409 fn test_dry_run_import() {
410 let service = ImportService::new(test_capture_service());
411 let input = r#"{"content": "Test memory"}"#;
412
413 let result = service
414 .import_from_reader(
415 Cursor::new(input),
416 ImportOptions::default().with_dry_run(true),
417 None,
418 )
419 .unwrap();
420
421 assert_eq!(result.imported, 1);
422 assert_eq!(result.total_processed, 1);
423 }
424
425 #[test]
426 fn test_import_with_invalid_record() {
427 let service = ImportService::new(test_capture_service());
428 let input = r#"{"content": ""}
430{"content": "Valid memory"}"#;
431
432 let result = service
433 .import_from_reader(
434 Cursor::new(input),
435 ImportOptions::default().with_dry_run(true),
436 None,
437 )
438 .unwrap();
439
440 assert_eq!(result.skipped_invalid, 1);
441 assert_eq!(result.imported, 1);
442 }
443}