subcog/cli/
llm_factory.rs1use std::sync::Arc;
6
7use crate::config::LlmConfig;
8use crate::llm::{
9 AnthropicClient, LlmHttpConfig, LlmProvider, LlmResilienceConfig, LmStudioClient, OllamaClient,
10 OpenAiClient, ResilientLlmProvider,
11};
12
13#[must_use]
15pub fn build_http_config(llm_config: &LlmConfig) -> LlmHttpConfig {
16 LlmHttpConfig::from_config(llm_config).with_env_overrides()
17}
18
19#[must_use]
21pub fn build_resilience_config(llm_config: &LlmConfig) -> LlmResilienceConfig {
22 LlmResilienceConfig::from_config(llm_config).with_env_overrides()
23}
24
25#[must_use]
27pub fn build_openai_client(llm_config: &LlmConfig) -> OpenAiClient {
28 let mut client = OpenAiClient::new();
29 if let Some(ref api_key) = llm_config.api_key {
30 client = client.with_api_key(api_key);
31 }
32 if let Some(ref model) = llm_config.model {
33 client = client.with_model(model);
34 }
35 if let Some(ref base_url) = llm_config.base_url {
36 client = client.with_endpoint(base_url);
37 }
38 if let Some(max_tokens) = llm_config.max_tokens {
39 client = client.with_max_tokens(max_tokens);
40 }
41 client.with_http_config(build_http_config(llm_config))
42}
43
44#[must_use]
46pub fn build_anthropic_client(llm_config: &LlmConfig) -> AnthropicClient {
47 let mut client = AnthropicClient::new();
48 if let Some(ref api_key) = llm_config.api_key {
49 client = client.with_api_key(api_key);
50 }
51 if let Some(ref model) = llm_config.model {
52 client = client.with_model(model);
53 }
54 if let Some(ref base_url) = llm_config.base_url {
55 client = client.with_endpoint(base_url);
56 }
57 client.with_http_config(build_http_config(llm_config))
58}
59
60#[must_use]
62pub fn build_ollama_client(llm_config: &LlmConfig) -> OllamaClient {
63 let mut client = OllamaClient::new();
64 if let Some(ref model) = llm_config.model {
65 client = client.with_model(model);
66 }
67 if let Some(ref base_url) = llm_config.base_url {
68 client = client.with_endpoint(base_url);
69 }
70 client.with_http_config(build_http_config(llm_config))
71}
72
73#[must_use]
75pub fn build_lmstudio_client(llm_config: &LlmConfig) -> LmStudioClient {
76 let mut client = LmStudioClient::new();
77 if let Some(ref model) = llm_config.model {
78 client = client.with_model(model);
79 }
80 if let Some(ref base_url) = llm_config.base_url {
81 client = client.with_endpoint(base_url);
82 }
83 client.with_http_config(build_http_config(llm_config))
84}
85
86#[must_use]
90pub fn build_hook_llm_provider(
91 config: &crate::config::SubcogConfig,
92) -> Option<Arc<dyn LlmProvider>> {
93 use crate::config::LlmProvider as Provider;
94
95 if !config.search_intent.use_llm {
96 return None;
97 }
98
99 let llm_config = &config.llm;
100 let provider: Arc<dyn LlmProvider> = match llm_config.provider {
101 Provider::OpenAi => {
102 let resilience_config = build_resilience_config(llm_config);
103 Arc::new(ResilientLlmProvider::new(
104 build_openai_client(llm_config),
105 resilience_config,
106 ))
107 },
108 Provider::Anthropic => {
109 let resilience_config = build_resilience_config(llm_config);
110 Arc::new(ResilientLlmProvider::new(
111 build_anthropic_client(llm_config),
112 resilience_config,
113 ))
114 },
115 Provider::Ollama => {
116 let resilience_config = build_resilience_config(llm_config);
117 Arc::new(ResilientLlmProvider::new(
118 build_ollama_client(llm_config),
119 resilience_config,
120 ))
121 },
122 Provider::LmStudio => {
123 let resilience_config = build_resilience_config(llm_config);
124 Arc::new(ResilientLlmProvider::new(
125 build_lmstudio_client(llm_config),
126 resilience_config,
127 ))
128 },
129 Provider::None => return None,
130 };
131
132 Some(provider)
133}
134
135#[must_use]
143pub fn build_llm_provider_for_entity_extraction(
144 config: &crate::config::SubcogConfig,
145) -> Option<Arc<dyn LlmProvider>> {
146 use crate::config::{LlmProvider as Provider, OperationType};
147
148 tracing::debug!(
149 llm_features = config.features.llm_features,
150 provider = ?config.llm.provider,
151 "build_llm_provider_for_entity_extraction called"
152 );
153
154 if !config.features.llm_features {
155 tracing::debug!("LLM features disabled in config, returning None");
156 return None;
157 }
158
159 let entity_timeout_ms = u64::try_from(
161 config
162 .timeouts
163 .get(OperationType::EntityExtraction)
164 .as_millis(),
165 )
166 .unwrap_or(u64::MAX);
167 let mut llm_config = config.llm.clone();
168 llm_config.timeout_ms = Some(entity_timeout_ms);
169
170 tracing::debug!(
171 entity_timeout_ms = entity_timeout_ms,
172 "Using entity extraction timeout for LLM"
173 );
174
175 let provider: Arc<dyn LlmProvider> = match llm_config.provider {
176 Provider::OpenAi => {
177 let resilience_config = build_resilience_config(&llm_config);
178 Arc::new(ResilientLlmProvider::new(
179 build_openai_client(&llm_config),
180 resilience_config,
181 ))
182 },
183 Provider::Anthropic => {
184 let resilience_config = build_resilience_config(&llm_config);
185 Arc::new(ResilientLlmProvider::new(
186 build_anthropic_client(&llm_config),
187 resilience_config,
188 ))
189 },
190 Provider::Ollama => {
191 let resilience_config = build_resilience_config(&llm_config);
192 Arc::new(ResilientLlmProvider::new(
193 build_ollama_client(&llm_config),
194 resilience_config,
195 ))
196 },
197 Provider::LmStudio => {
198 let resilience_config = build_resilience_config(&llm_config);
199 Arc::new(ResilientLlmProvider::new(
200 build_lmstudio_client(&llm_config),
201 resilience_config,
202 ))
203 },
204 Provider::None => {
205 tracing::debug!("LLM provider is None, returning None");
206 return None;
207 },
208 };
209
210 tracing::debug!(
211 provider_type = ?llm_config.provider,
212 timeout_ms = entity_timeout_ms,
213 "LLM provider for entity extraction built successfully"
214 );
215 Some(provider)
216}
217
218#[must_use]
224pub fn build_llm_provider(config: &crate::config::SubcogConfig) -> Option<Arc<dyn LlmProvider>> {
225 use crate::config::LlmProvider as Provider;
226
227 tracing::debug!(
228 llm_features = config.features.llm_features,
229 provider = ?config.llm.provider,
230 "build_llm_provider called"
231 );
232
233 if !config.features.llm_features {
234 tracing::debug!("LLM features disabled in config, returning None");
235 return None;
236 }
237
238 let llm_config = &config.llm;
239 let provider: Arc<dyn LlmProvider> = match llm_config.provider {
240 Provider::OpenAi => {
241 let resilience_config = build_resilience_config(llm_config);
242 Arc::new(ResilientLlmProvider::new(
243 build_openai_client(llm_config),
244 resilience_config,
245 ))
246 },
247 Provider::Anthropic => {
248 let resilience_config = build_resilience_config(llm_config);
249 Arc::new(ResilientLlmProvider::new(
250 build_anthropic_client(llm_config),
251 resilience_config,
252 ))
253 },
254 Provider::Ollama => {
255 let resilience_config = build_resilience_config(llm_config);
256 Arc::new(ResilientLlmProvider::new(
257 build_ollama_client(llm_config),
258 resilience_config,
259 ))
260 },
261 Provider::LmStudio => {
262 let resilience_config = build_resilience_config(llm_config);
263 Arc::new(ResilientLlmProvider::new(
264 build_lmstudio_client(llm_config),
265 resilience_config,
266 ))
267 },
268 Provider::None => {
269 tracing::debug!("LLM provider is None, returning None");
270 return None;
271 },
272 };
273
274 tracing::debug!(provider_type = ?llm_config.provider, "LLM provider built successfully");
275 Some(provider)
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::config::{LlmConfig, LlmProvider as Provider, SubcogConfig};
282
283 #[test]
284 fn test_build_http_config_with_defaults() {
285 let llm_config = LlmConfig::default();
286 let http_config = build_http_config(&llm_config);
287
288 assert!(http_config.connect_timeout_ms > 0);
290 assert!(http_config.timeout_ms > 0);
291 }
292
293 #[test]
294 fn test_build_resilience_config_with_defaults() {
295 let llm_config = LlmConfig::default();
296 let resilience_config = build_resilience_config(&llm_config);
297
298 assert!(resilience_config.breaker_failure_threshold > 0);
300 }
301
302 #[test]
303 fn test_build_openai_client_with_config() {
304 let llm_config = LlmConfig {
305 api_key: Some("test-api-key".to_string()),
306 model: Some("gpt-4".to_string()),
307 base_url: Some("https://custom.openai.com".to_string()),
308 ..Default::default()
309 };
310
311 let client = build_openai_client(&llm_config);
312 assert_eq!(client.name(), "openai");
313 }
314
315 #[test]
316 fn test_build_anthropic_client_with_config() {
317 let llm_config = LlmConfig {
318 api_key: Some("sk-ant-test-key".to_string()),
319 model: Some("claude-3-opus".to_string()),
320 ..Default::default()
321 };
322
323 let client = build_anthropic_client(&llm_config);
324 assert_eq!(client.name(), "anthropic");
325 }
326
327 #[test]
328 fn test_build_ollama_client_with_config() {
329 let llm_config = LlmConfig {
330 model: Some("llama2".to_string()),
331 base_url: Some("http://localhost:11434".to_string()),
332 ..Default::default()
333 };
334
335 let client = build_ollama_client(&llm_config);
336 assert_eq!(client.name(), "ollama");
337 }
338
339 #[test]
340 fn test_build_lmstudio_client_with_config() {
341 let llm_config = LlmConfig {
342 model: Some("local-model".to_string()),
343 base_url: Some("http://localhost:1234".to_string()),
344 ..Default::default()
345 };
346
347 let client = build_lmstudio_client(&llm_config);
348 assert_eq!(client.name(), "lmstudio");
349 }
350
351 #[test]
352 fn test_build_hook_llm_provider_disabled() {
353 let mut config = SubcogConfig::default();
354 config.search_intent.use_llm = false;
355
356 let provider = build_hook_llm_provider(&config);
357 assert!(provider.is_none());
358 }
359
360 #[test]
361 fn test_build_hook_llm_provider_openai() {
362 let mut config = SubcogConfig::default();
363 config.search_intent.use_llm = true;
364 config.llm.provider = Provider::OpenAi;
365 config.llm.api_key = Some("test-key".to_string());
366
367 let provider = build_hook_llm_provider(&config);
368 assert!(provider.is_some());
369 }
370
371 #[test]
372 fn test_build_hook_llm_provider_anthropic() {
373 let mut config = SubcogConfig::default();
374 config.search_intent.use_llm = true;
375 config.llm.provider = Provider::Anthropic;
376
377 let provider = build_hook_llm_provider(&config);
378 assert!(provider.is_some());
379 }
380
381 #[test]
382 fn test_build_hook_llm_provider_ollama() {
383 let mut config = SubcogConfig::default();
384 config.search_intent.use_llm = true;
385 config.llm.provider = Provider::Ollama;
386
387 let provider = build_hook_llm_provider(&config);
388 assert!(provider.is_some());
389 }
390
391 #[test]
392 fn test_build_hook_llm_provider_lmstudio() {
393 let mut config = SubcogConfig::default();
394 config.search_intent.use_llm = true;
395 config.llm.provider = Provider::LmStudio;
396
397 let provider = build_hook_llm_provider(&config);
398 assert!(provider.is_some());
399 }
400}