Skip to content

Commit dab1b2c

Browse files
committed
refactor: support future llms
1 parent 38ba6fd commit dab1b2c

8 files changed

Lines changed: 92 additions & 47 deletions

File tree

src/app.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use std;
22
use std::collections::HashMap;
33

4-
use crate::config::AppConfig;
4+
use crate::config::Config;
55
use crate::notification::Notification;
66
use crossterm::event::KeyCode;
77

8+
use std::sync::Arc;
9+
810
pub type AppResult<T> = std::result::Result<T, Box<dyn std::error::Error>>;
911

1012
#[derive(Debug)]
@@ -31,16 +33,16 @@ pub struct App {
3133
pub previous_key: KeyCode,
3234
pub focused_block: FocusedBlock,
3335
pub show_help_popup: bool,
34-
pub gpt_messages: Vec<HashMap<String, String>>,
36+
pub llm_messages: Vec<HashMap<String, String>>,
3537
pub history: Vec<Vec<String>>,
3638
pub show_history_popup: bool,
3739
pub history_thread_index: usize,
38-
pub config: AppConfig,
40+
pub config: Arc<Config>,
3941
pub notifications: Vec<Notification>,
4042
}
4143

42-
impl Default for App {
43-
fn default() -> Self {
44+
impl App {
45+
pub fn new(config: Arc<Config>) -> Self {
4446
Self {
4547
running: true,
4648
prompt: String::from(">_ "),
@@ -50,20 +52,14 @@ impl Default for App {
5052
previous_key: KeyCode::Null,
5153
focused_block: FocusedBlock::Prompt,
5254
show_help_popup: false,
53-
gpt_messages: Vec::new(),
55+
llm_messages: Vec::new(),
5456
history: Vec::new(),
5557
show_history_popup: false,
5658
history_thread_index: 0,
57-
config: AppConfig::load(),
59+
config,
5860
notifications: Vec::new(),
5961
}
6062
}
61-
}
62-
63-
impl App {
64-
pub fn new() -> Self {
65-
Self::default()
66-
}
6763

6864
pub fn tick(&mut self) {
6965
self.notifications.retain(|n| n.ttl > 0);

src/config.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,42 +4,50 @@ use dirs;
44
use serde::Deserialize;
55

66
#[derive(Deserialize, Debug)]
7-
pub struct AppConfig {
7+
pub struct Config {
88
#[serde(default = "default_archive_file_name")]
99
pub archive_file_name: String,
1010

1111
#[serde(default)]
1212
pub key_bindings: KeyBindings,
1313

1414
#[serde(default)]
15-
pub gpt: GPT,
15+
pub gpt: GPTConfig,
1616
}
1717

1818
pub fn default_archive_file_name() -> String {
1919
String::from("tenere.archive")
2020
}
2121

2222
#[derive(Deserialize, Debug, Clone)]
23-
pub struct GPT {
23+
pub struct GPTConfig {
2424
pub openai_api_key: Option<String>,
2525

26-
#[serde(default = "GPT::default_model")]
26+
#[serde(default = "GPTConfig::default_model")]
2727
pub model: String,
28+
29+
#[serde(default = "GPTConfig::default_url")]
30+
pub url: String,
2831
}
2932

30-
impl Default for GPT {
33+
impl Default for GPTConfig {
3134
fn default() -> Self {
3235
Self {
3336
openai_api_key: None,
3437
model: String::from("gpt-3.5-turbo"),
38+
url: String::from("https://api.openai.com/v1/chat/completions"),
3539
}
3640
}
3741
}
3842

39-
impl GPT {
43+
impl GPTConfig {
4044
pub fn default_model() -> String {
4145
String::from("gpt-3.5-turbo")
4246
}
47+
48+
pub fn default_url() -> String {
49+
String::from("https://api.openai.com/v1/chat/completions")
50+
}
4351
}
4452

4553
#[derive(Deserialize, Debug)]
@@ -86,15 +94,15 @@ impl KeyBindings {
8694
}
8795
}
8896

89-
impl AppConfig {
97+
impl Config {
9098
pub fn load() -> Self {
9199
let conf_path = dirs::config_dir()
92100
.unwrap()
93101
.join("tenere")
94102
.join("config.toml");
95103

96104
let config = std::fs::read_to_string(conf_path).unwrap_or(String::new());
97-
let app_config: AppConfig = toml::from_str(&config).unwrap();
105+
let app_config: Config = toml::from_str(&config).unwrap();
98106
app_config
99107
}
100108
}

src/event.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub enum Event {
1111
Key(KeyEvent),
1212
Mouse(MouseEvent),
1313
Resize(u16, u16),
14-
GPTResponse(String),
14+
LLMAnswer(String),
1515
Notification(Notification),
1616
}
1717

src/gpt.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::config::GPTConfig;
2+
use crate::llm::LLM;
13
use reqwest::header::HeaderMap;
24
use serde_json::{json, Value};
35
use std;
@@ -7,13 +9,15 @@ use std::collections::HashMap;
79
pub struct GPT {
810
client: reqwest::blocking::Client,
911
openai_api_key: String,
12+
url: String,
1013
}
1114

1215
impl GPT {
13-
pub fn new(api_key: Option<String>) -> Self {
16+
pub fn new(config: GPTConfig) -> Self {
1417
let openai_api_key = match std::env::var("OPENAI_API_KEY") {
1518
Ok(key) => key,
16-
Err(_) => api_key
19+
Err(_) => config
20+
.openai_api_key
1721
.ok_or_else(|| {
1822
eprintln!(
1923
r#"Can not find the openai api key
@@ -28,15 +32,16 @@ You need to define one wether in the configuration file or as an environment var
2832
Self {
2933
client: reqwest::blocking::Client::new(),
3034
openai_api_key,
35+
url: config.url,
3136
}
3237
}
38+
}
3339

34-
pub fn ask(
40+
impl LLM for GPT {
41+
fn ask(
3542
&self,
3643
chat_messages: Vec<HashMap<String, String>>,
3744
) -> Result<String, Box<dyn std::error::Error>> {
38-
let url = "https://api.openai.com/v1/chat/completions";
39-
4045
let mut headers = HeaderMap::new();
4146
headers.insert("Content-Type", "application/json".parse().unwrap());
4247
headers.insert(
@@ -61,7 +66,12 @@ You need to define one wether in the configuration file or as an environment var
6166
"messages": messages
6267
});
6368

64-
let response = self.client.post(url).headers(headers).json(&body).send()?;
69+
let response = self
70+
.client
71+
.post(&self.url)
72+
.headers(headers)
73+
.json(&body)
74+
.send()?;
6575

6676
match response.error_for_status() {
6777
Ok(res) => {

src/handler.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use crate::{
22
app::{App, AppResult, FocusedBlock, Mode},
33
event::Event,
4-
gpt::GPT,
54
};
5+
6+
use crate::llm::LLM;
67
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
78
use std::sync::mpsc::Sender;
89
use std::{collections::HashMap, thread};
@@ -13,7 +14,7 @@ use std::sync::Arc;
1314
pub fn handle_key_events(
1415
key_event: KeyEvent,
1516
app: &mut App,
16-
gpt: Arc<GPT>,
17+
llm: Arc<impl LLM + 'static>,
1718
sender: Sender<Event>,
1819
) -> AppResult<()> {
1920
match app.mode {
@@ -30,25 +31,25 @@ pub fn handle_key_events(
3031
}
3132

3233
KeyCode::Enter => {
33-
let mut conv: HashMap<String, String> = HashMap::new();
34-
3534
let user_input: String = app.prompt.drain(3..).collect();
3635
let user_input = user_input.trim();
3736
if user_input.is_empty() {
3837
return Ok(());
3938
}
4039
app.chat.push(format!(" : {}\n", user_input));
4140

42-
conv.insert("role".to_string(), "user".to_string());
43-
conv.insert("content".to_string(), user_input.to_string());
44-
app.gpt_messages.push(conv.clone());
41+
let conv = HashMap::from([
42+
("role".into(), "user".into()),
43+
("content".into(), user_input.into()),
44+
]);
45+
app.llm_messages.push(conv);
4546

46-
let gpt_messages = app.gpt_messages.clone();
47+
let llm_messages = app.llm_messages.clone();
4748

4849
thread::spawn(move || {
49-
let response = gpt.ask(gpt_messages.to_vec());
50+
let response = llm.ask(llm_messages.to_vec());
5051
sender
51-
.send(Event::GPTResponse(match response {
52+
.send(Event::LLMAnswer(match response {
5253
Ok(answer) => answer,
5354
Err(e) => e.to_string(),
5455
}))
@@ -89,7 +90,7 @@ pub fn handle_key_events(
8990
app.prompt = String::from(">_ ");
9091
app.history.push(app.chat.clone());
9192
app.chat = Vec::new();
92-
app.gpt_messages = Vec::new();
93+
app.llm_messages = Vec::new();
9394
app.scroll = 0;
9495
}
9596

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@ pub mod config;
1515
pub mod ui;
1616

1717
pub mod notification;
18+
19+
pub mod llm;

src/llm.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
use crate::config::Config;
2+
use crate::gpt::GPT;
3+
use std::collections::HashMap;
4+
5+
use std::sync::Arc;
6+
pub trait LLM: Send + Sync {
7+
fn ask(
8+
&self,
9+
chat_messages: Vec<HashMap<String, String>>,
10+
) -> Result<String, Box<dyn std::error::Error>>;
11+
}
12+
13+
pub enum LLMBackend {
14+
ChatGPT,
15+
}
16+
17+
pub struct LLMModel {}
18+
19+
impl LLMModel {
20+
pub fn init(model: LLMBackend, config: Arc<Config>) -> impl LLM {
21+
match model {
22+
LLMBackend::ChatGPT => GPT::new(config.gpt.clone()),
23+
}
24+
}
25+
}

src/main.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,25 @@ use std::collections::HashMap;
22
use std::{env, io};
33
use tenere::app::{App, AppResult};
44
use tenere::cli;
5+
use tenere::config::Config;
56
use tenere::event::{Event, EventHandler};
6-
use tenere::gpt::GPT;
77
use tenere::handler::handle_key_events;
88
use tenere::tui::Tui;
99
use tui::backend::CrosstermBackend;
1010
use tui::Terminal;
1111

12+
use tenere::llm::{LLMBackend, LLMModel};
13+
1214
use std::sync::Arc;
1315

1416
use clap::crate_version;
1517

1618
fn main() -> AppResult<()> {
1719
cli::cli().version(crate_version!()).get_matches();
1820

19-
let mut app = App::new();
20-
let gpt = Arc::new(GPT::new(app.config.gpt.openai_api_key.clone()));
21+
let config = Arc::new(Config::load());
22+
let mut app = App::new(config.clone());
23+
let llm = Arc::new(LLMModel::init(LLMBackend::ChatGPT, config));
2124

2225
let backend = CrosstermBackend::new(io::stderr());
2326
let terminal = Terminal::new(backend)?;
@@ -30,18 +33,18 @@ fn main() -> AppResult<()> {
3033
match tui.events.next()? {
3134
Event::Tick => app.tick(),
3235
Event::Key(key_event) => {
33-
handle_key_events(key_event, &mut app, gpt.clone(), tui.events.sender.clone())?
36+
handle_key_events(key_event, &mut app, llm.clone(), tui.events.sender.clone())?
3437
}
3538
Event::Mouse(_) => {}
3639
Event::Resize(_, _) => {}
37-
Event::GPTResponse(response) => {
40+
Event::LLMAnswer(answer) => {
3841
app.chat.pop();
39-
app.chat.push(format!("🤖: {}\n", response));
42+
app.chat.push(format!("🤖: {}\n", answer));
4043
app.chat.push("\n".to_string());
4144
let mut conv: HashMap<String, String> = HashMap::new();
4245
conv.insert("role".to_string(), "user".to_string());
43-
conv.insert("content".to_string(), response.clone());
44-
app.gpt_messages.push(conv);
46+
conv.insert("content".to_string(), answer);
47+
app.llm_messages.push(conv);
4548
}
4649
Event::Notification(notification) => {
4750
app.notifications.push(notification);

0 commit comments

Comments
 (0)