1#[cfg(feature = "ssr")]
2pub use axum::extract::ws::Message as WSMessage;
3#[cfg(feature = "ssr")]
4pub use axum::extract::ws::WebSocket as AxumWS;
5#[cfg(feature = "ssr")]
6pub use flams_database::DBBackend;
7
8#[cfg(feature = "hydrate")]
9fn js_to_string(e: leptos::wasm_bindgen::JsValue) -> String {
10 use leptos::web_sys::js_sys::Object;
11 Object::from(e).to_string().into()
12}
13
14#[cfg(feature = "hydrate")]
15pub trait WebSocketClient<
16 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
17 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
18>: WebSocket<ClientMsg, ServerMsg>
19{
20 fn new(ws: leptos::web_sys::WebSocket) -> Self;
21 fn socket(&mut self) -> &mut leptos::web_sys::WebSocket;
22
23 fn send(&mut self, msg: &ClientMsg) {
24 let Ok(s) = serde_json::to_string(msg) else {
25 tracing::error!("Error serializing websocket message");
26 return;
27 };
28 if let Err(e) = self.socket().send_with_str(&s) {
29 tracing::error!("Error sending websocket message: {}", js_to_string(e));
30 }
31 }
32
33 #[allow(clippy::cognitive_complexity)]
34 fn callback(
35 ws: &leptos::web_sys::WebSocket,
36 handle: &mut impl (FnMut(ServerMsg) -> Option<ClientMsg>),
37 event: leptos::web_sys::MessageEvent,
38 ) {
39 let Some(data) = event.data().as_string() else {
40 tracing::error!("Not a string: {}", js_to_string(event.data()));
41 return;
42 };
43 if data == "ping" {
44 if let Err(e) = ws.send_with_str("pong") {
45 tracing::error!("Error sending websocket message: {}", js_to_string(e));
46 }
47 } else {
48 let mut deserializer = serde_json::Deserializer::from_str(&data);
49 deserializer.disable_recursion_limit();
50 let value = ServerMsg::deserialize(&mut deserializer);
51 let ret = match value {
52 Ok(msg) => msg,
53 Err(e) => {
54 tracing::error!("{e}");
55 return;
56 }
57 };
58 if let Some(a) = handle(ret) {
59 let Ok(s) = serde_json::to_string(&a) else {
60 tracing::error!("Error serializing websocket message");
61 return;
62 };
63 if let Err(e) = ws.send_with_str(&s) {
64 tracing::error!("Error sending websocket message: {}", js_to_string(e));
65 }
66 }
67 }
68 }
69
70 fn start(mut handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static) -> Option<Self> {
71 use leptos::wasm_bindgen::JsCast;
72 use leptos::wasm_bindgen::prelude::Closure;
73 let ws = match leptos::web_sys::WebSocket::new(Self::SERVER_ENDPOINT) {
74 Ok(ws) => ws,
75 Err(e) => {
76 tracing::error!(
77 "Error creating websockecrate::users::t: {}",
78 js_to_string(e)
79 );
80 return None;
81 }
82 };
83 let ws2 = ws.clone();
84 let callback =
85 Closure::<dyn FnMut(_)>::new(move |event| Self::callback(&ws2, &mut handle, event));
86 ws.set_onmessage(Some(callback.as_ref().unchecked_ref()));
87 let mut r = Self::new(ws);
88 callback.forget();
89 if let Some(mut f) = r.on_open() {
90 let callback = Closure::<dyn FnMut(_)>::new(move |_: leptos::web_sys::MessageEvent| {
91 f();
92 });
93 r.socket()
94 .set_onopen(Some(callback.as_ref().unchecked_ref()));
95 callback.forget();
96 }
97 Some(r)
98 }
99
100 fn on_open(&self) -> Option<Box<dyn FnMut()>> {
101 None
102 }
103}
104
105#[cfg(feature = "ssr")]
106#[async_trait::async_trait]
107pub trait WebSocketServer<
108 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
109 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
110>: WebSocket<ClientMsg, ServerMsg>
111{
112 async fn new(account: crate::LoginState, db: flams_database::DBBackend) -> Option<Self>;
113 async fn next(&mut self) -> Option<ServerMsg>;
114 async fn handle_message(&mut self, msg: ClientMsg) -> Option<ServerMsg>;
115 async fn on_start(&mut self, _socket: &mut axum::extract::ws::WebSocket) {}
116
117 async fn ws_handler(
118 auth_session: axum_login::AuthSession<flams_database::DBBackend>,
119 ws: axum::extract::WebSocketUpgrade,
120 ) -> axum::response::Response
121 where
122 Self: Send,
123 {
124 let login = match &auth_session.backend.admin {
125 None => crate::LoginState::NoAccounts,
126 Some(_) => match auth_session.user {
127 None => crate::LoginState::None,
128 Some(flams_database::DBUser {
129 id: 0, username, ..
130 }) if username == "admin" => crate::LoginState::Admin,
131 Some(u) => crate::LoginState::User {
132 name: u.username,
133 avatar: u.avatar_url.unwrap_or_default(),
134 is_admin: u.is_admin,
135 },
136 },
137 };
138 Self::new(login, auth_session.backend).await.map_or_else(
139 || {
140 let mut res = axum::response::Response::new(axum::body::Body::empty());
141 *(res.status_mut()) = axum::http::StatusCode::UNAUTHORIZED;
142 res
143 },
144 |conn| ws.on_upgrade(move |socket| conn.on_upgrade(socket)),
145 )
146 }
147
148 async fn on_upgrade(mut self, mut socket: axum::extract::ws::WebSocket)
149 where
150 Self: Send,
151 {
152 use axum::body::Bytes;
153 if socket
154 .send(axum::extract::ws::Message::Ping(Bytes::new()))
155 .await
156 .is_err()
157 {
158 return;
159 }
160 let timeout = std::time::Duration::from_secs_f32(Self::TIMEOUT);
161 self.on_start(&mut socket).await;
162 loop {
163 tokio::select! {
164 () = tokio::time::sleep(timeout) => if socket.send(axum::extract::ws::Message::Ping(Bytes::new())).await.is_err() {
165 return
166 },
167 msg = self.next() => if let Some(msg) = msg {
168 if let Ok(msg) = serde_json::to_string(&msg) {
169 if socket.send(axum::extract::ws::Message::Text(msg.into())).await.is_err() {
170 return
171 }
172 }
173 } else {return},
174 o = socket.recv() => match o {
175 None => break,
176 Some(msg) => match msg {
177 Ok(axum::extract::ws::Message::Ping(_)) => {
178 if socket.send(axum::extract::ws::Message::Pong(Bytes::new())).await.is_err() {
179 break
180 }
181 },
182 Ok(axum::extract::ws::Message::Text(msg)) => {
183 if let Ok(msg) = serde_json::from_str(&msg) {
184 if let Some(reply) = self.handle_message(msg).await {
185 if let Ok(reply) = serde_json::to_string(&reply) {
186 if socket.send(axum::extract::ws::Message::Text(reply.into())).await.is_err() {
187 break
188 }
189 }
190 }
191 }
192 },
193 _ => ()
194 },
195 },
196 }
197 }
198 }
199}
200
201pub trait WebSocket<
202 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
203 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
204>: Sized + 'static
205{
206 const TIMEOUT: f32 = 10.0;
207 const SERVER_ENDPOINT: &'static str;
208
209 #[cfg(feature = "ssr")]
210 fn force_start_server() {
211 }
216
217 #[cfg(feature = "hydrate")]
218 fn force_start_client(
219 handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static + Clone,
220 mut on_start: impl FnMut(Self) + 'static,
221 ) where
222 Self: WebSocketClient<ClientMsg, ServerMsg>,
223 {
224 let _res = leptos::prelude::Effect::new(move |_| {
226 if let Some(r) = Self::start(handle.clone()) {
228 on_start(r);
229 }
230 });
231 }
232}