1use crate::mcp::{
62 ResourceContent, ResourceDefinition, ResourceHandler, ToolContent, ToolDefinition,
63 ToolRegistry, ToolResult,
64 prompts::{PromptContent, PromptDefinition, PromptRegistry},
65};
66use crate::models::{EventMeta, MemoryEvent};
67use crate::observability::{
68 RequestContext as ObsRequestContext, current_request_id, flush_metrics, scope_request_context,
69};
70use crate::security::record_event;
71use crate::services::ServiceContainer;
72use crate::{Error, Result as SubcogResult};
73#[cfg(feature = "http")]
74use axum::extract::{Request, State};
75#[cfg(feature = "http")]
76use axum::http::{Method, StatusCode, header};
77#[cfg(feature = "http")]
78use axum::middleware::Next;
79#[cfg(feature = "http")]
80use axum::response::{IntoResponse, Response};
81#[cfg(feature = "http")]
82use axum::routing::any_service;
83#[cfg(feature = "http")]
84use axum::{Json, Router};
85use rmcp::model::{
86 AnnotateAble, CallToolRequestParams, CallToolResult, Content, GetPromptRequestParams,
87 GetPromptResult, Implementation, ListPromptsResult, ListResourceTemplatesResult,
88 ListResourcesResult, ListToolsResult, PaginatedRequestParams, Prompt,
89 PromptArgument as RmcpPromptArgument, PromptMessage as RmcpPromptMessage, PromptMessageContent,
90 PromptMessageRole, RawResource, Resource, ResourceContents, ServerCapabilities, ServerInfo,
91 Tool,
92};
93use rmcp::service::RequestContext;
94use rmcp::transport::stdio;
95#[cfg(feature = "http")]
96use rmcp::transport::streamable_http_server::session::local::LocalSessionManager;
97#[cfg(feature = "http")]
98use rmcp::transport::streamable_http_server::tower::{
99 StreamableHttpServerConfig, StreamableHttpService,
100};
101use rmcp::{ErrorData as McpError, RoleServer, ServerHandler, ServiceExt};
102use serde::{Deserialize, Serialize};
103use serde_json::{Map, Value};
104use std::borrow::Cow;
105#[cfg(feature = "http")]
106use std::collections::HashMap;
107use std::sync::Arc;
108use std::sync::atomic::{AtomicBool, Ordering};
109use std::time::{Duration, Instant};
110use tokio::sync::Mutex;
111#[cfg(feature = "http")]
112use tower_http::cors::CorsLayer;
113#[cfg(feature = "http")]
114use tower_http::set_header::SetResponseHeaderLayer;
115#[cfg(feature = "http")]
116use tower_http::trace::TraceLayer;
117use tracing::{Instrument, info_span};
118
119type McpResult<T> = std::result::Result<T, McpError>;
120
121fn join_error_to_mcp(e: &tokio::task::JoinError) -> McpError {
123 McpError::internal_error(format!("Task join error: {e}"), None)
124}
125
126fn record_mcp_metrics<T>(operation: &'static str, start: Instant, result: &McpResult<T>) {
127 let status = if result.is_ok() { "success" } else { "error" };
128 metrics::counter!(
129 "mcp_requests_total",
130 "operation" => operation,
131 "status" => status
132 )
133 .increment(1);
134 if result.is_err() {
135 metrics::counter!("mcp_request_errors_total", "operation" => operation).increment(1);
136 }
137 metrics::histogram!("mcp_request_duration_ms", "operation" => operation)
138 .record(start.elapsed().as_secs_f64() * 1000.0);
139}
140
141async fn run_mcp_with_context<T, F, Fut>(
142 request_context: Option<ObsRequestContext>,
143 span: tracing::Span,
144 operation: &'static str,
145 f: F,
146) -> McpResult<T>
147where
148 F: FnOnce(Instant) -> Fut,
149 Fut: std::future::Future<Output = McpResult<T>>,
150{
151 let run = async move {
152 let _span_guard = span.enter();
153 let start = Instant::now();
154 let result = f(start).await;
155 record_mcp_metrics(operation, start, &result);
156 result
157 };
158
159 if let Some(context) = request_context {
160 scope_request_context(context, run).await
161 } else {
162 run.await
163 }
164}
165
166fn init_request_context() -> (Option<ObsRequestContext>, String) {
167 let context = ObsRequestContext::new();
168 let request_id = context.request_id().to_string();
169 (Some(context), request_id)
170}
171
172async fn await_shutdown(cancel_token: rmcp::service::RunningServiceCancellationToken) {
173 while !is_shutdown_requested() {
174 tokio::time::sleep(Duration::from_millis(200)).await;
175 }
176 cancel_token.cancel();
177}
178
179fn execute_call_tool(
180 state: &McpState,
181 request: CallToolRequestParams,
182 start: Instant,
183) -> McpResult<CallToolResult> {
184 let arguments = match request.arguments {
185 Some(args) => Value::Object(args),
186 None => Value::Object(Map::new()),
187 };
188
189 let result = match state.tools.execute(&request.name, arguments) {
190 Ok(result) => result,
191 Err(err) => {
192 record_event(MemoryEvent::McpRequestError {
193 meta: EventMeta::new("mcp", current_request_id()),
194 operation: "call_tool".to_string(),
195 error: err.to_string(),
196 });
197 return Err(McpError::invalid_params(err.to_string(), None));
198 },
199 };
200
201 let status = if result.is_error { "error" } else { "success" };
202 let duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
203 record_event(MemoryEvent::McpToolExecuted {
204 meta: EventMeta::new("mcp", current_request_id()),
205 tool_name: request.name.to_string(),
206 status: status.to_string(),
207 duration_ms,
208 error: result
209 .is_error
210 .then_some("tool execution returned error".to_string()),
211 });
212
213 Ok(tool_result_to_rmcp(result))
214}
215
216static SHUTDOWN_REQUESTED: AtomicBool = AtomicBool::new(false);
218
219#[must_use]
221pub fn is_shutdown_requested() -> bool {
222 SHUTDOWN_REQUESTED.load(Ordering::SeqCst)
223}
224
225pub fn request_shutdown() {
227 SHUTDOWN_REQUESTED.store(true, Ordering::SeqCst);
228}
229
230pub fn setup_signal_handler() -> SubcogResult<()> {
241 ctrlc::set_handler(move || {
242 tracing::info!("Shutdown signal received, initiating graceful shutdown");
243 request_shutdown();
244
245 flush_metrics();
247
248 metrics::counter!("mcp_shutdown_signals_total").increment(1);
249 })
250 .map_err(|e| Error::OperationFailed {
251 operation: "setup_signal_handler".to_string(),
252 cause: e.to_string(),
253 })?;
254
255 tracing::debug!("Signal handler installed for graceful shutdown");
256 Ok(())
257}
258
259#[cfg(feature = "http")]
260use crate::mcp::auth::{Claims, JwtAuthenticator, JwtConfig, ToolAuthorization};
261
262const DEFAULT_RATE_LIMIT_MAX_REQUESTS: usize = 1000;
264
265const DEFAULT_RATE_LIMIT_WINDOW_SECS: u64 = 60;
267
268#[cfg(feature = "http")]
270const DEFAULT_CORS_ALLOWED_ORIGIN: &str = "";
271
272#[cfg(feature = "http")]
274#[derive(Debug, Clone)]
275pub struct CorsConfig {
276 pub allowed_origins: Vec<String>,
278 pub allow_credentials: bool,
280 pub max_age_secs: u64,
282}
283
284#[cfg(feature = "http")]
285impl Default for CorsConfig {
286 fn default() -> Self {
287 Self {
288 allowed_origins: Vec::new(), allow_credentials: false,
290 max_age_secs: 3600,
291 }
292 }
293}
294
295#[cfg(feature = "http")]
296#[derive(Clone)]
297struct RateLimitEntry {
298 count: usize,
299 window_start: Instant,
300}
301
302#[cfg(feature = "http")]
303#[derive(Clone)]
304struct HttpAuthState {
305 authenticator: JwtAuthenticator,
306 rate_limit: RateLimitConfig,
307 rate_limits: Arc<Mutex<HashMap<String, RateLimitEntry>>>,
308}
309
310#[cfg(feature = "http")]
311async fn auth_middleware(
312 State(state): State<HttpAuthState>,
313 mut req: Request,
314 next: Next,
315) -> Response {
316 let request_context = ObsRequestContext::new();
317 let request_id = request_context.request_id().to_string();
318 scope_request_context(request_context, async move {
319 let span = info_span!(
320 "subcog.mcp.auth",
321 request_id = %request_id,
322 component = "mcp",
323 operation = "auth"
324 );
325 let _span_guard = span.enter();
326
327 let auth_header = req
328 .headers()
329 .get(header::AUTHORIZATION)
330 .and_then(|h| h.to_str().ok());
331
332 let claims = if let Some(header_value) = auth_header {
333 match state.authenticator.validate_header(header_value) {
334 Ok(claims) => claims,
335 Err(e) => {
336 tracing::warn!(error = %e, "JWT authentication failed");
337 record_event(MemoryEvent::McpAuthFailed {
338 meta: EventMeta::new("mcp", current_request_id()),
339 client_id: None,
340 reason: e.to_string(),
341 });
342 return (
343 StatusCode::UNAUTHORIZED,
344 Json(serde_json::json!({
345 "error": {
346 "code": -32000,
347 "message": format!("Authentication failed: {e}")
348 }
349 })),
350 )
351 .into_response();
352 },
353 }
354 } else {
355 record_event(MemoryEvent::McpAuthFailed {
356 meta: EventMeta::new("mcp", current_request_id()),
357 client_id: None,
358 reason: "missing authorization header".to_string(),
359 });
360 return (
361 StatusCode::UNAUTHORIZED,
362 Json(serde_json::json!({
363 "error": {
364 "code": -32000,
365 "message": "Authentication required"
366 }
367 })),
368 )
369 .into_response();
370 };
371
372 let client_id = claims.sub.clone();
373 let mut rate_limits = state.rate_limits.lock().await;
374 let entry = rate_limits
375 .entry(client_id.clone())
376 .or_insert_with(|| RateLimitEntry {
377 count: 0,
378 window_start: Instant::now(),
379 });
380
381 if entry.window_start.elapsed() > state.rate_limit.window {
382 entry.count = 0;
383 entry.window_start = Instant::now();
384 }
385
386 if entry.count >= state.rate_limit.max_requests {
387 tracing::warn!(
388 client = %client_id,
389 requests = entry.count,
390 "Per-client rate limit exceeded"
391 );
392 return (
393 StatusCode::TOO_MANY_REQUESTS,
394 Json(serde_json::json!({
395 "error": {
396 "code": -32000,
397 "message": format!(
398 "Rate limit exceeded: max {} requests per {:?}",
399 state.rate_limit.max_requests,
400 state.rate_limit.window
401 )
402 }
403 })),
404 )
405 .into_response();
406 }
407
408 entry.count += 1;
409 drop(rate_limits);
410
411 req.extensions_mut().insert(claims);
412
413 next.run(req).await
414 })
415 .await
416}
417
418#[cfg(feature = "http")]
419async fn map_notification_status(req: Request, next: Next) -> Response {
420 let mut response = next.run(req).await;
421 if response.status() == StatusCode::ACCEPTED {
422 *response.status_mut() = StatusCode::NO_CONTENT;
423 }
424 response
425}
426
427#[cfg(feature = "http")]
428fn build_cors_layer(config: &CorsConfig) -> SubcogResult<CorsLayer> {
429 if config.allowed_origins.is_empty() {
430 return Ok(CorsLayer::new());
431 }
432
433 let mut cors = CorsLayer::new().allow_methods([
434 Method::GET,
435 Method::POST,
436 Method::DELETE,
437 Method::OPTIONS,
438 ]);
439
440 for origin in &config.allowed_origins {
441 let header_value =
442 origin
443 .parse::<header::HeaderValue>()
444 .map_err(|e| Error::OperationFailed {
445 operation: "cors_origin".to_string(),
446 cause: e.to_string(),
447 })?;
448 cors = cors.allow_origin(header_value);
449 }
450
451 if config.allow_credentials {
452 cors = cors.allow_credentials(true);
453 }
454
455 Ok(cors.max_age(Duration::from_secs(config.max_age_secs)))
456}
457
458#[cfg(feature = "http")]
459fn ensure_tool_authorized(
460 tool_auth: &ToolAuthorization,
461 context: &RequestContext<RoleServer>,
462 tool_name: &str,
463) -> McpResult<()> {
464 if let Some(claims) = context.extensions.get::<Claims>()
465 && !tool_auth.is_authorized(claims, tool_name)
466 {
467 let required_scope = tool_auth.required_scope(tool_name);
468 let scope_str = required_scope.unwrap_or("unknown");
469 return Err(McpError::invalid_params(
470 format!("Forbidden: tool '{tool_name}' requires '{scope_str}' scope"),
471 None,
472 ));
473 }
474 Ok(())
475}
476
477#[cfg(feature = "http")]
478impl CorsConfig {
479 #[must_use]
484 pub fn from_env() -> Self {
485 let allowed_origins = std::env::var("SUBCOG_MCP_CORS_ALLOWED_ORIGINS")
486 .unwrap_or_else(|_| DEFAULT_CORS_ALLOWED_ORIGIN.to_string())
487 .split(',')
488 .map(str::trim)
489 .filter(|s| !s.is_empty())
490 .map(String::from)
491 .collect();
492
493 let allow_credentials = std::env::var("SUBCOG_MCP_CORS_ALLOW_CREDENTIALS")
494 .unwrap_or_else(|_| "false".to_string())
495 .parse::<bool>()
496 .unwrap_or(false);
497
498 let max_age_secs = std::env::var("SUBCOG_MCP_CORS_MAX_AGE_SECS")
499 .unwrap_or_else(|_| "3600".to_string())
500 .parse::<u64>()
501 .unwrap_or(3600);
502
503 Self {
504 allowed_origins,
505 allow_credentials,
506 max_age_secs,
507 }
508 }
509
510 #[must_use]
512 pub fn with_origins(mut self, origins: Vec<String>) -> Self {
513 self.allowed_origins = origins;
514 self
515 }
516
517 #[must_use]
519 pub const fn with_credentials(mut self, allow: bool) -> Self {
520 self.allow_credentials = allow;
521 self
522 }
523}
524
525#[derive(Debug, Clone, Serialize, Deserialize)]
527pub struct RateLimitConfig {
528 pub max_requests: usize,
530 pub window: Duration,
532}
533
534impl Default for RateLimitConfig {
535 fn default() -> Self {
536 Self {
537 max_requests: DEFAULT_RATE_LIMIT_MAX_REQUESTS,
538 window: Duration::from_secs(DEFAULT_RATE_LIMIT_WINDOW_SECS),
539 }
540 }
541}
542
543impl RateLimitConfig {
544 #[must_use]
546 pub fn from_env() -> Self {
547 let max_requests = std::env::var("SUBCOG_MCP_RATE_LIMIT_MAX_REQUESTS")
548 .ok()
549 .and_then(|v| v.parse().ok())
550 .unwrap_or(DEFAULT_RATE_LIMIT_MAX_REQUESTS);
551
552 let window_secs = std::env::var("SUBCOG_MCP_RATE_LIMIT_WINDOW_SECS")
553 .ok()
554 .and_then(|v| v.parse().ok())
555 .unwrap_or(DEFAULT_RATE_LIMIT_WINDOW_SECS);
556
557 Self {
558 max_requests,
559 window: Duration::from_secs(window_secs),
560 }
561 }
562
563 #[must_use]
565 pub const fn with_max_requests(mut self, max_requests: usize) -> Self {
566 self.max_requests = max_requests;
567 self
568 }
569
570 #[must_use]
572 pub const fn with_window_secs(mut self, secs: u64) -> Self {
573 self.window = Duration::from_secs(secs);
574 self
575 }
576}
577
578#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
580pub enum Transport {
581 #[default]
583 Stdio,
584 Http,
586}
587
588struct McpState {
589 tools: ToolRegistry,
590 resources: Mutex<ResourceHandler>,
591 prompts: PromptRegistry,
592 #[cfg(feature = "http")]
593 tool_auth: ToolAuthorization,
594}
595
596#[derive(Clone)]
597struct McpHandler {
598 state: Arc<McpState>,
599}
600
601impl McpHandler {
602 fn new(tools: ToolRegistry, resources: ResourceHandler, prompts: PromptRegistry) -> Self {
603 Self {
604 state: Arc::new(McpState {
605 tools,
606 resources: Mutex::new(resources),
607 prompts,
608 #[cfg(feature = "http")]
609 tool_auth: ToolAuthorization::default(),
610 }),
611 }
612 }
613}
614
615impl ServerHandler for McpHandler {
616 fn get_info(&self) -> ServerInfo {
617 ServerInfo {
618 protocol_version: rmcp::model::ProtocolVersion::default(),
619 capabilities: ServerCapabilities::builder()
620 .enable_tools()
621 .enable_resources()
622 .enable_prompts()
623 .build(),
624 server_info: Implementation::from_build_env(),
625 instructions: Some("Subcog MCP server".to_string()),
626 }
627 }
628
629 fn list_tools(
630 &self,
631 _request: Option<PaginatedRequestParams>,
632 _context: RequestContext<RoleServer>,
633 ) -> impl std::future::Future<Output = McpResult<ListToolsResult>> + Send + '_ {
634 let state = self.state.clone();
635 let (request_context, request_id) = init_request_context();
636
637 async move {
638 let span = info_span!(
639 parent: None,
640 "subcog.mcp.request",
641 request_id = %request_id,
642 component = "mcp",
643 origin = "mcp",
644 entrypoint = "tools/list",
645 operation = "tools/list"
646 );
647
648 run_mcp_with_context(request_context, span, "list_tools", |_start| async move {
649 let tools = state
650 .tools
651 .list_tools()
652 .into_iter()
653 .map(tool_definition_to_rmcp)
654 .collect();
655 Ok(ListToolsResult::with_all_items(tools))
656 })
657 .await
658 }
659 }
660
661 fn call_tool(
662 &self,
663 request: CallToolRequestParams,
664 context: RequestContext<RoleServer>,
665 ) -> impl std::future::Future<Output = McpResult<CallToolResult>> + Send + '_ {
666 let state = self.state.clone();
667 let (request_context, request_id) = init_request_context();
668 let tool_name = request.name.clone();
669 async move {
670 let span = info_span!(
671 parent: None,
672 "subcog.mcp.request",
673 request_id = %request_id,
674 component = "mcp",
675 origin = "mcp",
676 entrypoint = "tools/call",
677 operation = "tools/call",
678 tool_name = %tool_name
679 );
680
681 #[cfg(feature = "http")]
683 if let Err(err) = ensure_tool_authorized(&state.tool_auth, &context, &tool_name) {
684 record_event(MemoryEvent::McpRequestError {
685 meta: EventMeta::new("mcp", current_request_id()),
686 operation: "call_tool".to_string(),
687 error: err.to_string(),
688 });
689 return Err(err);
690 }
691 #[cfg(not(feature = "http"))]
692 let _ = &context;
693
694 run_mcp_with_context(
695 request_context,
696 span,
697 "call_tool",
698 move |start| async move {
699 tokio::task::spawn_blocking(move || execute_call_tool(&state, request, start))
702 .await
703 .map_err(|e| join_error_to_mcp(&e))?
704 },
705 )
706 .await
707 }
708 }
709
710 fn list_resources(
711 &self,
712 _request: Option<PaginatedRequestParams>,
713 _context: RequestContext<RoleServer>,
714 ) -> impl std::future::Future<Output = McpResult<ListResourcesResult>> + Send + '_ {
715 let state = self.state.clone();
716 let (request_context, request_id) = init_request_context();
717 async move {
718 let span = info_span!(
719 parent: None,
720 "subcog.mcp.request",
721 request_id = %request_id,
722 component = "mcp",
723 origin = "mcp",
724 entrypoint = "resources/list",
725 operation = "resources/list"
726 );
727
728 run_mcp_with_context(
729 request_context,
730 span,
731 "list_resources",
732 |_start| async move {
733 let resources = state
734 .resources
735 .lock()
736 .await
737 .list_resources()
738 .into_iter()
739 .map(resource_definition_to_rmcp)
740 .collect();
741 Ok(ListResourcesResult::with_all_items(resources))
742 },
743 )
744 .await
745 }
746 }
747
748 fn list_resource_templates(
749 &self,
750 _request: Option<PaginatedRequestParams>,
751 _context: RequestContext<RoleServer>,
752 ) -> impl std::future::Future<Output = McpResult<ListResourceTemplatesResult>> + Send + '_ {
753 let (request_context, request_id) = init_request_context();
754
755 async move {
756 let span = info_span!(
757 parent: None,
758 "subcog.mcp.request",
759 request_id = %request_id,
760 component = "mcp",
761 origin = "mcp",
762 entrypoint = "resources/list_templates",
763 operation = "resources/list_templates"
764 );
765
766 run_mcp_with_context(
767 request_context,
768 span,
769 "list_resource_templates",
770 |_start| async move { Ok(ListResourceTemplatesResult::with_all_items(Vec::new())) },
771 )
772 .await
773 }
774 }
775
776 fn read_resource(
777 &self,
778 request: rmcp::model::ReadResourceRequestParams,
779 _context: RequestContext<RoleServer>,
780 ) -> impl std::future::Future<Output = McpResult<rmcp::model::ReadResourceResult>> + Send + '_
781 {
782 let state = self.state.clone();
783 let (request_context, request_id) = init_request_context();
784 let resource_uri = request.uri.clone();
785 async move {
786 let span = info_span!(
787 parent: None,
788 "subcog.mcp.request",
789 request_id = %request_id,
790 component = "mcp",
791 origin = "mcp",
792 entrypoint = "resources/read",
793 operation = "resources/read",
794 resource_uri = %resource_uri
795 );
796
797 run_mcp_with_context(
798 request_context,
799 span,
800 "read_resource",
801 |_start| async move {
802 let content = state
803 .resources
804 .lock()
805 .await
806 .get_resource(&request.uri)
807 .map_err(|e| McpError::resource_not_found(e.to_string(), None))?;
808
809 let contents = vec![resource_content_to_rmcp(content)];
810 Ok(rmcp::model::ReadResourceResult { contents })
811 },
812 )
813 .await
814 }
815 }
816
817 fn list_prompts(
818 &self,
819 _request: Option<PaginatedRequestParams>,
820 _context: RequestContext<RoleServer>,
821 ) -> impl std::future::Future<Output = McpResult<ListPromptsResult>> + Send + '_ {
822 let state = self.state.clone();
823 let (request_context, request_id) = init_request_context();
824 async move {
825 let span = info_span!(
826 parent: None,
827 "subcog.mcp.request",
828 request_id = %request_id,
829 component = "mcp",
830 origin = "mcp",
831 entrypoint = "prompts/list",
832 operation = "prompts/list"
833 );
834
835 run_mcp_with_context(request_context, span, "list_prompts", |_start| async move {
836 let prompts = state
837 .prompts
838 .list_prompts()
839 .into_iter()
840 .map(prompt_definition_to_rmcp)
841 .collect();
842 Ok(ListPromptsResult::with_all_items(prompts))
843 })
844 .await
845 }
846 }
847
848 fn get_prompt(
849 &self,
850 request: GetPromptRequestParams,
851 _context: RequestContext<RoleServer>,
852 ) -> impl std::future::Future<Output = McpResult<GetPromptResult>> + Send + '_ {
853 let state = self.state.clone();
854 let (request_context, request_id) = init_request_context();
855 let prompt_name = request.name.clone();
856 let arguments = request.arguments.map_or_else(
857 || serde_json::Value::Object(serde_json::Map::new()),
858 serde_json::Value::Object,
859 );
860 async move {
861 let span = info_span!(
862 parent: None,
863 "subcog.mcp.request",
864 request_id = %request_id,
865 component = "mcp",
866 origin = "mcp",
867 entrypoint = "prompts/get",
868 operation = "prompts/get",
869 prompt = %prompt_name
870 );
871
872 run_mcp_with_context(request_context, span, "get_prompt", |_start| async move {
873 resolve_prompt(&state.prompts, &prompt_name, &arguments)
874 })
875 .await
876 }
877 }
878}
879
880fn resolve_prompt(
882 registry: &PromptRegistry,
883 name: &str,
884 arguments: &serde_json::Value,
885) -> McpResult<GetPromptResult> {
886 let prompt_def = registry.get_prompt(name);
887 let messages = registry.get_prompt_messages(name, arguments);
888
889 let Some(messages) = messages else {
890 let err = format!("Unknown prompt: {name}");
891 return Err(McpError::invalid_params(err, None));
892 };
893
894 let rmcp_messages: Vec<RmcpPromptMessage> =
895 messages.into_iter().map(prompt_message_to_rmcp).collect();
896
897 Ok(GetPromptResult {
898 description: prompt_def.and_then(|p| p.description.clone()),
899 messages: rmcp_messages,
900 })
901}
902
903fn tool_definition_to_rmcp(def: &ToolDefinition) -> Tool {
904 let schema = def.input_schema.as_object().cloned().unwrap_or_default();
905
906 Tool {
907 name: Cow::Owned(def.name.clone()),
908 title: None,
909 description: Some(Cow::Owned(def.description.clone())),
910 input_schema: Arc::new(schema),
911 output_schema: None,
912 annotations: None,
913 icons: None,
914 meta: None,
915 }
916}
917
918fn tool_content_to_rmcp(content: ToolContent) -> Content {
919 match content {
920 ToolContent::Text { text } => Content::text(text),
921 ToolContent::Image { data, mime_type } => Content::image(data, mime_type),
922 }
923}
924
925fn tool_result_to_rmcp(result: ToolResult) -> CallToolResult {
926 let contents = result
927 .content
928 .into_iter()
929 .map(tool_content_to_rmcp)
930 .collect();
931 if result.is_error {
932 CallToolResult::error(contents)
933 } else {
934 CallToolResult::success(contents)
935 }
936}
937
938fn resource_definition_to_rmcp(def: ResourceDefinition) -> Resource {
939 RawResource {
940 uri: def.uri,
941 name: def.name,
942 title: None,
943 description: def.description,
944 mime_type: def.mime_type,
945 size: None,
946 icons: None,
947 meta: None,
948 }
949 .no_annotation()
950}
951
952fn prompt_definition_to_rmcp(def: &PromptDefinition) -> Prompt {
953 let arguments = if def.arguments.is_empty() {
954 None
955 } else {
956 Some(
957 def.arguments
958 .iter()
959 .map(|arg| RmcpPromptArgument {
960 name: arg.name.clone(),
961 title: None,
962 description: arg.description.clone(),
963 required: Some(arg.required),
964 })
965 .collect(),
966 )
967 };
968
969 Prompt {
970 name: def.name.clone(),
971 title: None,
972 description: def.description.clone(),
973 arguments,
974 icons: None,
975 meta: None,
976 }
977}
978
979fn prompt_message_to_rmcp(msg: crate::mcp::prompts::PromptMessage) -> RmcpPromptMessage {
980 let role = match msg.role.as_str() {
981 "assistant" => PromptMessageRole::Assistant,
982 _ => PromptMessageRole::User,
983 };
984
985 let content = match msg.content {
986 PromptContent::Text { text } => PromptMessageContent::Text { text },
987 PromptContent::Image { data, mime_type } => {
988 PromptMessageContent::Text {
991 text: format!("[Image: {mime_type}, {} bytes]", data.len()),
992 }
993 },
994 PromptContent::Resource { uri } => {
995 PromptMessageContent::Text {
998 text: format!("[Resource: {uri}]"),
999 }
1000 },
1001 };
1002
1003 RmcpPromptMessage { role, content }
1004}
1005
1006fn resource_content_to_rmcp(content: ResourceContent) -> ResourceContents {
1007 if let Some(text) = content.text {
1008 ResourceContents::TextResourceContents {
1009 uri: content.uri,
1010 mime_type: content.mime_type,
1011 text,
1012 meta: None,
1013 }
1014 } else {
1015 ResourceContents::BlobResourceContents {
1016 uri: content.uri,
1017 mime_type: content.mime_type,
1018 blob: content.blob.unwrap_or_default(),
1019 meta: None,
1020 }
1021 }
1022}
1023
1024pub struct McpServer {
1026 tools: ToolRegistry,
1028 resources: ResourceHandler,
1030 transport: Transport,
1032 port: u16,
1034 rate_limit: RateLimitConfig,
1036 #[cfg(feature = "http")]
1038 jwt_authenticator: Option<JwtAuthenticator>,
1039 #[cfg(feature = "http")]
1041 cors_config: CorsConfig,
1042}
1043
1044impl McpServer {
1045 #[must_use]
1047 pub fn new() -> Self {
1048 let resources = Self::try_init_resources();
1050
1051 Self {
1052 tools: ToolRegistry::new(),
1053 resources,
1054 transport: Transport::Stdio,
1055 port: 3000,
1056 rate_limit: RateLimitConfig::from_env(),
1057 #[cfg(feature = "http")]
1058 jwt_authenticator: None,
1059 #[cfg(feature = "http")]
1060 cors_config: CorsConfig::from_env(),
1061 }
1062 }
1063
1064 #[cfg(feature = "http")]
1069 #[must_use]
1070 pub fn with_cors_config(mut self, config: CorsConfig) -> Self {
1071 self.cors_config = config;
1072 self
1073 }
1074
1075 #[cfg(feature = "http")]
1081 #[must_use]
1082 pub fn with_jwt_authenticator(mut self, authenticator: JwtAuthenticator) -> Self {
1083 self.jwt_authenticator = Some(authenticator);
1084 self
1085 }
1086
1087 #[cfg(feature = "http")]
1096 pub fn with_jwt_from_env(self) -> SubcogResult<Self> {
1097 let config = JwtConfig::from_env()?;
1098 let authenticator = JwtAuthenticator::new(&config);
1099 Ok(self.with_jwt_authenticator(authenticator))
1100 }
1101
1102 #[must_use]
1108 pub const fn with_rate_limit(mut self, config: RateLimitConfig) -> Self {
1109 self.rate_limit = config;
1110 self
1111 }
1112
1113 fn try_init_resources() -> ResourceHandler {
1117 use crate::config::SubcogConfig;
1118 use crate::services::PromptService;
1119
1120 let mut handler = ResourceHandler::new();
1121
1122 if let Ok(services) = ServiceContainer::from_current_dir_or_user() {
1124 if let Ok(recall) = services.recall() {
1125 handler = handler.with_recall_service(recall);
1126 }
1127
1128 if let Some(repo_path) = services.repo_path() {
1131 let config = SubcogConfig::load_default().with_repo_path(repo_path);
1132 let prompt_service =
1133 PromptService::with_subcog_config(config).with_repo_path(repo_path);
1134 handler = handler.with_prompt_service(prompt_service);
1135 } else {
1136 let config = SubcogConfig::load_default();
1138 let prompt_service = PromptService::with_subcog_config(config);
1139 handler = handler.with_prompt_service(prompt_service);
1140 }
1141 }
1142
1143 handler
1144 }
1145
1146 #[must_use]
1148 pub const fn with_transport(mut self, transport: Transport) -> Self {
1149 self.transport = transport;
1150 self
1151 }
1152
1153 #[must_use]
1155 pub const fn with_port(mut self, port: u16) -> Self {
1156 self.port = port;
1157 self
1158 }
1159
1160 pub async fn start(&mut self) -> SubcogResult<()> {
1169 setup_signal_handler()?;
1171
1172 let (transport, port) = match self.transport {
1173 Transport::Stdio => ("stdio", None),
1174 Transport::Http => ("http", Some(self.port)),
1175 };
1176 record_event(MemoryEvent::McpStarted {
1177 meta: EventMeta::new("mcp", current_request_id()),
1178 transport: transport.to_string(),
1179 port,
1180 });
1181
1182 match self.transport {
1183 Transport::Stdio => self.run_stdio().await,
1184 Transport::Http => self.run_http().await,
1185 }
1186 }
1187
1188 fn build_handler(&mut self) -> McpHandler {
1189 let tools = std::mem::take(&mut self.tools);
1190 let resources = std::mem::take(&mut self.resources);
1191 let prompts = PromptRegistry::new();
1192 McpHandler::new(tools, resources, prompts)
1193 }
1194
1195 async fn run_stdio(&mut self) -> SubcogResult<()> {
1197 let handler = self.build_handler();
1198 let service = handler
1199 .serve(stdio())
1200 .await
1201 .map_err(|e| Error::OperationFailed {
1202 operation: "serve_stdio".to_string(),
1203 cause: e.to_string(),
1204 })?;
1205
1206 let cancel_token = service.cancellation_token();
1207 let span = tracing::Span::current();
1208 let request_context = current_request_id().map(ObsRequestContext::from_id);
1209 tokio::spawn(
1210 async move {
1211 let run = await_shutdown(cancel_token);
1212 if let Some(context) = request_context {
1213 scope_request_context(context, run).await;
1214 } else {
1215 run.await;
1216 }
1217 }
1218 .instrument(span),
1219 );
1220
1221 service
1222 .waiting()
1223 .await
1224 .map_err(|e| Error::OperationFailed {
1225 operation: "wait_stdio".to_string(),
1226 cause: e.to_string(),
1227 })?;
1228
1229 Ok(())
1230 }
1231
1232 #[allow(dead_code)]
1234 fn graceful_shutdown(&self) {
1235 let start = Instant::now();
1236 tracing::info!("Starting graceful shutdown sequence");
1237
1238 flush_metrics();
1240
1241 metrics::counter!("mcp_graceful_shutdown_total").increment(1);
1243 metrics::histogram!("mcp_shutdown_duration_ms")
1244 .record(start.elapsed().as_secs_f64() * 1000.0);
1245
1246 tracing::info!(
1247 duration_ms = start.elapsed().as_millis(),
1248 "Graceful shutdown completed"
1249 );
1250 }
1251
1252 #[cfg(feature = "http")]
1256 async fn run_http(&mut self) -> SubcogResult<()> {
1257 let authenticator = self.jwt_authenticator.clone().ok_or_else(|| {
1259 Error::OperationFailed {
1260 operation: "run_http".to_string(),
1261 cause: "JWT authenticator not configured. Set SUBCOG_MCP_JWT_SECRET or call with_jwt_authenticator()".to_string(),
1262 }
1263 })?;
1264
1265 let handler = self.build_handler();
1266 let session_manager = Arc::new(LocalSessionManager::default());
1267 let streamable = StreamableHttpService::new(
1268 move || Ok(handler.clone()),
1269 session_manager,
1270 StreamableHttpServerConfig::default(),
1271 );
1272
1273 let auth_state = HttpAuthState {
1274 authenticator,
1275 rate_limit: self.rate_limit.clone(),
1276 rate_limits: Arc::new(Mutex::new(HashMap::new())),
1277 };
1278
1279 let cors_layer = build_cors_layer(&self.cors_config)?;
1281
1282 let app = Router::new()
1283 .route_service("/mcp", any_service(streamable))
1284 .layer(axum::middleware::from_fn_with_state(
1285 auth_state.clone(),
1286 auth_middleware,
1287 ))
1288 .layer(axum::middleware::from_fn(map_notification_status))
1289 .layer(cors_layer)
1291 .layer(SetResponseHeaderLayer::overriding(
1293 header::X_CONTENT_TYPE_OPTIONS,
1294 header::HeaderValue::from_static("nosniff"),
1295 ))
1296 .layer(SetResponseHeaderLayer::overriding(
1297 header::X_FRAME_OPTIONS,
1298 header::HeaderValue::from_static("DENY"),
1299 ))
1300 .layer(SetResponseHeaderLayer::overriding(
1301 header::CONTENT_SECURITY_POLICY,
1302 header::HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'"),
1303 ))
1304 .layer(SetResponseHeaderLayer::overriding(
1305 header::CACHE_CONTROL,
1306 header::HeaderValue::from_static("no-store"),
1307 ))
1308 .layer(SetResponseHeaderLayer::overriding(
1309 header::HeaderName::from_static("x-permitted-cross-domain-policies"),
1310 header::HeaderValue::from_static("none"),
1311 ))
1312 .layer(TraceLayer::new_for_http());
1313
1314 let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
1315 tracing::info!(port = self.port, "Starting MCP HTTP server with JWT auth");
1316
1317 let listener =
1318 tokio::net::TcpListener::bind(addr)
1319 .await
1320 .map_err(|e| Error::OperationFailed {
1321 operation: "bind".to_string(),
1322 cause: e.to_string(),
1323 })?;
1324
1325 axum::serve(listener, app)
1326 .await
1327 .map_err(|e| Error::OperationFailed {
1328 operation: "serve".to_string(),
1329 cause: e.to_string(),
1330 })
1331 }
1332
1333 #[cfg(not(feature = "http"))]
1335 #[allow(clippy::unused_async)] async fn run_http(&self) -> SubcogResult<()> {
1337 Err(Error::FeatureNotEnabled("http".to_string()))
1338 }
1339}
1340
1341impl Default for McpServer {
1342 fn default() -> Self {
1343 Self::new()
1344 }
1345}
1346
1347#[cfg(test)]
1348mod tests {
1349 use super::*;
1350
1351 #[test]
1352 fn test_mcp_server_creation() {
1353 let server = McpServer::new();
1354 assert_eq!(server.transport, Transport::Stdio);
1355 }
1356
1357 #[test]
1358 fn test_with_transport() {
1359 let server = McpServer::new()
1360 .with_transport(Transport::Http)
1361 .with_port(8080);
1362 assert_eq!(server.transport, Transport::Http);
1363 assert_eq!(server.port, 8080);
1364 }
1365
1366 #[test]
1367 fn test_tool_definition_mapping() {
1368 let registry = ToolRegistry::new();
1369 let tool = registry.get_tool("subcog_status").unwrap();
1370 let rmcp_tool = tool_definition_to_rmcp(tool);
1371 assert_eq!(rmcp_tool.name, "subcog_status");
1372 }
1373}
1374
1375#[cfg(all(test, feature = "http"))]
1376mod cors_tests {
1377 use super::*;
1378
1379 #[test]
1380 fn test_cors_config_default() {
1381 let config = CorsConfig::default();
1382 assert!(config.allowed_origins.is_empty());
1383 assert!(!config.allow_credentials);
1384 assert_eq!(config.max_age_secs, 3600);
1385 }
1386
1387 #[test]
1388 fn test_cors_config_with_origins() {
1389 let config = CorsConfig::default()
1390 .with_origins(vec!["https://example.com".to_string()])
1391 .with_credentials(true);
1392
1393 assert_eq!(config.allowed_origins.len(), 1);
1394 assert_eq!(config.allowed_origins[0], "https://example.com");
1395 assert!(config.allow_credentials);
1396 }
1397
1398 #[test]
1399 fn test_cors_config_from_env_defaults() {
1400 let config = CorsConfig::from_env();
1403 assert_eq!(config.max_age_secs, 3600);
1405 assert!(!config.allow_credentials);
1407 }
1408
1409 #[test]
1410 fn test_cors_origin_parsing() {
1411 let origins_str = "https://a.com, https://b.com, ";
1413 let origins: Vec<String> = origins_str
1414 .split(',')
1415 .map(str::trim)
1416 .filter(|s| !s.is_empty())
1417 .map(String::from)
1418 .collect();
1419
1420 assert_eq!(origins.len(), 2);
1421 assert_eq!(origins[0], "https://a.com");
1422 assert_eq!(origins[1], "https://b.com");
1423 }
1424
1425 #[test]
1426 fn test_mcp_server_with_cors_config() {
1427 let cors = CorsConfig::default().with_origins(vec!["https://trusted.com".to_string()]);
1428
1429 let server = McpServer::new().with_cors_config(cors);
1430
1431 assert_eq!(server.cors_config.allowed_origins.len(), 1);
1432 assert_eq!(server.cors_config.allowed_origins[0], "https://trusted.com");
1433 }
1434}
1435
1436#[cfg(all(test, feature = "http"))]
1437mod auth_tests {
1438 use super::*;
1439 use axum::{
1440 Router,
1441 body::Body,
1442 http::{Request, StatusCode, header},
1443 middleware,
1444 routing::get,
1445 };
1446 use chrono::Utc;
1447 use jsonwebtoken::{EncodingKey, Header};
1448 use tower::util::ServiceExt;
1449
1450 const TEST_JWT_SECRET: &str = "a-very-long-secret-key-that-is-at-least-32-chars";
1451
1452 fn build_app(state: HttpAuthState) -> Router {
1453 Router::new()
1454 .route("/", get(|| async { "ok" }))
1455 .layer(middleware::from_fn_with_state(state, auth_middleware))
1456 }
1457
1458 fn build_state(max_requests: usize) -> HttpAuthState {
1459 let config = JwtConfig::new(TEST_JWT_SECRET);
1460 HttpAuthState {
1461 authenticator: JwtAuthenticator::new(&config),
1462 rate_limit: RateLimitConfig {
1463 max_requests,
1464 window: Duration::from_secs(60),
1465 },
1466 rate_limits: Arc::new(Mutex::new(HashMap::new())),
1467 }
1468 }
1469
1470 #[allow(clippy::expect_used)]
1471 fn create_test_token(sub: &str) -> String {
1472 let exp =
1473 usize::try_from((Utc::now() + chrono::Duration::hours(1)).timestamp()).unwrap_or(0);
1474 let iat = usize::try_from(Utc::now().timestamp()).unwrap_or(0);
1475 let claims = Claims {
1476 sub: sub.to_string(),
1477 exp,
1478 iat,
1479 iss: None,
1480 aud: None,
1481 scopes: vec!["read".to_string()],
1482 };
1483 jsonwebtoken::encode(
1484 &Header::default(),
1485 &claims,
1486 &EncodingKey::from_secret(TEST_JWT_SECRET.as_bytes()),
1487 )
1488 .expect("Failed to encode test token")
1489 }
1490
1491 #[tokio::test]
1492 async fn test_auth_middleware_missing_header() {
1493 let app = build_app(build_state(5));
1494 let request = Request::builder()
1495 .uri("/")
1496 .body(Body::empty())
1497 .expect("request");
1498 let response = app.oneshot(request).await.expect("response");
1499 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1500 }
1501
1502 #[tokio::test]
1503 async fn test_auth_middleware_rate_limit_exceeded() {
1504 let state = build_state(1);
1505 let app = build_app(state);
1506 let token = create_test_token("client-a");
1507 let auth_header = format!("Bearer {token}");
1508
1509 let request = Request::builder()
1510 .uri("/")
1511 .header(header::AUTHORIZATION, auth_header.clone())
1512 .body(Body::empty())
1513 .expect("request");
1514 let first = app.clone().oneshot(request).await.expect("response");
1515 assert_eq!(first.status(), StatusCode::OK);
1516
1517 let request = Request::builder()
1518 .uri("/")
1519 .header(header::AUTHORIZATION, auth_header)
1520 .body(Body::empty())
1521 .expect("request");
1522 let second = app.oneshot(request).await.expect("response");
1523 assert_eq!(second.status(), StatusCode::TOO_MANY_REQUESTS);
1524 }
1525}