Skip to content

Commit 1700d48

Browse files
committed
refactor: code refactoring
1 parent dab1b2c commit 1700d48

4 files changed

Lines changed: 26 additions & 16 deletions

File tree

src/gpt.rs renamed to src/chatgpt.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
use crate::config::GPTConfig;
1+
use crate::config::ChatGPTConfig;
22
use crate::llm::LLM;
33
use reqwest::header::HeaderMap;
44
use serde_json::{json, Value};
55
use std;
66
use std::collections::HashMap;
77

88
#[derive(Clone, Debug)]
9-
pub struct GPT {
9+
pub struct ChatGPT {
1010
client: reqwest::blocking::Client,
1111
openai_api_key: String,
1212
url: String,
1313
}
1414

15-
impl GPT {
16-
pub fn new(config: GPTConfig) -> Self {
15+
impl ChatGPT {
16+
pub fn new(config: ChatGPTConfig) -> Self {
1717
let openai_api_key = match std::env::var("OPENAI_API_KEY") {
1818
Ok(key) => key,
1919
Err(_) => config
@@ -37,7 +37,7 @@ You need to define one wether in the configuration file or as an environment var
3737
}
3838
}
3939

40-
impl LLM for GPT {
40+
impl LLM for ChatGPT {
4141
fn ask(
4242
&self,
4343
chat_messages: Vec<HashMap<String, String>>,

src/config.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::llm::LLMBackend;
12
use toml;
23

34
use dirs;
@@ -11,36 +12,43 @@ pub struct Config {
1112
#[serde(default)]
1213
pub key_bindings: KeyBindings,
1314

15+
#[serde(default = "default_llm_backend")]
16+
pub model: LLMBackend,
17+
1418
#[serde(default)]
15-
pub gpt: GPTConfig,
19+
pub chatgpt: ChatGPTConfig,
1620
}
1721

1822
pub fn default_archive_file_name() -> String {
1923
String::from("tenere.archive")
2024
}
2125

26+
pub fn default_llm_backend() -> LLMBackend {
27+
LLMBackend::ChatGPT
28+
}
29+
2230
#[derive(Deserialize, Debug, Clone)]
23-
pub struct GPTConfig {
31+
pub struct ChatGPTConfig {
2432
pub openai_api_key: Option<String>,
2533

26-
#[serde(default = "GPTConfig::default_model")]
34+
#[serde(default = "ChatGPTConfig::default_model")]
2735
pub model: String,
2836

29-
#[serde(default = "GPTConfig::default_url")]
37+
#[serde(default = "ChatGPTConfig::default_url")]
3038
pub url: String,
3139
}
3240

33-
impl Default for GPTConfig {
41+
impl Default for ChatGPTConfig {
3442
fn default() -> Self {
3543
Self {
3644
openai_api_key: None,
37-
model: String::from("gpt-3.5-turbo"),
38-
url: String::from("https://api.openai.com/v1/chat/completions"),
45+
model: Self::default_model(),
46+
url: Self::default_url(),
3947
}
4048
}
4149
}
4250

43-
impl GPTConfig {
51+
impl ChatGPTConfig {
4452
pub fn default_model() -> String {
4553
String::from("gpt-3.5-turbo")
4654
}

src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pub mod tui;
66

77
pub mod handler;
88

9-
pub mod gpt;
9+
pub mod chatgpt;
1010

1111
pub mod cli;
1212

src/llm.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
use crate::chatgpt::ChatGPT;
12
use crate::config::Config;
2-
use crate::gpt::GPT;
3+
use serde::Deserialize;
34
use std::collections::HashMap;
45

56
use std::sync::Arc;
@@ -10,6 +11,7 @@ pub trait LLM: Send + Sync {
1011
) -> Result<String, Box<dyn std::error::Error>>;
1112
}
1213

14+
#[derive(Deserialize, Debug)]
1315
pub enum LLMBackend {
1416
ChatGPT,
1517
}
@@ -19,7 +21,7 @@ pub struct LLMModel {}
1921
impl LLMModel {
2022
pub fn init(model: LLMBackend, config: Arc<Config>) -> impl LLM {
2123
match model {
22-
LLMBackend::ChatGPT => GPT::new(config.gpt.clone()),
24+
LLMBackend::ChatGPT => ChatGPT::new(config.chatgpt.clone()),
2325
}
2426
}
2527
}

0 commit comments

Comments
 (0)