Skip to main content

subcog/mcp/
server.rs

1//! MCP server setup and lifecycle.
2//!
3//! Implements an rmcp-based MCP server over stdio or HTTP transport.
4//!
5//! ## Transport Security Model (COMP-CRIT-003)
6//!
7//! ### Stdio Transport (Default)
8//!
9//! The stdio transport is the default and recommended transport for local use with
10//! Claude Desktop and other MCP-compatible clients. Its security model is based on:
11//!
12//! **Trust Assumptions:**
13//! - **Process isolation**: The parent process (Claude Desktop) spawns subcog as a child
14//!   process. Communication happens exclusively over stdin/stdout pipes.
15//! - **No network exposure**: Stdio transport never binds to a network port, eliminating
16//!   remote attack vectors entirely.
17//! - **Same-user execution**: The spawned process inherits the parent's user context,
18//!   meaning file system and secret access is limited to what the invoking user can access.
19//! - **No authentication required**: Since only the parent process can communicate over
20//!   the pipes, authentication is implicit through OS process isolation.
21//!
22//! **Security Properties:**
23//! - **Confidentiality**: Data never leaves the local machine unless explicitly requested
24//!   (e.g., git sync to remote).
25//! - **Integrity**: OS guarantees pipe integrity; no MITM attacks possible.
26//! - **Availability**: Process lifecycle is controlled by the parent.
27//!
28//! **Threat Mitigations:**
29//! - **Malicious input**: Content is sanitized before storage (secrets detection, PII redaction).
30//! - **Resource exhaustion**: Rate limits and memory caps protect against `DoS`.
31//! - **Privilege escalation**: Subcog runs with user privileges, never elevated.
32//!
33//! ### HTTP Transport (Optional)
34//!
35//! The HTTP transport exposes subcog over a network socket and requires explicit security:
36//!
37//! - **JWT bearer token authentication** (SEC-H1): All requests must include a valid JWT.
38//!   Requires `SUBCOG_MCP_JWT_SECRET` environment variable (min 32 characters).
39//! - **Per-client rate limiting** (ARCH-H1): Prevents abuse via configurable request limits.
40//! - **CORS protection** (HIGH-SEC-006): Restrictive by default; origins must be explicitly allowed.
41//! - **Security headers**: X-Content-Type-Options, X-Frame-Options, CSP, no-cache directives.
42//!
43//! **When to use HTTP transport:**
44//! - Remote access to subcog server (e.g., from containerized environments)
45//! - Shared team server with multi-user access
46//! - Integration with web-based MCP clients
47//!
48//! **Configuration:**
49//! ```bash
50//! # Required for HTTP transport
51//! export SUBCOG_MCP_JWT_SECRET="your-32-char-minimum-secret-key"
52//!
53//! # Optional: customize rate limits
54//! export SUBCOG_MCP_RATE_LIMIT_MAX_REQUESTS=1000
55//! export SUBCOG_MCP_RATE_LIMIT_WINDOW_SECS=60
56//!
57//! # Optional: configure CORS for web clients
58//! export SUBCOG_MCP_CORS_ALLOWED_ORIGINS="https://your-app.com"
59//! ```
60
61use crate::mcp::{
62    ResourceContent, ResourceDefinition, ResourceHandler, ToolContent, ToolDefinition,
63    ToolRegistry, ToolResult,
64    prompts::{PromptContent, PromptDefinition, PromptRegistry},
65};
66use crate::models::{EventMeta, MemoryEvent};
67use crate::observability::{
68    RequestContext as ObsRequestContext, current_request_id, flush_metrics, scope_request_context,
69};
70use crate::security::record_event;
71use crate::services::ServiceContainer;
72use crate::{Error, Result as SubcogResult};
73#[cfg(feature = "http")]
74use axum::extract::{Request, State};
75#[cfg(feature = "http")]
76use axum::http::{Method, StatusCode, header};
77#[cfg(feature = "http")]
78use axum::middleware::Next;
79#[cfg(feature = "http")]
80use axum::response::{IntoResponse, Response};
81#[cfg(feature = "http")]
82use axum::routing::any_service;
83#[cfg(feature = "http")]
84use axum::{Json, Router};
85use rmcp::model::{
86    AnnotateAble, CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams,
87    GetPromptResult, Implementation, ListPromptsResult, ListResourceTemplatesResult,
88    ListResourcesResult, ListToolsResult, PaginatedRequestParams, Prompt,
89    PromptArgument as RmcpPromptArgument, PromptMessage as RmcpPromptMessage, PromptMessageContent,
90    PromptMessageRole, RawResource, Resource, ResourceContents, ServerCapabilities, ServerInfo,
91    Tool,
92};
93use rmcp::service::RequestContext;
94use rmcp::transport::stdio;
95#[cfg(feature = "http")]
96use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
97#[cfg(feature = "http")]
98use rmcp::transport::streamable_http_server::tower::{
99    StreamableHttpServerConfig, StreamableHttpService,
100};
101use rmcp::{ErrorData as McpError, RoleServer, ServerHandler, ServiceExt};
102use serde::{Deserialize, Serialize};
103use serde_json::{Map, Value};
104use std::borrow::Cow;
105#[cfg(feature = "http")]
106use std::collections::HashMap;
107use std::sync::Arc;
108use std::sync::atomic::{AtomicBool, Ordering};
109use std::time::{Duration, Instant};
110use tokio::sync::Mutex;
111#[cfg(feature = "http")]
112use tower_http::cors::CorsLayer;
113#[cfg(feature = "http")]
114use tower_http::set_header::SetResponseHeaderLayer;
115#[cfg(feature = "http")]
116use tower_http::trace::TraceLayer;
117use tracing::{Instrument, info_span};
118
119type McpResult<T> = std::result::Result<T, McpError>;
120
121/// Converts a task join error to an MCP internal error.
122fn join_error_to_mcp(e: &tokio::task::JoinError) -> McpError {
123    McpError::internal_error(format!("Task join error: {e}"), None)
124}
125
126fn record_mcp_metrics<T>(operation: &'static str, start: Instant, result: &McpResult<T>) {
127    let status = if result.is_ok() { "success" } else { "error" };
128    metrics::counter!(
129        "mcp_requests_total",
130        "operation" => operation,
131        "status" => status
132    )
133    .increment(1);
134    if result.is_err() {
135        metrics::counter!("mcp_request_errors_total", "operation" => operation).increment(1);
136    }
137    metrics::histogram!("mcp_request_duration_ms", "operation" => operation)
138        .record(start.elapsed().as_secs_f64() * 1000.0);
139}
140
141async fn run_mcp_with_context<T, F, Fut>(
142    request_context: Option<ObsRequestContext>,
143    span: tracing::Span,
144    operation: &'static str,
145    f: F,
146) -> McpResult<T>
147where
148    F: FnOnce(Instant) -> Fut,
149    Fut: std::future::Future<Output = McpResult<T>>,
150{
151    let run = async move {
152        let _span_guard = span.enter();
153        let start = Instant::now();
154        let result = f(start).await;
155        record_mcp_metrics(operation, start, &result);
156        result
157    };
158
159    if let Some(context) = request_context {
160        scope_request_context(context, run).await
161    } else {
162        run.await
163    }
164}
165
166fn init_request_context() -> (Option<ObsRequestContext>, String) {
167    let context = ObsRequestContext::new();
168    let request_id = context.request_id().to_string();
169    (Some(context), request_id)
170}
171
172async fn await_shutdown(cancel_token: rmcp::service::RunningServiceCancellationToken) {
173    while !is_shutdown_requested() {
174        tokio::time::sleep(Duration::from_millis(200)).await;
175    }
176    cancel_token.cancel();
177}
178
179fn execute_call_tool(
180    state: &McpState,
181    request: CallToolRequestParams,
182    start: Instant,
183) -> McpResult<CallToolResult> {
184    let arguments = match request.arguments {
185        Some(args) => Value::Object(args),
186        None => Value::Object(Map::new()),
187    };
188
189    let result = match state.tools.execute(&request.name, arguments) {
190        Ok(result) => result,
191        Err(err) => {
192            record_event(MemoryEvent::McpRequestError {
193                meta: EventMeta::new("mcp", current_request_id()),
194                operation: "call_tool".to_string(),
195                error: err.to_string(),
196            });
197            return Err(McpError::invalid_params(err.to_string(), None));
198        },
199    };
200
201    let status = if result.is_error { "error" } else { "success" };
202    let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
203    record_event(MemoryEvent::McpToolExecuted {
204        meta: EventMeta::new("mcp", current_request_id()),
205        tool_name: request.name.to_string(),
206        status: status.to_string(),
207        duration_ms,
208        error: result
209            .is_error
210            .then_some("tool execution returned error".to_string()),
211    });
212
213    Ok(tool_result_to_rmcp(result))
214}
215
216/// Global shutdown flag for graceful termination (RES-M4).
217static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
218
219/// Checks if shutdown has been requested.
220#[must_use]
221pub fn is_shutdown_requested() -> bool {
222    SHUTDOWN_REQUESTED.load(Ordering::SeqCst)
223}
224
225/// Requests a graceful shutdown.
226pub fn request_shutdown() {
227    SHUTDOWN_REQUESTED.store(true, Ordering::SeqCst);
228}
229
230/// Sets up the signal handler for graceful shutdown (RES-M4).
231///
232/// Installs a handler for SIGINT (Ctrl+C) and SIGTERM that:
233/// 1. Sets the shutdown flag
234/// 2. Logs the shutdown request
235/// 3. Flushes metrics
236///
237/// # Errors
238///
239/// Returns an error if the signal handler cannot be installed.
240pub fn setup_signal_handler() -> SubcogResult<()> {
241    ctrlc::set_handler(move || {
242        tracing::info!("Shutdown signal received, initiating graceful shutdown");
243        request_shutdown();
244
245        // Flush metrics immediately
246        flush_metrics();
247
248        metrics::counter!("mcp_shutdown_signals_total").increment(1);
249    })
250    .map_err(|e| Error::OperationFailed {
251        operation: "setup_signal_handler".to_string(),
252        cause: e.to_string(),
253    })?;
254
255    tracing::debug!("Signal handler installed for graceful shutdown");
256    Ok(())
257}
258
259#[cfg(feature = "http")]
260use crate::mcp::auth::{Claims, JwtAuthenticator, JwtConfig, ToolAuthorization};
261
262/// Default maximum requests per rate limit window.
263const DEFAULT_RATE_LIMIT_MAX_REQUESTS: usize = 1000;
264
265/// Default rate limit window duration (1 minute).
266const DEFAULT_RATE_LIMIT_WINDOW_SECS: u64 = 60;
267
268/// Default allowed CORS origin (none by default for security).
269#[cfg(feature = "http")]
270const DEFAULT_CORS_ALLOWED_ORIGIN: &str = "";
271
272/// CORS configuration (HIGH-SEC-006).
273#[cfg(feature = "http")]
274#[derive(Debug, Clone)]
275pub struct CorsConfig {
276    /// Allowed origins (comma-separated).
277    pub allowed_origins: Vec<String>,
278    /// Allow credentials (cookies, auth headers).
279    pub allow_credentials: bool,
280    /// Max age for preflight cache (seconds).
281    pub max_age_secs: u64,
282}
283
284#[cfg(feature = "http")]
285impl Default for CorsConfig {
286    fn default() -> Self {
287        Self {
288            allowed_origins: Vec::new(), // Deny all by default
289            allow_credentials: false,
290            max_age_secs: 3600,
291        }
292    }
293}
294
295#[cfg(feature = "http")]
296#[derive(Clone)]
297struct RateLimitEntry {
298    count: usize,
299    window_start: Instant,
300}
301
302#[cfg(feature = "http")]
303#[derive(Clone)]
304struct HttpAuthState {
305    authenticator: JwtAuthenticator,
306    rate_limit: RateLimitConfig,
307    rate_limits: Arc<Mutex<HashMap<String, RateLimitEntry>>>,
308}
309
310#[cfg(feature = "http")]
311async fn auth_middleware(
312    State(state): State<HttpAuthState>,
313    mut req: Request,
314    next: Next,
315) -> Response {
316    let request_context = ObsRequestContext::new();
317    let request_id = request_context.request_id().to_string();
318    scope_request_context(request_context, async move {
319        let span = info_span!(
320            "subcog.mcp.auth",
321            request_id = %request_id,
322            component = "mcp",
323            operation = "auth"
324        );
325        let _span_guard = span.enter();
326
327        let auth_header = req
328            .headers()
329            .get(header::AUTHORIZATION)
330            .and_then(|h| h.to_str().ok());
331
332        let claims = if let Some(header_value) = auth_header {
333            match state.authenticator.validate_header(header_value) {
334                Ok(claims) => claims,
335                Err(e) => {
336                    tracing::warn!(error = %e, "JWT authentication failed");
337                    record_event(MemoryEvent::McpAuthFailed {
338                        meta: EventMeta::new("mcp", current_request_id()),
339                        client_id: None,
340                        reason: e.to_string(),
341                    });
342                    return (
343                        StatusCode::UNAUTHORIZED,
344                        Json(serde_json::json!({
345                            "error": {
346                                "code": -32000,
347                                "message": format!("Authentication failed: {e}")
348                            }
349                        })),
350                    )
351                        .into_response();
352                },
353            }
354        } else {
355            record_event(MemoryEvent::McpAuthFailed {
356                meta: EventMeta::new("mcp", current_request_id()),
357                client_id: None,
358                reason: "missing authorization header".to_string(),
359            });
360            return (
361                StatusCode::UNAUTHORIZED,
362                Json(serde_json::json!({
363                    "error": {
364                        "code": -32000,
365                        "message": "Authentication required"
366                    }
367                })),
368            )
369                .into_response();
370        };
371
372        let client_id = claims.sub.clone();
373        let mut rate_limits = state.rate_limits.lock().await;
374        let entry = rate_limits
375            .entry(client_id.clone())
376            .or_insert_with(|| RateLimitEntry {
377                count: 0,
378                window_start: Instant::now(),
379            });
380
381        if entry.window_start.elapsed() > state.rate_limit.window {
382            entry.count = 0;
383            entry.window_start = Instant::now();
384        }
385
386        if entry.count >= state.rate_limit.max_requests {
387            tracing::warn!(
388                client = %client_id,
389                requests = entry.count,
390                "Per-client rate limit exceeded"
391            );
392            return (
393                StatusCode::TOO_MANY_REQUESTS,
394                Json(serde_json::json!({
395                    "error": {
396                        "code": -32000,
397                        "message": format!(
398                            "Rate limit exceeded: max {} requests per {:?}",
399                            state.rate_limit.max_requests,
400                            state.rate_limit.window
401                        )
402                    }
403                })),
404            )
405                .into_response();
406        }
407
408        entry.count += 1;
409        drop(rate_limits);
410
411        req.extensions_mut().insert(claims);
412
413        next.run(req).await
414    })
415    .await
416}
417
418#[cfg(feature = "http")]
419async fn map_notification_status(req: Request, next: Next) -> Response {
420    let mut response = next.run(req).await;
421    if response.status() == StatusCode::ACCEPTED {
422        *response.status_mut() = StatusCode::NO_CONTENT;
423    }
424    response
425}
426
427#[cfg(feature = "http")]
428fn build_cors_layer(config: &CorsConfig) -> SubcogResult<CorsLayer> {
429    if config.allowed_origins.is_empty() {
430        return Ok(CorsLayer::new());
431    }
432
433    let mut cors = CorsLayer::new().allow_methods([
434        Method::GET,
435        Method::POST,
436        Method::DELETE,
437        Method::OPTIONS,
438    ]);
439
440    for origin in &config.allowed_origins {
441        let header_value =
442            origin
443                .parse::<header::HeaderValue>()
444                .map_err(|e| Error::OperationFailed {
445                    operation: "cors_origin".to_string(),
446                    cause: e.to_string(),
447                })?;
448        cors = cors.allow_origin(header_value);
449    }
450
451    if config.allow_credentials {
452        cors = cors.allow_credentials(true);
453    }
454
455    Ok(cors.max_age(Duration::from_secs(config.max_age_secs)))
456}
457
458#[cfg(feature = "http")]
459fn ensure_tool_authorized(
460    tool_auth: &ToolAuthorization,
461    context: &RequestContext<RoleServer>,
462    tool_name: &str,
463) -> McpResult<()> {
464    if let Some(claims) = context.extensions.get::<Claims>()
465        && !tool_auth.is_authorized(claims, tool_name)
466    {
467        let required_scope = tool_auth.required_scope(tool_name);
468        let scope_str = required_scope.unwrap_or("unknown");
469        return Err(McpError::invalid_params(
470            format!("Forbidden: tool '{tool_name}' requires '{scope_str}' scope"),
471            None,
472        ));
473    }
474    Ok(())
475}
476
477#[cfg(feature = "http")]
478impl CorsConfig {
479    /// Creates config from environment variables.
480    ///
481    /// Reads `SUBCOG_MCP_CORS_ALLOWED_ORIGINS` (comma-separated list),
482    /// `SUBCOG_MCP_CORS_ALLOW_CREDENTIALS`, and `SUBCOG_MCP_CORS_MAX_AGE_SECS`.
483    #[must_use]
484    pub fn from_env() -> Self {
485        let allowed_origins = std::env::var("SUBCOG_MCP_CORS_ALLOWED_ORIGINS")
486            .unwrap_or_else(|_| DEFAULT_CORS_ALLOWED_ORIGIN.to_string())
487            .split(',')
488            .map(str::trim)
489            .filter(|s| !s.is_empty())
490            .map(String::from)
491            .collect();
492
493        let allow_credentials = std::env::var("SUBCOG_MCP_CORS_ALLOW_CREDENTIALS")
494            .unwrap_or_else(|_| "false".to_string())
495            .parse::<bool>()
496            .unwrap_or(false);
497
498        let max_age_secs = std::env::var("SUBCOG_MCP_CORS_MAX_AGE_SECS")
499            .unwrap_or_else(|_| "3600".to_string())
500            .parse::<u64>()
501            .unwrap_or(3600);
502
503        Self {
504            allowed_origins,
505            allow_credentials,
506            max_age_secs,
507        }
508    }
509
510    /// Sets the allowed origins.
511    #[must_use]
512    pub fn with_origins(mut self, origins: Vec<String>) -> Self {
513        self.allowed_origins = origins;
514        self
515    }
516
517    /// Sets whether to allow credentials.
518    #[must_use]
519    pub const fn with_credentials(mut self, allow: bool) -> Self {
520        self.allow_credentials = allow;
521        self
522    }
523}
524
525/// Rate limit configuration (ARCH-H1).
526#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct RateLimitConfig {
528    /// Maximum requests per window.
529    pub max_requests: usize,
530    /// Window duration.
531    pub window: Duration,
532}
533
534impl Default for RateLimitConfig {
535    fn default() -> Self {
536        Self {
537            max_requests: DEFAULT_RATE_LIMIT_MAX_REQUESTS,
538            window: Duration::from_secs(DEFAULT_RATE_LIMIT_WINDOW_SECS),
539        }
540    }
541}
542
543impl RateLimitConfig {
544    /// Creates config from environment variables.
545    #[must_use]
546    pub fn from_env() -> Self {
547        let max_requests = std::env::var("SUBCOG_MCP_RATE_LIMIT_MAX_REQUESTS")
548            .ok()
549            .and_then(|v| v.parse().ok())
550            .unwrap_or(DEFAULT_RATE_LIMIT_MAX_REQUESTS);
551
552        let window_secs = std::env::var("SUBCOG_MCP_RATE_LIMIT_WINDOW_SECS")
553            .ok()
554            .and_then(|v| v.parse().ok())
555            .unwrap_or(DEFAULT_RATE_LIMIT_WINDOW_SECS);
556
557        Self {
558            max_requests,
559            window: Duration::from_secs(window_secs),
560        }
561    }
562
563    /// Sets max requests.
564    #[must_use]
565    pub const fn with_max_requests(mut self, max_requests: usize) -> Self {
566        self.max_requests = max_requests;
567        self
568    }
569
570    /// Sets window duration in seconds.
571    #[must_use]
572    pub const fn with_window_secs(mut self, secs: u64) -> Self {
573        self.window = Duration::from_secs(secs);
574        self
575    }
576}
577
578/// Transport type for the MCP server.
579#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
580pub enum Transport {
581    /// Standard input/output (default for Claude Desktop).
582    #[default]
583    Stdio,
584    /// HTTP transport.
585    Http,
586}
587
588struct McpState {
589    tools: ToolRegistry,
590    resources: Mutex<ResourceHandler>,
591    prompts: PromptRegistry,
592    #[cfg(feature = "http")]
593    tool_auth: ToolAuthorization,
594}
595
596#[derive(Clone)]
597struct McpHandler {
598    state: Arc<McpState>,
599}
600
601impl McpHandler {
602    fn new(tools: ToolRegistry, resources: ResourceHandler, prompts: PromptRegistry) -> Self {
603        Self {
604            state: Arc::new(McpState {
605                tools,
606                resources: Mutex::new(resources),
607                prompts,
608                #[cfg(feature = "http")]
609                tool_auth: ToolAuthorization::default(),
610            }),
611        }
612    }
613}
614
615impl ServerHandler for McpHandler {
616    fn get_info(&self) -> ServerInfo {
617        ServerInfo {
618            protocol_version: rmcp::model::ProtocolVersion::default(),
619            capabilities: ServerCapabilities::builder()
620                .enable_tools()
621                .enable_resources()
622                .enable_prompts()
623                .build(),
624            server_info: Implementation::from_build_env(),
625            instructions: Some("Subcog MCP server".to_string()),
626        }
627    }
628
629    fn list_tools(
630        &self,
631        _request: Option<PaginatedRequestParams>,
632        _context: RequestContext<RoleServer>,
633    ) -> impl std::future::Future<Output = McpResult<ListToolsResult>> + Send + '_ {
634        let state = self.state.clone();
635        let (request_context, request_id) = init_request_context();
636
637        async move {
638            let span = info_span!(
639                parent: None,
640                "subcog.mcp.request",
641                request_id = %request_id,
642                component = "mcp",
643                origin = "mcp",
644                entrypoint = "tools/list",
645                operation = "tools/list"
646            );
647
648            run_mcp_with_context(request_context, span, "list_tools", |_start| async move {
649                let tools = state
650                    .tools
651                    .list_tools()
652                    .into_iter()
653                    .map(tool_definition_to_rmcp)
654                    .collect();
655                Ok(ListToolsResult::with_all_items(tools))
656            })
657            .await
658        }
659    }
660
661    fn call_tool(
662        &self,
663        request: CallToolRequestParams,
664        context: RequestContext<RoleServer>,
665    ) -> impl std::future::Future<Output = McpResult<CallToolResult>> + Send + '_ {
666        let state = self.state.clone();
667        let (request_context, request_id) = init_request_context();
668        let tool_name = request.name.clone();
669        async move {
670            let span = info_span!(
671                parent: None,
672                "subcog.mcp.request",
673                request_id = %request_id,
674                component = "mcp",
675                origin = "mcp",
676                entrypoint = "tools/call",
677                operation = "tools/call",
678                tool_name = %tool_name
679            );
680
681            // HTTP authorization check (must happen before spawn_blocking)
682            #[cfg(feature = "http")]
683            if let Err(err) = ensure_tool_authorized(&state.tool_auth, &context, &tool_name) {
684                record_event(MemoryEvent::McpRequestError {
685                    meta: EventMeta::new("mcp", current_request_id()),
686                    operation: "call_tool".to_string(),
687                    error: err.to_string(),
688                });
689                return Err(err);
690            }
691            #[cfg(not(feature = "http"))]
692            let _ = &context;
693
694            run_mcp_with_context(
695                request_context,
696                span,
697                "call_tool",
698                move |start| async move {
699                    // Use spawn_blocking to run the potentially blocking tool execution
700                    // (e.g., LLM calls use reqwest::blocking::Client)
701                    tokio::task::spawn_blocking(move || execute_call_tool(&state, request, start))
702                        .await
703                        .map_err(|e| join_error_to_mcp(&e))?
704                },
705            )
706            .await
707        }
708    }
709
710    fn list_resources(
711        &self,
712        _request: Option<PaginatedRequestParams>,
713        _context: RequestContext<RoleServer>,
714    ) -> impl std::future::Future<Output = McpResult<ListResourcesResult>> + Send + '_ {
715        let state = self.state.clone();
716        let (request_context, request_id) = init_request_context();
717        async move {
718            let span = info_span!(
719                parent: None,
720                "subcog.mcp.request",
721                request_id = %request_id,
722                component = "mcp",
723                origin = "mcp",
724                entrypoint = "resources/list",
725                operation = "resources/list"
726            );
727
728            run_mcp_with_context(
729                request_context,
730                span,
731                "list_resources",
732                |_start| async move {
733                    let resources = state
734                        .resources
735                        .lock()
736                        .await
737                        .list_resources()
738                        .into_iter()
739                        .map(resource_definition_to_rmcp)
740                        .collect();
741                    Ok(ListResourcesResult::with_all_items(resources))
742                },
743            )
744            .await
745        }
746    }
747
748    fn list_resource_templates(
749        &self,
750        _request: Option<PaginatedRequestParams>,
751        _context: RequestContext<RoleServer>,
752    ) -> impl std::future::Future<Output = McpResult<ListResourceTemplatesResult>> + Send + '_ {
753        let (request_context, request_id) = init_request_context();
754
755        async move {
756            let span = info_span!(
757                parent: None,
758                "subcog.mcp.request",
759                request_id = %request_id,
760                component = "mcp",
761                origin = "mcp",
762                entrypoint = "resources/list_templates",
763                operation = "resources/list_templates"
764            );
765
766            run_mcp_with_context(
767                request_context,
768                span,
769                "list_resource_templates",
770                |_start| async move { Ok(ListResourceTemplatesResult::with_all_items(Vec::new())) },
771            )
772            .await
773        }
774    }
775
776    fn read_resource(
777        &self,
778        request: rmcp::model::ReadResourceRequestParams,
779        _context: RequestContext<RoleServer>,
780    ) -> impl std::future::Future<Output = McpResult<rmcp::model::ReadResourceResult>> + Send + '_
781    {
782        let state = self.state.clone();
783        let (request_context, request_id) = init_request_context();
784        let resource_uri = request.uri.clone();
785        async move {
786            let span = info_span!(
787                parent: None,
788                "subcog.mcp.request",
789                request_id = %request_id,
790                component = "mcp",
791                origin = "mcp",
792                entrypoint = "resources/read",
793                operation = "resources/read",
794                resource_uri = %resource_uri
795            );
796
797            run_mcp_with_context(
798                request_context,
799                span,
800                "read_resource",
801                |_start| async move {
802                    let content = state
803                        .resources
804                        .lock()
805                        .await
806                        .get_resource(&request.uri)
807                        .map_err(|e| McpError::resource_not_found(e.to_string(), None))?;
808
809                    let contents = vec![resource_content_to_rmcp(content)];
810                    Ok(rmcp::model::ReadResourceResult { contents })
811                },
812            )
813            .await
814        }
815    }
816
817    fn list_prompts(
818        &self,
819        _request: Option<PaginatedRequestParams>,
820        _context: RequestContext<RoleServer>,
821    ) -> impl std::future::Future<Output = McpResult<ListPromptsResult>> + Send + '_ {
822        let state = self.state.clone();
823        let (request_context, request_id) = init_request_context();
824        async move {
825            let span = info_span!(
826                parent: None,
827                "subcog.mcp.request",
828                request_id = %request_id,
829                component = "mcp",
830                origin = "mcp",
831                entrypoint = "prompts/list",
832                operation = "prompts/list"
833            );
834
835            run_mcp_with_context(request_context, span, "list_prompts", |_start| async move {
836                let prompts = state
837                    .prompts
838                    .list_prompts()
839                    .into_iter()
840                    .map(prompt_definition_to_rmcp)
841                    .collect();
842                Ok(ListPromptsResult::with_all_items(prompts))
843            })
844            .await
845        }
846    }
847
848    fn get_prompt(
849        &self,
850        request: GetPromptRequestParams,
851        _context: RequestContext<RoleServer>,
852    ) -> impl std::future::Future<Output = McpResult<GetPromptResult>> + Send + '_ {
853        let state = self.state.clone();
854        let (request_context, request_id) = init_request_context();
855        let prompt_name = request.name.clone();
856        let arguments = request.arguments.map_or_else(
857            || serde_json::Value::Object(serde_json::Map::new()),
858            serde_json::Value::Object,
859        );
860        async move {
861            let span = info_span!(
862                parent: None,
863                "subcog.mcp.request",
864                request_id = %request_id,
865                component = "mcp",
866                origin = "mcp",
867                entrypoint = "prompts/get",
868                operation = "prompts/get",
869                prompt = %prompt_name
870            );
871
872            run_mcp_with_context(request_context, span, "get_prompt", |_start| async move {
873                resolve_prompt(&state.prompts, &prompt_name, &arguments)
874            })
875            .await
876        }
877    }
878}
879
880/// Resolves a prompt by name and arguments into a `GetPromptResult`.
881fn resolve_prompt(
882    registry: &PromptRegistry,
883    name: &str,
884    arguments: &serde_json::Value,
885) -> McpResult<GetPromptResult> {
886    let prompt_def = registry.get_prompt(name);
887    let messages = registry.get_prompt_messages(name, arguments);
888
889    let Some(messages) = messages else {
890        let err = format!("Unknown prompt: {name}");
891        return Err(McpError::invalid_params(err, None));
892    };
893
894    let rmcp_messages: Vec<RmcpPromptMessage> =
895        messages.into_iter().map(prompt_message_to_rmcp).collect();
896
897    Ok(GetPromptResult {
898        description: prompt_def.and_then(|p| p.description.clone()),
899        messages: rmcp_messages,
900    })
901}
902
903fn tool_definition_to_rmcp(def: &ToolDefinition) -> Tool {
904    let schema = def.input_schema.as_object().cloned().unwrap_or_default();
905
906    Tool {
907        name: Cow::Owned(def.name.clone()),
908        title: None,
909        description: Some(Cow::Owned(def.description.clone())),
910        input_schema: Arc::new(schema),
911        output_schema: None,
912        annotations: None,
913        icons: None,
914        meta: None,
915    }
916}
917
918fn tool_content_to_rmcp(content: ToolContent) -> Content {
919    match content {
920        ToolContent::Text { text } => Content::text(text),
921        ToolContent::Image { data, mime_type } => Content::image(data, mime_type),
922    }
923}
924
925fn tool_result_to_rmcp(result: ToolResult) -> CallToolResult {
926    let contents = result
927        .content
928        .into_iter()
929        .map(tool_content_to_rmcp)
930        .collect();
931    if result.is_error {
932        CallToolResult::error(contents)
933    } else {
934        CallToolResult::success(contents)
935    }
936}
937
938fn resource_definition_to_rmcp(def: ResourceDefinition) -> Resource {
939    RawResource {
940        uri: def.uri,
941        name: def.name,
942        title: None,
943        description: def.description,
944        mime_type: def.mime_type,
945        size: None,
946        icons: None,
947        meta: None,
948    }
949    .no_annotation()
950}
951
952fn prompt_definition_to_rmcp(def: &PromptDefinition) -> Prompt {
953    let arguments = if def.arguments.is_empty() {
954        None
955    } else {
956        Some(
957            def.arguments
958                .iter()
959                .map(|arg| RmcpPromptArgument {
960                    name: arg.name.clone(),
961                    title: None,
962                    description: arg.description.clone(),
963                    required: Some(arg.required),
964                })
965                .collect(),
966        )
967    };
968
969    Prompt {
970        name: def.name.clone(),
971        title: None,
972        description: def.description.clone(),
973        arguments,
974        icons: None,
975        meta: None,
976    }
977}
978
979fn prompt_message_to_rmcp(msg: crate::mcp::prompts::PromptMessage) -> RmcpPromptMessage {
980    let role = match msg.role.as_str() {
981        "assistant" => PromptMessageRole::Assistant,
982        _ => PromptMessageRole::User,
983    };
984
985    let content = match msg.content {
986        PromptContent::Text { text } => PromptMessageContent::Text { text },
987        PromptContent::Image { data, mime_type } => {
988            // Convert to text representation for simplicity
989            // (our prompts rarely use images)
990            PromptMessageContent::Text {
991                text: format!("[Image: {mime_type}, {} bytes]", data.len()),
992            }
993        },
994        PromptContent::Resource { uri } => {
995            // Convert resource to text representation since PromptMessageContent
996            // doesn't have a direct resource variant in the same form
997            PromptMessageContent::Text {
998                text: format!("[Resource: {uri}]"),
999            }
1000        },
1001    };
1002
1003    RmcpPromptMessage { role, content }
1004}
1005
1006fn resource_content_to_rmcp(content: ResourceContent) -> ResourceContents {
1007    if let Some(text) = content.text {
1008        ResourceContents::TextResourceContents {
1009            uri: content.uri,
1010            mime_type: content.mime_type,
1011            text,
1012            meta: None,
1013        }
1014    } else {
1015        ResourceContents::BlobResourceContents {
1016            uri: content.uri,
1017            mime_type: content.mime_type,
1018            blob: content.blob.unwrap_or_default(),
1019            meta: None,
1020        }
1021    }
1022}
1023
1024/// MCP server for subcog.
1025pub struct McpServer {
1026    /// Tool registry.
1027    tools: ToolRegistry,
1028    /// Resource handler.
1029    resources: ResourceHandler,
1030    /// Transport type.
1031    transport: Transport,
1032    /// HTTP port (if using HTTP transport).
1033    port: u16,
1034    /// Rate limit configuration (ARCH-H1).
1035    rate_limit: RateLimitConfig,
1036    /// JWT authenticator for HTTP transport (SEC-H1).
1037    #[cfg(feature = "http")]
1038    jwt_authenticator: Option<JwtAuthenticator>,
1039    /// CORS configuration for HTTP transport (HIGH-SEC-006).
1040    #[cfg(feature = "http")]
1041    cors_config: CorsConfig,
1042}
1043
1044impl McpServer {
1045    /// Creates a new MCP server.
1046    #[must_use]
1047    pub fn new() -> Self {
1048        // Try to initialize RecallService for memory browsing
1049        let resources = Self::try_init_resources();
1050
1051        Self {
1052            tools: ToolRegistry::new(),
1053            resources,
1054            transport: Transport::Stdio,
1055            port: 3000,
1056            rate_limit: RateLimitConfig::from_env(),
1057            #[cfg(feature = "http")]
1058            jwt_authenticator: None,
1059            #[cfg(feature = "http")]
1060            cors_config: CorsConfig::from_env(),
1061        }
1062    }
1063
1064    /// Sets the CORS configuration for HTTP transport (HIGH-SEC-006).
1065    ///
1066    /// By default, no origins are allowed (deny all CORS requests).
1067    /// Use this to explicitly allow specific origins.
1068    #[cfg(feature = "http")]
1069    #[must_use]
1070    pub fn with_cors_config(mut self, config: CorsConfig) -> Self {
1071        self.cors_config = config;
1072        self
1073    }
1074
1075    /// Sets the JWT authenticator for HTTP transport (SEC-H1).
1076    ///
1077    /// # Arguments
1078    ///
1079    /// * `authenticator` - The JWT authenticator to use for validating bearer tokens.
1080    #[cfg(feature = "http")]
1081    #[must_use]
1082    pub fn with_jwt_authenticator(mut self, authenticator: JwtAuthenticator) -> Self {
1083        self.jwt_authenticator = Some(authenticator);
1084        self
1085    }
1086
1087    /// Initializes JWT authentication from environment variables.
1088    ///
1089    /// Reads `SUBCOG_MCP_JWT_SECRET`, `SUBCOG_MCP_JWT_ISSUER`, and
1090    /// `SUBCOG_MCP_JWT_AUDIENCE` from the environment.
1091    ///
1092    /// # Errors
1093    ///
1094    /// Returns an error if `SUBCOG_MCP_JWT_SECRET` is not set or too short.
1095    #[cfg(feature = "http")]
1096    pub fn with_jwt_from_env(self) -> SubcogResult<Self> {
1097        let config = JwtConfig::from_env()?;
1098        let authenticator = JwtAuthenticator::new(&config);
1099        Ok(self.with_jwt_authenticator(authenticator))
1100    }
1101
1102    /// Sets the rate limit configuration (ARCH-H1).
1103    ///
1104    /// # Arguments
1105    ///
1106    /// * `config` - The rate limit configuration.
1107    #[must_use]
1108    pub const fn with_rate_limit(mut self, config: RateLimitConfig) -> Self {
1109        self.rate_limit = config;
1110        self
1111    }
1112
1113    /// Tries to initialize `ResourceHandler` with services.
1114    ///
1115    /// Uses domain-scoped index (user-level index with project facets).
1116    fn try_init_resources() -> ResourceHandler {
1117        use crate::config::SubcogConfig;
1118        use crate::services::PromptService;
1119
1120        let mut handler = ResourceHandler::new();
1121
1122        // Try to add RecallService (works in both project and user scope)
1123        if let Ok(services) = ServiceContainer::from_current_dir_or_user() {
1124            if let Ok(recall) = services.recall() {
1125                handler = handler.with_recall_service(recall);
1126            }
1127
1128            // Try to add PromptService with full config (respects storage settings)
1129            // For user-scope, repo_path is None - PromptService still works with user storage
1130            if let Some(repo_path) = services.repo_path() {
1131                let config = SubcogConfig::load_default().with_repo_path(repo_path);
1132                let prompt_service =
1133                    PromptService::with_subcog_config(config).with_repo_path(repo_path);
1134                handler = handler.with_prompt_service(prompt_service);
1135            } else {
1136                // User-scope: create prompt service without repo path
1137                let config = SubcogConfig::load_default();
1138                let prompt_service = PromptService::with_subcog_config(config);
1139                handler = handler.with_prompt_service(prompt_service);
1140            }
1141        }
1142
1143        handler
1144    }
1145
1146    /// Sets the transport type.
1147    #[must_use]
1148    pub const fn with_transport(mut self, transport: Transport) -> Self {
1149        self.transport = transport;
1150        self
1151    }
1152
1153    /// Sets the HTTP port.
1154    #[must_use]
1155    pub const fn with_port(mut self, port: u16) -> Self {
1156        self.port = port;
1157        self
1158    }
1159
1160    /// Starts the MCP server with graceful shutdown handling (RES-M4).
1161    ///
1162    /// Sets up signal handlers for SIGINT/SIGTERM before starting the server.
1163    /// The server will gracefully shut down when a signal is received.
1164    ///
1165    /// # Errors
1166    ///
1167    /// Returns an error if the server fails to start or signal handler cannot be installed.
1168    pub async fn start(&mut self) -> SubcogResult<()> {
1169        // Set up signal handler for graceful shutdown (RES-M4)
1170        setup_signal_handler()?;
1171
1172        let (transport, port) = match self.transport {
1173            Transport::Stdio => ("stdio", None),
1174            Transport::Http => ("http", Some(self.port)),
1175        };
1176        record_event(MemoryEvent::McpStarted {
1177            meta: EventMeta::new("mcp", current_request_id()),
1178            transport: transport.to_string(),
1179            port,
1180        });
1181
1182        match self.transport {
1183            Transport::Stdio => self.run_stdio().await,
1184            Transport::Http => self.run_http().await,
1185        }
1186    }
1187
1188    fn build_handler(&mut self) -> McpHandler {
1189        let tools = std::mem::take(&mut self.tools);
1190        let resources = std::mem::take(&mut self.resources);
1191        let prompts = PromptRegistry::new();
1192        McpHandler::new(tools, resources, prompts)
1193    }
1194
1195    /// Runs the server over stdio with graceful shutdown (RES-M4).
1196    async fn run_stdio(&mut self) -> SubcogResult<()> {
1197        let handler = self.build_handler();
1198        let service = handler
1199            .serve(stdio())
1200            .await
1201            .map_err(|e| Error::OperationFailed {
1202                operation: "serve_stdio".to_string(),
1203                cause: e.to_string(),
1204            })?;
1205
1206        let cancel_token = service.cancellation_token();
1207        let span = tracing::Span::current();
1208        let request_context = current_request_id().map(ObsRequestContext::from_id);
1209        tokio::spawn(
1210            async move {
1211                let run = await_shutdown(cancel_token);
1212                if let Some(context) = request_context {
1213                    scope_request_context(context, run).await;
1214                } else {
1215                    run.await;
1216                }
1217            }
1218            .instrument(span),
1219        );
1220
1221        service
1222            .waiting()
1223            .await
1224            .map_err(|e| Error::OperationFailed {
1225                operation: "wait_stdio".to_string(),
1226                cause: e.to_string(),
1227            })?;
1228
1229        Ok(())
1230    }
1231
1232    /// Performs graceful shutdown cleanup (RES-M4).
1233    #[allow(dead_code)]
1234    fn graceful_shutdown(&self) {
1235        let start = Instant::now();
1236        tracing::info!("Starting graceful shutdown sequence");
1237
1238        // Flush any pending metrics
1239        flush_metrics();
1240
1241        // Record shutdown metrics
1242        metrics::counter!("mcp_graceful_shutdown_total").increment(1);
1243        metrics::histogram!("mcp_shutdown_duration_ms")
1244            .record(start.elapsed().as_secs_f64() * 1000.0);
1245
1246        tracing::info!(
1247            duration_ms = start.elapsed().as_millis(),
1248            "Graceful shutdown completed"
1249        );
1250    }
1251
1252    /// Runs the server over HTTP with JWT authentication (SEC-H1).
1253    ///
1254    /// Requires the `http` feature and `SUBCOG_MCP_JWT_SECRET` environment variable.
1255    #[cfg(feature = "http")]
1256    async fn run_http(&mut self) -> SubcogResult<()> {
1257        // Ensure JWT authenticator is configured
1258        let authenticator = self.jwt_authenticator.clone().ok_or_else(|| {
1259            Error::OperationFailed {
1260                operation: "run_http".to_string(),
1261                cause: "JWT authenticator not configured. Set SUBCOG_MCP_JWT_SECRET or call with_jwt_authenticator()".to_string(),
1262            }
1263        })?;
1264
1265        let handler = self.build_handler();
1266        let session_manager = Arc::new(LocalSessionManager::default());
1267        let streamable = StreamableHttpService::new(
1268            move || Ok(handler.clone()),
1269            session_manager,
1270            StreamableHttpServerConfig::default(),
1271        );
1272
1273        let auth_state = HttpAuthState {
1274            authenticator,
1275            rate_limit: self.rate_limit.clone(),
1276            rate_limits: Arc::new(Mutex::new(HashMap::new())),
1277        };
1278
1279        // Build CORS layer
1280        let cors_layer = build_cors_layer(&self.cors_config)?;
1281
1282        let app = Router::new()
1283            .route_service("/mcp", any_service(streamable))
1284            .layer(axum::middleware::from_fn_with_state(
1285                auth_state.clone(),
1286                auth_middleware,
1287            ))
1288            .layer(axum::middleware::from_fn(map_notification_status))
1289            // CORS layer (HIGH-SEC-006) - must be before other layers
1290            .layer(cors_layer)
1291            // Security headers (OWASP recommendations)
1292            .layer(SetResponseHeaderLayer::overriding(
1293                header::X_CONTENT_TYPE_OPTIONS,
1294                header::HeaderValue::from_static("nosniff"),
1295            ))
1296            .layer(SetResponseHeaderLayer::overriding(
1297                header::X_FRAME_OPTIONS,
1298                header::HeaderValue::from_static("DENY"),
1299            ))
1300            .layer(SetResponseHeaderLayer::overriding(
1301                header::CONTENT_SECURITY_POLICY,
1302                header::HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
1303            ))
1304            .layer(SetResponseHeaderLayer::overriding(
1305                header::CACHE_CONTROL,
1306                header::HeaderValue::from_static("no-store"),
1307            ))
1308            .layer(SetResponseHeaderLayer::overriding(
1309                header::HeaderName::from_static("x-permitted-cross-domain-policies"),
1310                header::HeaderValue::from_static("none"),
1311            ))
1312            .layer(TraceLayer::new_for_http());
1313
1314        let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
1315        tracing::info!(port = self.port, "Starting MCP HTTP server with JWT auth");
1316
1317        let listener =
1318            tokio::net::TcpListener::bind(addr)
1319                .await
1320                .map_err(|e| Error::OperationFailed {
1321                    operation: "bind".to_string(),
1322                    cause: e.to_string(),
1323                })?;
1324
1325        axum::serve(listener, app)
1326            .await
1327            .map_err(|e| Error::OperationFailed {
1328                operation: "serve".to_string(),
1329                cause: e.to_string(),
1330            })
1331    }
1332
1333    /// Runs the server over HTTP (feature not enabled).
1334    #[cfg(not(feature = "http"))]
1335    #[allow(clippy::unused_async)] // Matches async signature of enabled version
1336    async fn run_http(&self) -> SubcogResult<()> {
1337        Err(Error::FeatureNotEnabled("http".to_string()))
1338    }
1339}
1340
1341impl Default for McpServer {
1342    fn default() -> Self {
1343        Self::new()
1344    }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349    use super::*;
1350
1351    #[test]
1352    fn test_mcp_server_creation() {
1353        let server = McpServer::new();
1354        assert_eq!(server.transport, Transport::Stdio);
1355    }
1356
1357    #[test]
1358    fn test_with_transport() {
1359        let server = McpServer::new()
1360            .with_transport(Transport::Http)
1361            .with_port(8080);
1362        assert_eq!(server.transport, Transport::Http);
1363        assert_eq!(server.port, 8080);
1364    }
1365
1366    #[test]
1367    fn test_tool_definition_mapping() {
1368        let registry = ToolRegistry::new();
1369        let tool = registry.get_tool("subcog_status").unwrap();
1370        let rmcp_tool = tool_definition_to_rmcp(tool);
1371        assert_eq!(rmcp_tool.name, "subcog_status");
1372    }
1373}
1374
1375#[cfg(all(test, feature = "http"))]
1376mod cors_tests {
1377    use super::*;
1378
1379    #[test]
1380    fn test_cors_config_default() {
1381        let config = CorsConfig::default();
1382        assert!(config.allowed_origins.is_empty());
1383        assert!(!config.allow_credentials);
1384        assert_eq!(config.max_age_secs, 3600);
1385    }
1386
1387    #[test]
1388    fn test_cors_config_with_origins() {
1389        let config = CorsConfig::default()
1390            .with_origins(vec!["https://example.com".to_string()])
1391            .with_credentials(true);
1392
1393        assert_eq!(config.allowed_origins.len(), 1);
1394        assert_eq!(config.allowed_origins[0], "https://example.com");
1395        assert!(config.allow_credentials);
1396    }
1397
1398    #[test]
1399    fn test_cors_config_from_env_defaults() {
1400        // Test that from_env() returns sensible defaults when env vars are not set
1401        // (assumes test environment doesn't have SUBCOG_MCP_CORS_* set)
1402        let config = CorsConfig::from_env();
1403        // Default max_age should be 3600
1404        assert_eq!(config.max_age_secs, 3600);
1405        // Default allow_credentials should be false
1406        assert!(!config.allow_credentials);
1407    }
1408
1409    #[test]
1410    fn test_cors_origin_parsing() {
1411        // Test the parsing logic used in from_env
1412        let origins_str = "https://a.com, https://b.com, ";
1413        let origins: Vec<String> = origins_str
1414            .split(',')
1415            .map(str::trim)
1416            .filter(|s| !s.is_empty())
1417            .map(String::from)
1418            .collect();
1419
1420        assert_eq!(origins.len(), 2);
1421        assert_eq!(origins[0], "https://a.com");
1422        assert_eq!(origins[1], "https://b.com");
1423    }
1424
1425    #[test]
1426    fn test_mcp_server_with_cors_config() {
1427        let cors = CorsConfig::default().with_origins(vec!["https://trusted.com".to_string()]);
1428
1429        let server = McpServer::new().with_cors_config(cors);
1430
1431        assert_eq!(server.cors_config.allowed_origins.len(), 1);
1432        assert_eq!(server.cors_config.allowed_origins[0], "https://trusted.com");
1433    }
1434}
1435
1436#[cfg(all(test, feature = "http"))]
1437mod auth_tests {
1438    use super::*;
1439    use axum::{
1440        Router,
1441        body::Body,
1442        http::{Request, StatusCode, header},
1443        middleware,
1444        routing::get,
1445    };
1446    use chrono::Utc;
1447    use jsonwebtoken::{EncodingKey, Header};
1448    use tower::util::ServiceExt;
1449
1450    const TEST_JWT_SECRET: &str = "a-very-long-secret-key-that-is-at-least-32-chars";
1451
1452    fn build_app(state: HttpAuthState) -> Router {
1453        Router::new()
1454            .route("/", get(|| async { "ok" }))
1455            .layer(middleware::from_fn_with_state(state, auth_middleware))
1456    }
1457
1458    fn build_state(max_requests: usize) -> HttpAuthState {
1459        let config = JwtConfig::new(TEST_JWT_SECRET);
1460        HttpAuthState {
1461            authenticator: JwtAuthenticator::new(&config),
1462            rate_limit: RateLimitConfig {
1463                max_requests,
1464                window: Duration::from_secs(60),
1465            },
1466            rate_limits: Arc::new(Mutex::new(HashMap::new())),
1467        }
1468    }
1469
1470    #[allow(clippy::expect_used)]
1471    fn create_test_token(sub: &str) -> String {
1472        let exp =
1473            usize::try_from((Utc::now() + chrono::Duration::hours(1)).timestamp()).unwrap_or(0);
1474        let iat = usize::try_from(Utc::now().timestamp()).unwrap_or(0);
1475        let claims = Claims {
1476            sub: sub.to_string(),
1477            exp,
1478            iat,
1479            iss: None,
1480            aud: None,
1481            scopes: vec!["read".to_string()],
1482        };
1483        jsonwebtoken::encode(
1484            &Header::default(),
1485            &claims,
1486            &EncodingKey::from_secret(TEST_JWT_SECRET.as_bytes()),
1487        )
1488        .expect("Failed to encode test token")
1489    }
1490
1491    #[tokio::test]
1492    async fn test_auth_middleware_missing_header() {
1493        let app = build_app(build_state(5));
1494        let request = Request::builder()
1495            .uri("/")
1496            .body(Body::empty())
1497            .expect("request");
1498        let response = app.oneshot(request).await.expect("response");
1499        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1500    }
1501
1502    #[tokio::test]
1503    async fn test_auth_middleware_rate_limit_exceeded() {
1504        let state = build_state(1);
1505        let app = build_app(state);
1506        let token = create_test_token("client-a");
1507        let auth_header = format!("Bearer {token}");
1508
1509        let request = Request::builder()
1510            .uri("/")
1511            .header(header::AUTHORIZATION, auth_header.clone())
1512            .body(Body::empty())
1513            .expect("request");
1514        let first = app.clone().oneshot(request).await.expect("response");
1515        assert_eq!(first.status(), StatusCode::OK);
1516
1517        let request = Request::builder()
1518            .uri("/")
1519            .header(header::AUTHORIZATION, auth_header)
1520            .body(Body::empty())
1521            .expect("request");
1522        let second = app.oneshot(request).await.expect("response");
1523        assert_eq!(second.status(), StatusCode::TOO_MANY_REQUESTS);
1524    }
1525}