1use flams_utils::parking_lot;
2use std::marker::PhantomData;
3
4#[cfg(feature = "ssr")]
5pub use axum::extract::ws::Message as WSMessage;
6#[cfg(feature = "ssr")]
7pub use axum::extract::ws::WebSocket as AxumWS;
8#[cfg(feature = "ssr")]
9pub use flams_database::DBBackend;
10
11#[cfg(feature = "hydrate")]
12#[derive(Debug)]
13pub struct WSClient<
14 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
15 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
16> {
17 socket: leptos::web_sys::WebSocket,
18 _phantom: PhantomData<(ClientMsg, ServerMsg)>,
19 queue: std::sync::Arc<parking_lot::Mutex<Option<Vec<String>>>>,
20}
21
22#[cfg(feature = "hydrate")]
23impl<
24 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
25 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
26> Clone for WSClient<ClientMsg, ServerMsg>
27{
28 #[inline]
29 fn clone(&self) -> Self {
30 Self {
31 socket: self.socket.clone(),
32 queue: self.queue.clone(),
33 _phantom: PhantomData,
34 }
35 }
36}
37
38#[cfg(feature = "hydrate")]
39impl<
40 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
41 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
42> WSClient<ClientMsg, ServerMsg>
43{
44 pub fn send(&self, msg: &ClientMsg) {
45 let Ok(s) = serde_json::to_string(msg) else {
46 tracing::error!("Error serializing websocket message");
47 return;
48 };
49 {
50 if let Some(v) = &mut *self.queue.lock() {
51 v.push(s);
52 return;
53 }
54 }
55
56 if let Err(e) = self.socket.send_with_str(&s) {
57 tracing::error!("Error sending websocket message: {}", js_to_string(e));
58 }
59 }
60
61 #[inline]
62 pub fn new(endpoint: &str, mut handler: impl FnMut(ServerMsg) + 'static) -> Option<Self> {
63 Self::new_i(
64 endpoint,
65 Box::new(move |s| {
66 let mut deserializer = serde_json::Deserializer::from_str(&s);
67 deserializer.disable_recursion_limit();
68 let value = ServerMsg::deserialize(&mut deserializer);
69 match value {
70 Ok(msg) => handler(msg),
71 Err(e) => {
72 tracing::error!("{e}");
73 }
74 }
75 }),
76 )
77 }
78 fn new_i(endpoint: &str, mut handler: Box<dyn FnMut(String) + 'static>) -> Option<Self> {
79 use leptos::wasm_bindgen::JsCast;
80 use leptos::wasm_bindgen::prelude::Closure;
81 let ws = match leptos::web_sys::WebSocket::new(endpoint) {
82 Ok(ws) => ws,
83 Err(e) => {
84 tracing::error!("Error creating websocket: {}", js_to_string(e));
85 return None;
86 }
87 };
88 let ws2 = ws.clone();
89 let callback =
90 Closure::<dyn FnMut(_)>::new(move |event| callback(&ws2, &mut *handler, event));
91 ws.set_onmessage(Some(callback.as_ref().unchecked_ref()));
92 callback.forget();
93
94 let r = Self {
95 socket: ws,
96 queue: std::sync::Arc::default(),
97 _phantom: PhantomData,
98 };
99 let ws = r.socket.clone();
100 let queue = r.queue.clone();
101 let callback = Closure::<dyn FnMut(_)>::new(move |_: leptos::web_sys::MessageEvent| {
102 if let Some(queue) = { queue.lock().take() } {
103 for s in queue {
104 let _ = ws.send_with_str(&s);
105 }
106 }
107 });
108 r.socket.set_onopen(Some(callback.as_ref().unchecked_ref()));
109 callback.forget();
110 Some(r)
111 }
112}
113
114#[cfg(feature = "hydrate")]
115#[allow(clippy::needless_pass_by_value)]
116fn callback(
117 ws: &leptos::web_sys::WebSocket,
118 handler: &mut dyn FnMut(String),
119 event: leptos::web_sys::MessageEvent,
120) {
121 let Some(data) = event.data().as_string() else {
122 tracing::error!("Not a string: {}", js_to_string(event.data()));
123 return;
124 };
125 if data == "ping" {
126 if let Err(e) = ws.send_with_str("pong") {
127 tracing::error!("Error sending websocket message: {}", js_to_string(e));
128 }
129 } else {
130 handler(data);
131 }
132}
133
134#[cfg(feature = "ssr")]
135#[derive(Clone)]
136pub struct WSSocket<
137 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
138 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
139> {
140 socket: tokio::sync::mpsc::UnboundedSender<ServerMsg>,
141 _phantom: PhantomData<(ClientMsg, ServerMsg)>,
142}
143
144#[cfg(feature = "ssr")]
145pub trait WSServerSocket<
146 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
147 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
148>: Sized + Sync + 'static
149{
150 const TIMEOUT: f32 = 10.0;
151
152 fn new(socket: WSSocket<ClientMsg, ServerMsg>) -> impl Future<Output = Self> + Send;
153 fn handle(&self, msg: ClientMsg) -> impl Future<Output = bool> + Send;
154 fn span(&self) -> Option<&'static tracing::Span>;
155
156 fn allow_user(state: crate::LoginState) -> bool {
157 true
158 }
159
160 async fn handler(
161 auth_session: axum_login::AuthSession<flams_database::DBBackend>,
162 ws: axum::extract::WebSocketUpgrade,
163 ) -> axum::response::Response
164 where
165 Self: Send,
166 {
167 let login = match &auth_session.backend.admin {
168 None => crate::LoginState::NoAccounts,
169 Some(_) => match auth_session.user {
170 None => crate::LoginState::None,
171 Some(flams_database::DBUser {
172 id: 0, username, ..
173 }) if username == "admin" => crate::LoginState::Admin,
174 Some(u) => crate::LoginState::User {
175 name: u.username,
176 avatar: u.avatar_url.unwrap_or_default(),
177 is_admin: u.is_admin,
178 },
179 },
180 };
181 if !Self::allow_user(login) {
182 let mut res = axum::response::Response::new(axum::body::Body::empty());
183 *(res.status_mut()) = axum::http::StatusCode::UNAUTHORIZED;
184 return res;
185 }
186 ws.on_upgrade(move |mut ws| async move {
187 use axum::body::Bytes;
188 let timeout = std::time::Duration::from_secs_f32(Self::TIMEOUT);
189 let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
190
191 let socket = WSSocket {
192 socket: sender.clone(),
193 _phantom: PhantomData,
194 };
195 let slf = std::sync::Arc::new(Self::new(socket).await);
196 let span = slf.span().map(tracing::Span::enter);
197 let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
198 loop {
199 if cancel.load(std::sync::atomic::Ordering::Acquire) {
200 tracing::info!("Dropping websocket");
201 return
202 }
203 tokio::select! {
204 () = tokio::time::sleep(timeout) => {
205 if ws.send(axum::extract::ws::Message::Ping(Bytes::new())).await.is_err() {
206 tracing::info!("Error sending ping; Dropping websocket");
207 return
208 }
209 },
210 msg = receiver.recv() => {
211 match msg {
212 None => {
213 tracing::info!("Receiver closed; Dropping websocket");
214 return
215 },
216 Some(msg) => {
217 if let Ok(msg) = serde_json::to_string(&msg) && {
218 tracing::info!("Returning {}",msg);
219 ws.send(axum::extract::ws::Message::Text(msg.into())).await.is_err()
220 } {
221 tracing::info!("Error serializing result; Dropping websocket");
222 return
223 }
224 }
225 }
226 }
227 msg = ws.recv() => {
228 match msg {
229 None => {
230 tracing::info!("Received None-message from client; Dropping websocket");
231 return
232 },
233 Some(Ok(axum::extract::ws::Message::Ping(_))) => {
234 if ws.send(axum::extract::ws::Message::Pong(Bytes::new())).await.is_err() {
235 tracing::info!("Ping not returned; Dropping websocket");
236 return;
237 }
238 }
239 Some(Ok(axum::extract::ws::Message::Text(msg))) => {
240 tracing::info!("Received {}",msg);
241 let cancel = cancel.clone();
242 let slf = slf.clone();
243 #[allow(clippy::let_underscore_future)]
244 let _ = tokio::task::spawn(async move {
245 match serde_json::from_str(&msg) {
246 Ok(msg) => {
247 if !slf.handle(msg).await {
248 tracing::info!("handle returned false; Dropping websocket");
249 cancel.store(true, std::sync::atomic::Ordering::Release);
250 }
251 }
252 Err(e) => {
253 tracing::error!("Error: {e:?}");
254 }
255 }
256 });
257 }
258 _ => ()
259 }
260 }
261 }
262 }
263 })
264 }
265}
266
267#[cfg(feature = "ssr")]
281impl<
282 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
283 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
284> WSSocket<ClientMsg, ServerMsg>
285{
286 #[inline]
287 pub fn send(&self, msg: ServerMsg) {
288 let _ = self.socket.send(msg);
289 }
290}
291
292#[cfg(feature = "hydrate")]
293fn js_to_string(e: leptos::wasm_bindgen::JsValue) -> String {
294 use leptos::web_sys::js_sys::Object;
295 Object::from(e).to_string().into()
296}
297
298#[cfg(feature = "hydrate")]
299pub trait WebSocketClient<
300 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
301 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
302>: WebSocket<ClientMsg, ServerMsg>
303{
304 fn new(ws: leptos::web_sys::WebSocket) -> Self;
305 fn socket(&mut self) -> &mut leptos::web_sys::WebSocket;
306
307 fn send(&mut self, msg: &ClientMsg) {
308 let Ok(s) = serde_json::to_string(msg) else {
309 tracing::error!("Error serializing websocket message");
310 return;
311 };
312 if let Err(e) = self.socket().send_with_str(&s) {
313 tracing::error!("Error sending websocket message: {}", js_to_string(e));
314 }
315 }
316
317 #[allow(clippy::cognitive_complexity)]
318 fn callback(
319 ws: &leptos::web_sys::WebSocket,
320 handle: &mut impl FnMut(ServerMsg) -> Option<ClientMsg>,
321 event: leptos::web_sys::MessageEvent,
322 ) {
323 let Some(data) = event.data().as_string() else {
324 tracing::error!("Not a string: {}", js_to_string(event.data()));
325 return;
326 };
327 if data == "ping" {
328 if let Err(e) = ws.send_with_str("pong") {
329 tracing::error!("Error sending websocket message: {}", js_to_string(e));
330 }
331 } else {
332 let mut deserializer = serde_json::Deserializer::from_str(&data);
333 deserializer.disable_recursion_limit();
334 let value = ServerMsg::deserialize(&mut deserializer);
335 let ret = match value {
336 Ok(msg) => msg,
337 Err(e) => {
338 tracing::error!("{e}");
339 return;
340 }
341 };
342 if let Some(a) = handle(ret) {
343 let Ok(s) = serde_json::to_string(&a) else {
344 tracing::error!("Error serializing websocket message");
345 return;
346 };
347 if let Err(e) = ws.send_with_str(&s) {
348 tracing::error!("Error sending websocket message: {}", js_to_string(e));
349 }
350 }
351 }
352 }
353
354 fn start(mut handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static) -> Option<Self> {
355 use leptos::wasm_bindgen::JsCast;
356 use leptos::wasm_bindgen::prelude::Closure;
357 let ws = match leptos::web_sys::WebSocket::new(Self::SERVER_ENDPOINT) {
358 Ok(ws) => ws,
359 Err(e) => {
360 tracing::error!("Error creating websocket: {}", js_to_string(e));
361 return None;
362 }
363 };
364 let ws2 = ws.clone();
365 let callback =
366 Closure::<dyn FnMut(_)>::new(move |event| Self::callback(&ws2, &mut handle, event));
367 ws.set_onmessage(Some(callback.as_ref().unchecked_ref()));
368 let mut r = Self::new(ws);
369 callback.forget();
370 if let Some(mut f) = r.on_open() {
371 let callback = Closure::<dyn FnMut(_)>::new(move |_: leptos::web_sys::MessageEvent| {
372 f();
373 });
374 r.socket()
375 .set_onopen(Some(callback.as_ref().unchecked_ref()));
376 callback.forget();
377 }
378 Some(r)
379 }
380
381 fn on_open(&self) -> Option<Box<dyn FnMut()>> {
382 None
383 }
384}
385
386#[cfg(feature = "ssr")]
387#[async_trait::async_trait]
388pub trait WebSocketServer<
389 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
390 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
391>: WebSocket<ClientMsg, ServerMsg>
392{
393 async fn new(account: crate::LoginState, db: flams_database::DBBackend) -> Option<Self>;
394 async fn next(&mut self) -> Option<ServerMsg>;
395 async fn handle_message(&mut self, msg: ClientMsg) -> Option<ServerMsg>;
396 async fn on_start(&mut self, _socket: &mut axum::extract::ws::WebSocket) {}
397
398 async fn ws_handler(
399 auth_session: axum_login::AuthSession<flams_database::DBBackend>,
400 ws: axum::extract::WebSocketUpgrade,
401 ) -> axum::response::Response
402 where
403 Self: Send,
404 {
405 let login = match &auth_session.backend.admin {
406 None => crate::LoginState::NoAccounts,
407 Some(_) => match auth_session.user {
408 None => crate::LoginState::None,
409 Some(flams_database::DBUser {
410 id: 0, username, ..
411 }) if username == "admin" => crate::LoginState::Admin,
412 Some(u) => crate::LoginState::User {
413 name: u.username,
414 avatar: u.avatar_url.unwrap_or_default(),
415 is_admin: u.is_admin,
416 },
417 },
418 };
419 Self::new(login, auth_session.backend).await.map_or_else(
420 || {
421 let mut res = axum::response::Response::new(axum::body::Body::empty());
422 *(res.status_mut()) = axum::http::StatusCode::UNAUTHORIZED;
423 res
424 },
425 |conn| ws.on_upgrade(move |socket| conn.on_upgrade(socket)),
426 )
427 }
428
429 async fn on_upgrade(mut self, mut socket: axum::extract::ws::WebSocket)
430 where
431 Self: Send,
432 {
433 use axum::body::Bytes;
434 if socket
435 .send(axum::extract::ws::Message::Ping(Bytes::new()))
436 .await
437 .is_err()
438 {
439 return;
440 }
441 let timeout = std::time::Duration::from_secs_f32(Self::TIMEOUT);
442 self.on_start(&mut socket).await;
443 loop {
444 tokio::select! {
445 () = tokio::time::sleep(timeout) => {
446 if socket.send(axum::extract::ws::Message::Ping(Bytes::new())).await.is_err() {
447 return
448 }
449 },
450 msg = self.next() => {
451 if let Some(msg) = msg {
452 if let Ok(msg) = serde_json::to_string(&msg) {
453 if socket.send(axum::extract::ws::Message::Text(msg.into())).await.is_err() {
454 return
455 }
456 }
457 } else {return}
458 },
459 o = socket.recv() => {
460 match o {
461 None => {
462 break
463 },
464 Some(msg) => match msg {
465 Ok(axum::extract::ws::Message::Ping(_)) => {
466 if socket.send(axum::extract::ws::Message::Pong(Bytes::new())).await.is_err() {
467 break
468 }
469 },
470 Ok(axum::extract::ws::Message::Text(msg)) => {
471 if let Ok(msg) = serde_json::from_str(&msg) {
472 if let Some(reply) = self.handle_message(msg).await {
473 if let Ok(reply) = serde_json::to_string(&reply) {
474 if socket.send(axum::extract::ws::Message::Text(reply.into())).await.is_err() {
475 break
476 }
477 }
478 }
479 }
480 },
481 _ => ()
482 },
483 }
484 },
485 }
486 }
487 }
488}
489
490pub trait WebSocket<
491 ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
492 ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
493>: Sized + 'static
494{
495 const TIMEOUT: f32 = 10.0;
496 const SERVER_ENDPOINT: &'static str;
497
498 #[cfg(feature = "ssr")]
499 fn force_start_server() {
500 }
505
506 #[cfg(feature = "hydrate")]
507 fn force_start_client(
508 handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static + Clone,
509 mut on_start: impl FnMut(Self) + 'static,
510 ) where
511 Self: WebSocketClient<ClientMsg, ServerMsg>,
512 {
513 let _res = leptos::prelude::Effect::new(move |_| {
515 if let Some(r) = Self::start(handle.clone()) {
517 on_start(r);
518 }
519 });
520 }
521}