subcog/observability/
request_context.rs

1//! Request context propagation for correlation IDs.
2
3use std::cell::RefCell;
4use std::future::Future;
5use uuid::Uuid;
6
7/// Per-request context with correlation ID.
8#[derive(Clone, Debug)]
9pub struct RequestContext {
10    request_id: String,
11}
12
13impl RequestContext {
14    /// Creates a new request context with a generated ID.
15    #[must_use]
16    pub fn new() -> Self {
17        Self {
18            request_id: Uuid::new_v4().to_string(),
19        }
20    }
21
22    /// Creates a new request context with an existing request ID.
23    #[must_use]
24    pub fn from_id(request_id: impl Into<String>) -> Self {
25        Self {
26            request_id: request_id.into(),
27        }
28    }
29
30    /// Returns the request ID.
31    #[must_use]
32    pub fn request_id(&self) -> &str {
33        &self.request_id
34    }
35}
36
37impl Default for RequestContext {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43tokio::task_local! {
44    static TASK_CONTEXT: RequestContext;
45}
46
47thread_local! {
48    static THREAD_CONTEXT: RefCell<Option<RequestContext>> = const { RefCell::new(None) };
49}
50
51/// Guard that restores the previous thread-local context on drop.
52pub struct RequestContextGuard {
53    previous: Option<RequestContext>,
54}
55
56impl Drop for RequestContextGuard {
57    fn drop(&mut self) {
58        THREAD_CONTEXT.with(|slot| {
59            *slot.borrow_mut() = self.previous.take();
60        });
61    }
62}
63
64/// Enters a request context for synchronous flows.
65#[must_use]
66pub fn enter_request_context(context: RequestContext) -> RequestContextGuard {
67    let previous = THREAD_CONTEXT.with(|slot| slot.borrow_mut().replace(context));
68    RequestContextGuard { previous }
69}
70
71/// Scopes a request context across an async future.
72pub async fn scope_request_context<F, T>(context: RequestContext, fut: F) -> T
73where
74    F: Future<Output = T>,
75{
76    TASK_CONTEXT
77        .scope(context.clone(), async move {
78            let _guard = enter_request_context(context);
79            fut.await
80        })
81        .await
82}
83
84/// Returns the current request ID, if set.
85#[must_use]
86pub fn current_request_id() -> Option<String> {
87    if let Ok(id) = TASK_CONTEXT.try_with(|ctx| ctx.request_id.clone()) {
88        return Some(id);
89    }
90
91    THREAD_CONTEXT.with(|slot| slot.borrow().as_ref().map(|ctx| ctx.request_id.clone()))
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_thread_context_guard_propagates_request_id() {
100        let context = RequestContext::from_id("thread-test");
101        let _guard = enter_request_context(context);
102        assert_eq!(current_request_id().as_deref(), Some("thread-test"));
103    }
104
105    #[tokio::test]
106    async fn test_scope_request_context_propagates_across_await() {
107        let context = RequestContext::from_id("async-test");
108        let observed = scope_request_context(context, async {
109            tokio::task::yield_now().await;
110            current_request_id()
111        })
112        .await;
113        assert_eq!(observed.as_deref(), Some("async-test"));
114    }
115}