Skip to content

Commit 74b9146

Browse files
Sanae6roobscoob
authored andcommitted
switch DiscordClient to interior mutability model
1 parent 63e5d86 commit 74b9146

File tree

3 files changed

+65
-58
lines changed

3 files changed

+65
-58
lines changed

src/discord/src/channel/mod.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,20 @@ use crate::{client::DiscordClient, message::{author::{DiscordMessageAuthor, Disp
88
pub struct DiscordChannel {
99
channel_id: Snowflake,
1010
receiver: broadcast::Receiver<DiscordMessage>,
11-
client: Arc<RwLock<DiscordClient>>
11+
client: Arc<DiscordClient>,
1212
}
1313

1414
impl DiscordChannel {
15-
pub async fn new(client: Arc<RwLock<DiscordClient>>, channel_id: Snowflake) -> Self {
15+
pub async fn new(client: Arc<DiscordClient>, channel_id: Snowflake) -> Self {
1616
let (sender, receiver) = broadcast::channel(10);
1717

18-
client.write().await.add_channel_message_sender(channel_id, sender).await;
18+
client.add_channel_message_sender(channel_id, sender).await;
1919

20-
DiscordChannel { channel_id, receiver, client }
20+
DiscordChannel {
21+
channel_id,
22+
receiver,
23+
client,
24+
}
2125
}
2226
}
2327

@@ -35,7 +39,7 @@ impl Channel for DiscordChannel {
3539
let sent_nonce = nonce.clone();
3640

3741
tokio::spawn(async move {
38-
client.write().await.send_message(channel_id, sent_content, sent_nonce).await;
42+
client.send_message(channel_id, sent_content, sent_nonce).await;
3943
});
4044

4145
DiscordMessage {
@@ -52,7 +56,7 @@ impl Clone for DiscordChannel {
5256
Self {
5357
channel_id: self.channel_id.clone(),
5458
receiver: self.receiver.resubscribe(),
55-
client: self.client.clone()
59+
client: self.client.clone(),
5660
}
5761
}
5862
}

src/discord/src/client.rs

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use std::{collections::HashMap, rc::Rc, sync::Arc};
1+
use std::{
2+
collections::HashMap, rc::Rc, sync::{Arc, OnceLock}
3+
};
24

35
use serenity::{
46
all::{Context, EventHandler, GatewayIntents, Message},
@@ -18,47 +20,48 @@ use crate::{
1820

1921
#[derive(Default)]
2022
pub struct DiscordClient {
21-
channel_message_event_handlers: HashMap<Snowflake, Vec<broadcast::Sender<DiscordMessage>>>,
22-
client: Option<serenity::Client>
23+
channel_message_event_handlers: RwLock<HashMap<Snowflake, Vec<broadcast::Sender<DiscordMessage>>>>,
24+
client: OnceLock<serenity::Client>,
2325
}
2426

2527
impl DiscordClient {
26-
pub fn new(token: String) -> Arc<RwLock<DiscordClient>> {
27-
let client = Arc::new(RwLock::new(DiscordClient::default()));
28-
let remote = RemoteDiscordClient(client.clone());
29-
let async_client = client.clone();
30-
31-
tokio::spawn(async move {
32-
let mut client = serenity::Client::builder(token, GatewayIntents::all())
33-
.event_handler(remote)
34-
.await
35-
.expect("Error creating client");
36-
37-
if let Err(why) = client.start().await {
38-
panic!("Client error: {why:?}");
39-
}
28+
pub async fn new(token: String) -> Arc<DiscordClient> {
29+
let client = Arc::new(DiscordClient::default());
30+
31+
let mut discord = serenity::Client::builder(token, GatewayIntents::all()).event_handler_arc(client.clone()).await.expect("Error creating client");
32+
33+
if let Err(why) = discord.start().await {
34+
panic!("Client error: {why:?}");
35+
}
4036

41-
async_client.write().await.client = Some(client);
42-
});
37+
let _ = client.client.set(discord);
4338

4439
client
4540
}
4641

47-
pub async fn add_channel_message_sender(&mut self, channel: Snowflake, sender: broadcast::Sender<DiscordMessage>) {
48-
self.channel_message_event_handlers.entry(channel).or_default().push(sender);
42+
fn discord(&self) -> &serenity::Client {
43+
self.client.get().unwrap()
44+
}
45+
46+
pub async fn add_channel_message_sender(&self, channel: Snowflake, sender: broadcast::Sender<DiscordMessage>) {
47+
self.channel_message_event_handlers.write().await.entry(channel).or_default().push(sender);
4948
}
5049

51-
pub async fn send_message(&mut self, channel_id: Snowflake, content: String, nonce: String) {
50+
pub async fn send_message(&self, channel_id: Snowflake, content: String, nonce: String) {
5251
println!("All the way to discord~! {:?} {:?}", channel_id, content);
53-
ChannelId::new(channel_id.content).send_message(self.client.as_ref().unwrap().http.clone(), CreateMessage::new().content(content).enforce_nonce(true).nonce(serenity::all::Nonce::String(nonce))).await.unwrap();
52+
ChannelId::new(channel_id.content)
53+
.send_message(
54+
self.discord().http.clone(),
55+
CreateMessage::new().content(content).enforce_nonce(true).nonce(serenity::all::Nonce::String(nonce)),
56+
)
57+
.await
58+
.unwrap();
5459
}
5560
}
5661

57-
struct RemoteDiscordClient(Arc<RwLock<DiscordClient>>);
58-
5962
#[async_trait]
60-
impl EventHandler for RemoteDiscordClient {
61-
async fn ready(&self, ctx: Context, data_about_bot: serenity::model::prelude::Ready) {
63+
impl EventHandler for DiscordClient {
64+
async fn ready(&self, _: Context, data_about_bot: serenity::model::prelude::Ready) {
6265
println!("Ready! {:?}", data_about_bot);
6366
}
6467

@@ -69,7 +72,7 @@ impl EventHandler for RemoteDiscordClient {
6972
content: msg.channel_id.get(),
7073
};
7174

72-
if let Some(vec) = self.0.read().await.channel_message_event_handlers.get(&snowflake) {
75+
if let Some(vec) = self.channel_message_event_handlers.read().await.get(&snowflake) {
7376
for sender in vec {
7477
println!("Sending to sender!");
7578

@@ -86,7 +89,7 @@ impl EventHandler for RemoteDiscordClient {
8689
nonce: msg.nonce.clone().map(|n| match n {
8790
Nonce::Number(n) => n.to_string(),
8891
Nonce::String(s) => s,
89-
})
92+
}),
9093
});
9194
}
9295
}

src/ui/src/main.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
pub mod channel;
21
pub mod app_state;
2+
pub mod channel;
33

44
use std::{fs, path::PathBuf, sync::Arc};
55

@@ -30,44 +30,44 @@ impl AssetSource for Assets {
3030
actions!(main_menu, [Quit]);
3131

3232
fn init(app_state: Arc<AppState>, cx: &mut AppContext) -> Result<()> {
33-
components::init(cx);
33+
components::init(cx);
3434

35-
cx.bind_keys([KeyBinding::new("cmd-q", Quit, None)]);
35+
cx.bind_keys([KeyBinding::new("cmd-q", Quit, None)]);
3636

37-
Ok(())
37+
Ok(())
3838
}
3939

4040
#[tokio::main]
4141
async fn main() {
4242
env_logger::init();
4343

44-
let app_state = Arc::new(AppState {});
44+
let app_state = Arc::new(AppState {});
4545

4646
let token = dotenv::var("DISCORD_TOKEN").expect("Must provide DISCORD_TOKEN in .env");
4747
let demo_channel_id = dotenv::var("DEMO_CHANNEL_ID").expect("Must provide DEMO_CHANNEL_ID in .env");
4848

4949
let mut client = DiscordClient::new(token);
5050

51-
let mut channel = DiscordChannel::new(client.clone(), Snowflake { content: demo_channel_id.parse().unwrap() }).await;
51+
let channel = DiscordChannel::new(
52+
client.clone(),
53+
Snowflake {
54+
content: demo_channel_id.parse().unwrap(),
55+
},
56+
)
57+
.await;
5258

53-
App::new()
54-
.with_assets(Assets {
55-
base: PathBuf::from("img"),
56-
})
57-
.with_http_client(Arc::new(reqwest_client::ReqwestClient::new()))
58-
.run(move |cx: &mut AppContext| {
59-
AppState::set_global(Arc::downgrade(&app_state), cx);
59+
App::new().with_assets(Assets { base: PathBuf::from("img") }).with_http_client(Arc::new(reqwest_client::ReqwestClient::new())).run(
60+
move |cx: &mut AppContext| {
61+
AppState::set_global(Arc::downgrade(&app_state), cx);
6062

61-
if let Err(e) = init(app_state.clone(), cx) {
62-
log::error!("{}", e);
63-
return;
64-
}
63+
if let Err(e) = init(app_state.clone(), cx) {
64+
log::error!("{}", e);
65+
return;
66+
}
6567

66-
Theme::sync_system_appearance(cx);
68+
Theme::sync_system_appearance(cx);
6769

68-
let window = cx.open_window(WindowOptions::default(), |cx| {
69-
ChannelView::<DiscordMessage>::create(cx, channel)
70-
})
71-
.unwrap();
72-
});
70+
let window = cx.open_window(WindowOptions::default(), |cx| ChannelView::<DiscordMessage>::create(cx, channel)).unwrap();
71+
},
72+
);
7373
}

0 commit comments

Comments
 (0)