1use super::HookHandler;
11use crate::Result;
12use crate::observability::current_request_id;
13use crate::services::{ContextBuilderService, MemoryStatistics};
14use std::fmt::Write;
15use std::time::{Duration, Instant};
16use tracing::instrument;
17
18const MIN_SESSION_ID_LENGTH: usize = 16;
20
21const MAX_SESSION_ID_LENGTH: usize = 256;
23
24const MIN_UNIQUE_CHARS: usize = 4;
26
27const MIN_SEQUENTIAL_RUN: usize = 8;
29
30const DEFAULT_CONTEXT_TIMEOUT_MS: u64 = 500;
32
33pub struct SessionStartHandler {
37 context_builder: Option<ContextBuilderService>,
39 max_context_tokens: usize,
41 guidance_level: GuidanceLevel,
43 context_timeout_ms: u64,
45}
46
47#[derive(Debug, Clone, Copy, Default)]
49pub enum GuidanceLevel {
50 Minimal,
52 #[default]
54 Standard,
55 Detailed,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum SessionIdValidation {
62 Valid,
64 TooShort,
66 TooLong,
68 LowEntropy,
70 Missing,
72}
73
74impl SessionIdValidation {
75 pub const fn description(self) -> &'static str {
77 match self {
78 Self::Valid => "valid",
79 Self::TooShort => "too short (minimum 16 characters)",
80 Self::TooLong => "too long (maximum 256 characters)",
81 Self::LowEntropy => "low entropy (predictable pattern detected)",
82 Self::Missing => "missing or empty",
83 }
84 }
85}
86
87pub fn validate_session_id(session_id: &str) -> SessionIdValidation {
101 if session_id.is_empty() || session_id == "unknown" {
103 return SessionIdValidation::Missing;
104 }
105
106 if session_id.len() < MIN_SESSION_ID_LENGTH {
108 return SessionIdValidation::TooShort;
109 }
110
111 if session_id.len() > MAX_SESSION_ID_LENGTH {
113 return SessionIdValidation::TooLong;
114 }
115
116 if has_low_entropy(session_id) {
118 return SessionIdValidation::LowEntropy;
119 }
120
121 SessionIdValidation::Valid
122}
123
124fn has_low_entropy(session_id: &str) -> bool {
126 let unique_chars: std::collections::HashSet<char> = session_id.chars().collect();
128 if unique_chars.len() < MIN_UNIQUE_CHARS {
129 return true;
130 }
131
132 let chars: Vec<char> = session_id.chars().collect();
134
135 if chars.iter().all(|&c| c == chars[0]) {
137 return true;
138 }
139
140 for pattern_len in 1..=4 {
142 if chars.len() >= pattern_len * 3 {
143 let pattern = &chars[..pattern_len];
144 let is_repeating = chars
145 .chunks(pattern_len)
146 .all(|chunk| chunk == pattern || chunk.len() < pattern_len);
147 if is_repeating {
148 return true;
149 }
150 }
151 }
152
153 if has_long_sequential_run(session_id) {
156 return true;
157 }
158
159 false
160}
161
162fn has_long_sequential_run(s: &str) -> bool {
168 if s.len() < MIN_SEQUENTIAL_RUN {
169 return false;
170 }
171
172 let chars: Vec<i32> = s
174 .chars()
175 .filter(char::is_ascii_alphanumeric)
176 .map(|c| c as i32)
177 .collect();
178
179 if chars.len() < MIN_SEQUENTIAL_RUN {
180 return false;
181 }
182
183 let mut ascending_run = 1;
185 for window in chars.windows(2) {
186 if window[1] == window[0] + 1 {
187 ascending_run += 1;
188 if ascending_run >= MIN_SEQUENTIAL_RUN {
189 return true;
190 }
191 } else {
192 ascending_run = 1;
193 }
194 }
195
196 let mut descending_run = 1;
198 for window in chars.windows(2) {
199 if window[0] == window[1] + 1 {
200 descending_run += 1;
201 if descending_run >= MIN_SEQUENTIAL_RUN {
202 return true;
203 }
204 } else {
205 descending_run = 1;
206 }
207 }
208
209 false
210}
211
212#[derive(Debug, Clone)]
214struct SessionContext {
215 content: String,
217 memory_count: usize,
219 token_estimate: usize,
221 was_truncated: bool,
223 statistics: Option<MemoryStatistics>,
225}
226
227impl SessionStartHandler {
228 #[must_use]
230 pub fn new() -> Self {
231 Self {
232 context_builder: None,
233 max_context_tokens: 2000,
234 guidance_level: GuidanceLevel::default(),
235 context_timeout_ms: DEFAULT_CONTEXT_TIMEOUT_MS,
236 }
237 }
238
239 #[must_use]
241 pub fn with_context_builder(mut self, builder: ContextBuilderService) -> Self {
242 self.context_builder = Some(builder);
243 self
244 }
245
246 #[must_use]
248 pub const fn with_max_tokens(mut self, tokens: usize) -> Self {
249 self.max_context_tokens = tokens;
250 self
251 }
252
253 #[must_use]
255 pub const fn with_guidance_level(mut self, level: GuidanceLevel) -> Self {
256 self.guidance_level = level;
257 self
258 }
259
260 #[must_use]
265 pub const fn with_context_timeout_ms(mut self, timeout_ms: u64) -> Self {
266 self.context_timeout_ms = timeout_ms;
267 self
268 }
269
270 fn build_context_from_builder(
274 &self,
275 max_tokens: usize,
276 start: Instant,
277 deadline: Duration,
278 ) -> Result<(Option<String>, Option<MemoryStatistics>, usize)> {
279 let Some(ref builder) = self.context_builder else {
280 return Ok((None, None, 0));
281 };
282
283 let context = builder.build_context(max_tokens)?;
284 let ctx = if context.is_empty() {
285 None
286 } else {
287 Some(context)
288 };
289
290 if start.elapsed() >= deadline {
292 tracing::debug!(
293 elapsed_ms = start.elapsed().as_millis(),
294 deadline_ms = self.context_timeout_ms,
295 "Skipping statistics due to timeout (PERF-M3)"
296 );
297 let count = usize::from(ctx.is_some());
298 return Ok((ctx, None, count));
299 }
300
301 let has_context = ctx.is_some();
302 let (stats, count) = match builder.get_statistics() {
303 Ok(s) => {
304 let c = s.total_count;
305 (Some(s), c)
306 },
307 Err(_) => (None, usize::from(has_context)),
308 };
309
310 Ok((ctx, stats, count))
311 }
312
313 fn add_guidance(&self, context_parts: &mut Vec<String>) {
315 match self.guidance_level {
316 GuidanceLevel::Minimal => {
317 },
319 GuidanceLevel::Standard => {
320 context_parts.push(Self::standard_guidance());
321 },
322 GuidanceLevel::Detailed => {
323 context_parts.push(Self::detailed_guidance());
324 },
325 }
326 }
327
328 fn build_session_context(&self, session_id: &str, cwd: &str) -> Result<SessionContext> {
333 let start = Instant::now();
334 let deadline = Duration::from_millis(self.context_timeout_ms);
335 let mut context_parts = Vec::new();
336 let mut memory_count = 0;
337 let mut statistics: Option<MemoryStatistics> = None;
338 let mut timed_out = false;
339
340 context_parts.push(format!(
342 "<subcog_session id=\"{session_id}\" cwd=\"{cwd}\">"
343 ));
344
345 let max_tokens = match self.guidance_level {
347 GuidanceLevel::Minimal => self.max_context_tokens / 2,
348 GuidanceLevel::Standard => self.max_context_tokens,
349 GuidanceLevel::Detailed => self.max_context_tokens * 2,
350 };
351
352 let within_deadline = start.elapsed() < deadline;
354 if !within_deadline {
355 timed_out = true;
356 tracing::warn!(
357 elapsed_ms = start.elapsed().as_millis(),
358 deadline_ms = self.context_timeout_ms,
359 "Context loading timed out, using minimal context (PERF-M3)"
360 );
361 metrics::counter!("session_context_timeout_total", "reason" => "deadline_exceeded")
362 .increment(1);
363 }
364
365 if within_deadline {
367 let (ctx, stats, count) =
368 self.build_context_from_builder(max_tokens, start, deadline)?;
369 if let Some(c) = ctx {
370 context_parts.push(c);
371 }
372 if let Some(s) = stats.as_ref() {
373 add_statistics_if_present(&mut context_parts, s);
374 }
375 statistics = stats;
376 memory_count = count;
377 timed_out = start.elapsed() >= deadline;
378 }
379
380 if !timed_out && start.elapsed() < deadline {
382 self.add_guidance(&mut context_parts);
383 }
384
385 context_parts.push("</subcog_session>".to_string());
387 let content = context_parts.join("");
388 let token_estimate = ContextBuilderService::estimate_tokens(&content);
389
390 if timed_out {
392 metrics::histogram!(
393 "session_context_build_duration_ms",
394 "status" => "timeout"
395 )
396 .record(start.elapsed().as_secs_f64() * 1000.0);
397 } else {
398 metrics::histogram!(
399 "session_context_build_duration_ms",
400 "status" => "success"
401 )
402 .record(start.elapsed().as_secs_f64() * 1000.0);
403 }
404
405 Ok(SessionContext {
406 content,
407 memory_count,
408 token_estimate,
409 was_truncated: token_estimate > max_tokens || timed_out,
410 statistics,
411 })
412 }
413
414 fn format_statistics(stats: &MemoryStatistics) -> String {
416 let mut xml = format!("<stats total=\"{}\">", stats.total_count);
417
418 if !stats.namespace_counts.is_empty() {
420 xml.push_str("<namespaces>");
421 let mut sorted: Vec<_> = stats.namespace_counts.iter().collect();
422 sorted.sort_by(|a, b| b.1.cmp(a.1));
423 for (ns, count) in sorted.iter().take(6) {
424 let _ = write!(xml, "<ns name=\"{ns}\" count=\"{count}\"/>");
425 }
426 xml.push_str("</namespaces>");
427 }
428
429 if !stats.top_tags.is_empty() {
431 xml.push_str("<tags>");
432 for (tag, count) in stats.top_tags.iter().take(8) {
433 let tag_escaped = tag
434 .replace('&', "&")
435 .replace('<', "<")
436 .replace('>', ">")
437 .replace('"', """);
438 let _ = write!(xml, "<tag name=\"{tag_escaped}\" count=\"{count}\"/>");
439 }
440 xml.push_str("</tags>");
441 }
442
443 if !stats.recent_topics.is_empty() {
445 xml.push_str("<topics>");
446 for topic in stats.recent_topics.iter().take(5) {
447 let topic_escaped = topic
448 .replace('&', "&")
449 .replace('<', "<")
450 .replace('>', ">");
451 let _ = write!(xml, "<topic>{topic_escaped}</topic>");
452 }
453 xml.push_str("</topics>");
454 }
455
456 xml.push_str("</stats>");
457 xml
458 }
459
460 fn standard_guidance() -> String {
462 "<guidance level=\"standard\"><tip>Use prompt_understanding for full docs</tip><steps><step>Call subcog_recall before responses</step><step>Capture decisions/patterns/learnings immediately</step></steps></guidance>".to_string()
463 }
464
465 fn detailed_guidance() -> String {
467 "<guidance level=\"detailed\"><tip>prompt_understanding has full protocol</tip><steps><step>Call subcog_recall before responses</step><step>Capture decisions/patterns/learnings immediately</step><step>Use namespaces: decisions, patterns, learnings, context, tech-debt, apis, config, security, performance, testing</step></steps></guidance>".to_string()
468 }
469
470 fn is_first_session(&self) -> bool {
472 self.context_builder
474 .as_ref()
475 .and_then(|builder| builder.build_context(100).ok())
476 .is_none_or(|context| context.is_empty())
477 }
478}
479
480impl Default for SessionStartHandler {
481 fn default() -> Self {
482 Self::new()
483 }
484}
485
486impl HookHandler for SessionStartHandler {
487 fn event_type(&self) -> &'static str {
488 "SessionStart"
489 }
490
491 #[instrument(
492 name = "subcog.hook.session_start",
493 skip(self, input),
494 fields(
495 request_id = tracing::field::Empty,
496 component = "hooks",
497 operation = "session_start",
498 hook = "SessionStart",
499 session_id = tracing::field::Empty,
500 cwd = tracing::field::Empty
501 )
502 )]
503 fn handle(&self, input: &str) -> Result<String> {
504 let start = Instant::now();
505 let mut token_estimate: Option<usize> = None;
506 if let Some(request_id) = current_request_id() {
507 tracing::Span::current().record("request_id", request_id.as_str());
508 }
509
510 tracing::info!(hook = "SessionStart", "Processing session start hook");
511
512 let result = (|| {
513 let input_json: serde_json::Value =
515 serde_json::from_str(input).unwrap_or_else(|_| serde_json::json!({}));
516
517 let session_id = input_json
519 .get("session_id")
520 .and_then(|v| v.as_str())
521 .unwrap_or("unknown");
522
523 let cwd = input_json
524 .get("cwd")
525 .and_then(|v| v.as_str())
526 .unwrap_or(".");
527 let span = tracing::Span::current();
528 span.record("session_id", session_id);
529 span.record("cwd", cwd);
530
531 let validation = validate_session_id(session_id);
533 if validation != SessionIdValidation::Valid {
534 tracing::warn!(
535 session_id = session_id,
536 validation = validation.description(),
537 "Session ID validation warning"
538 );
539 metrics::counter!(
540 "session_id_validation_warnings_total",
541 "reason" => validation.description()
542 )
543 .increment(1);
544 }
545
546 let session_context = self.build_session_context(session_id, cwd)?;
548 token_estimate = Some(session_context.token_estimate);
549
550 let is_first = self.is_first_session();
552
553 let mut metadata = serde_json::json!({
555 "memory_count": session_context.memory_count,
556 "token_estimate": session_context.token_estimate,
557 "was_truncated": session_context.was_truncated,
558 "guidance_level": format!("{:?}", self.guidance_level),
559 });
560
561 if let Some(ref stats) = session_context.statistics {
563 metadata["statistics"] = serde_json::json!({
564 "total_count": stats.total_count,
565 "namespace_counts": stats.namespace_counts,
566 "top_tags": stats.top_tags,
567 "recent_topics": stats.recent_topics
568 });
569 }
570
571 if is_first {
573 metadata["tutorial_invitation"] = serde_json::json!({
574 "prompt_name": "subcog_tutorial",
575 "message": "Welcome to Subcog! Use the subcog_tutorial prompt to get started."
576 });
577 }
578
579 let response = if session_context.content.is_empty() {
582 serde_json::json!({})
584 } else {
585 let metadata_str = serde_json::to_string(&metadata).unwrap_or_default();
587 let context_with_metadata = format!(
588 "{}\n\n<!-- subcog-metadata: {} -->",
589 session_context.content, metadata_str
590 );
591 serde_json::json!({
592 "hookSpecificOutput": {
593 "hookEventName": "SessionStart",
594 "additionalContext": context_with_metadata
595 }
596 })
597 };
598
599 serde_json::to_string(&response).map_err(|e| crate::Error::OperationFailed {
600 operation: "serialize_response".to_string(),
601 cause: e.to_string(),
602 })
603 })();
604
605 let status = if result.is_ok() { "success" } else { "error" };
606 metrics::counter!(
607 "hook_executions_total",
608 "hook_type" => "SessionStart",
609 "status" => status
610 )
611 .increment(1);
612 metrics::histogram!("hook_duration_ms", "hook_type" => "SessionStart")
613 .record(start.elapsed().as_secs_f64() * 1000.0);
614 if let Some(tokens) = token_estimate {
615 let tokens = u32::try_from(tokens).unwrap_or(u32::MAX);
616 metrics::histogram!("hook_context_tokens_estimate", "hook_type" => "SessionStart")
617 .record(f64::from(tokens));
618 }
619
620 result
621 }
622}
623
624fn add_statistics_if_present(context_parts: &mut Vec<String>, stats: &MemoryStatistics) {
626 if stats.total_count > 0 {
627 context_parts.push(SessionStartHandler::format_statistics(stats));
628 }
629}
630
631#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_handler_creation() {
637 let handler = SessionStartHandler::default();
638 assert_eq!(handler.event_type(), "SessionStart");
639 }
640
641 #[test]
642 fn test_guidance_levels() {
643 let handler = SessionStartHandler::new().with_guidance_level(GuidanceLevel::Minimal);
644 assert!(matches!(handler.guidance_level, GuidanceLevel::Minimal));
645
646 let handler = SessionStartHandler::new().with_guidance_level(GuidanceLevel::Detailed);
647 assert!(matches!(handler.guidance_level, GuidanceLevel::Detailed));
648 }
649
650 #[test]
651 fn test_handle_basic() {
652 let handler = SessionStartHandler::default();
653
654 let input = r#"{"session_id": "test-session-abc123def456", "cwd": "/path/to/project"}"#;
655
656 let result = handler.handle(input);
657 assert!(result.is_ok());
658
659 let response: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
660 let hook_output = response.get("hookSpecificOutput").unwrap();
662 assert_eq!(
663 hook_output.get("hookEventName"),
664 Some(&serde_json::Value::String("SessionStart".to_string()))
665 );
666 let context = hook_output
668 .get("additionalContext")
669 .unwrap()
670 .as_str()
671 .unwrap();
672 assert!(context.contains("<subcog_session"));
673 assert!(context.contains("test-session-abc123def456"));
674 assert!(context.contains("subcog-metadata"));
675 }
676
677 #[test]
678 fn test_handle_missing_fields() {
679 let handler = SessionStartHandler::default();
680
681 let input = "{}";
682
683 let result = handler.handle(input);
684 assert!(result.is_ok());
685 }
686
687 #[test]
688 fn test_first_session_detection() {
689 let handler = SessionStartHandler::default();
690 assert!(handler.is_first_session());
692 }
693
694 #[test]
695 fn test_standard_guidance() {
696 let guidance = SessionStartHandler::standard_guidance();
697 assert!(guidance.contains("prompt_understanding"));
698 assert!(guidance.contains("<guidance"));
699 assert!(guidance.contains("subcog_recall"));
700 }
701
702 #[test]
703 fn test_detailed_guidance() {
704 let guidance = SessionStartHandler::detailed_guidance();
705 assert!(guidance.contains("prompt_understanding"));
706 assert!(guidance.contains("<guidance"));
707 assert!(guidance.contains("namespaces"));
708 }
709
710 #[test]
711 fn test_max_tokens_configuration() {
712 let handler = SessionStartHandler::default().with_max_tokens(5000);
713 assert_eq!(handler.max_context_tokens, 5000);
714 }
715
716 #[test]
717 fn test_build_session_context() {
718 let handler = SessionStartHandler::default();
719 let result = handler.build_session_context("test-session", "/project");
720
721 assert!(result.is_ok());
722 let context = result.unwrap();
723 assert!(context.content.contains("test-session"));
724 }
725
726 #[test]
731 fn test_session_id_validation_valid() {
732 assert_eq!(
734 validate_session_id("abc123def456ghi789"),
735 SessionIdValidation::Valid
736 );
737 assert_eq!(
739 validate_session_id("f0504ebb-ca72-4d1a-8b7c-53fc85a1a8ba"),
740 SessionIdValidation::Valid
741 );
742 assert_eq!(
743 validate_session_id("session_2024_01_03_xyz"),
744 SessionIdValidation::Valid
745 );
746 }
747
748 #[test]
749 fn test_session_id_validation_missing() {
750 assert_eq!(validate_session_id(""), SessionIdValidation::Missing);
751 assert_eq!(validate_session_id("unknown"), SessionIdValidation::Missing);
752 }
753
754 #[test]
755 fn test_session_id_validation_too_short() {
756 assert_eq!(validate_session_id("short"), SessionIdValidation::TooShort);
757 assert_eq!(
758 validate_session_id("123456789012345"),
759 SessionIdValidation::TooShort
760 );
761 }
762
763 #[test]
764 fn test_session_id_validation_too_long() {
765 let long_id = "x".repeat(257);
766 assert_eq!(validate_session_id(&long_id), SessionIdValidation::TooLong);
767 }
768
769 #[test]
770 fn test_session_id_validation_low_entropy() {
771 assert_eq!(
773 validate_session_id("aaaaaaaaaaaaaaaaaaaaaaaaa"),
774 SessionIdValidation::LowEntropy
775 );
776
777 assert_eq!(
779 validate_session_id("abababababababababab"),
780 SessionIdValidation::LowEntropy
781 );
782
783 assert_eq!(
785 validate_session_id("abcdefghijklmnop"),
786 SessionIdValidation::LowEntropy
787 );
788 }
789
790 #[test]
791 fn test_session_id_validation_description() {
792 assert_eq!(SessionIdValidation::Valid.description(), "valid");
793 assert!(
794 SessionIdValidation::TooShort
795 .description()
796 .contains("minimum")
797 );
798 assert!(
799 SessionIdValidation::TooLong
800 .description()
801 .contains("maximum")
802 );
803 assert!(
804 SessionIdValidation::LowEntropy
805 .description()
806 .contains("entropy")
807 );
808 assert!(
809 SessionIdValidation::Missing
810 .description()
811 .contains("missing")
812 );
813 }
814
815 #[test]
816 fn test_has_low_entropy_few_unique_chars() {
817 assert!(has_low_entropy("aaa")); assert!(has_low_entropy("aabb")); assert!(has_low_entropy("aaabbbccc")); }
821
822 #[test]
823 fn test_has_long_sequential_run_ascending() {
824 assert!(has_long_sequential_run("abcdefgh"));
826 assert!(has_long_sequential_run("12345678"));
827 assert!(has_long_sequential_run("abcdefghijklmnop"));
828 }
829
830 #[test]
831 fn test_has_long_sequential_run_descending() {
832 assert!(has_long_sequential_run("hgfedcba"));
834 assert!(has_long_sequential_run("87654321"));
835 }
836
837 #[test]
838 fn test_has_long_sequential_run_non_sequential() {
839 assert!(!has_long_sequential_run("axbyczdwev"));
841 assert!(!has_long_sequential_run("8372619450"));
842 assert!(!has_long_sequential_run(
844 "f0504ebb-ca72-4d1a-8b7c-53fc85a1a8ba"
845 ));
846 assert!(!has_long_sequential_run(
847 "550e8400-e29b-41d4-a716-446655440000"
848 ));
849 }
850
851 #[test]
852 fn test_has_long_sequential_run_short_sequences_ok() {
853 assert!(!has_long_sequential_run("abc123xyz")); assert!(!has_long_sequential_run("1234abc5678")); assert!(!has_long_sequential_run("abcdefg")); }
858
859 #[test]
864 fn test_context_timeout_configuration() {
865 let handler = SessionStartHandler::new();
867 assert_eq!(handler.context_timeout_ms, DEFAULT_CONTEXT_TIMEOUT_MS);
868
869 let handler = SessionStartHandler::new().with_context_timeout_ms(1000);
871 assert_eq!(handler.context_timeout_ms, 1000);
872 }
873
874 #[test]
875 fn test_context_timeout_zero_still_works() {
876 let handler = SessionStartHandler::new().with_context_timeout_ms(0);
878 let result = handler.build_session_context("test-session", "/project");
879
880 assert!(result.is_ok());
882 let context = result.unwrap();
883 assert!(context.content.contains("test-session"));
885 assert!(context.was_truncated);
887 }
888
889 #[test]
890 fn test_context_timeout_large_value() {
891 let handler = SessionStartHandler::new().with_context_timeout_ms(60_000);
893 let result = handler.build_session_context("test-session", "/project");
894
895 assert!(result.is_ok());
896 let context = result.unwrap();
897 assert!(context.content.to_lowercase().contains("subcog"));
899 }
900
901 #[test]
902 fn test_build_context_records_was_truncated_on_timeout() {
903 let handler = SessionStartHandler::new().with_context_timeout_ms(0);
905 let result = handler.build_session_context("test", "/path");
906
907 assert!(result.is_ok());
908 let context = result.unwrap();
909 assert!(context.was_truncated);
911 }
912}