flams_router_base/
ws.rs

1use flams_utils::parking_lot;
2use std::marker::PhantomData;
3
4#[cfg(feature = "ssr")]
5pub use axum::extract::ws::Message as WSMessage;
6#[cfg(feature = "ssr")]
7pub use axum::extract::ws::WebSocket as AxumWS;
8#[cfg(feature = "ssr")]
9pub use flams_database::DBBackend;
10
11#[cfg(feature = "hydrate")]
12#[derive(Debug)]
13pub struct WSClient<
14    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
15    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
16> {
17    socket: leptos::web_sys::WebSocket,
18    _phantom: PhantomData<(ClientMsg, ServerMsg)>,
19    queue: std::sync::Arc<parking_lot::Mutex<Option<Vec<String>>>>,
20}
21
22#[cfg(feature = "hydrate")]
23impl<
24    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
25    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
26> Clone for WSClient<ClientMsg, ServerMsg>
27{
28    #[inline]
29    fn clone(&self) -> Self {
30        Self {
31            socket: self.socket.clone(),
32            queue: self.queue.clone(),
33            _phantom: PhantomData,
34        }
35    }
36}
37
38#[cfg(feature = "hydrate")]
39impl<
40    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
41    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
42> WSClient<ClientMsg, ServerMsg>
43{
44    pub fn send(&self, msg: &ClientMsg) {
45        let Ok(s) = serde_json::to_string(msg) else {
46            tracing::error!("Error serializing websocket message");
47            return;
48        };
49        {
50            if let Some(v) = &mut *self.queue.lock() {
51                v.push(s);
52                return;
53            }
54        }
55
56        if let Err(e) = self.socket.send_with_str(&s) {
57            tracing::error!("Error sending websocket message: {}", js_to_string(e));
58        }
59    }
60
61    #[inline]
62    pub fn new(endpoint: &str, mut handler: impl FnMut(ServerMsg) + 'static) -> Option<Self> {
63        Self::new_i(
64            endpoint,
65            Box::new(move |s| {
66                let mut deserializer = serde_json::Deserializer::from_str(&s);
67                deserializer.disable_recursion_limit();
68                let value = ServerMsg::deserialize(&mut deserializer);
69                match value {
70                    Ok(msg) => handler(msg),
71                    Err(e) => {
72                        tracing::error!("{e}");
73                    }
74                }
75            }),
76        )
77    }
78    fn new_i(endpoint: &str, mut handler: Box<dyn FnMut(String) + 'static>) -> Option<Self> {
79        use leptos::wasm_bindgen::JsCast;
80        use leptos::wasm_bindgen::prelude::Closure;
81        let ws = match leptos::web_sys::WebSocket::new(endpoint) {
82            Ok(ws) => ws,
83            Err(e) => {
84                tracing::error!("Error creating websocket: {}", js_to_string(e));
85                return None;
86            }
87        };
88        let ws2 = ws.clone();
89        let callback =
90            Closure::<dyn FnMut(_)>::new(move |event| callback(&ws2, &mut *handler, event));
91        ws.set_onmessage(Some(callback.as_ref().unchecked_ref()));
92        callback.forget();
93
94        let r = Self {
95            socket: ws,
96            queue: std::sync::Arc::default(),
97            _phantom: PhantomData,
98        };
99        let ws = r.socket.clone();
100        let queue = r.queue.clone();
101        let callback = Closure::<dyn FnMut(_)>::new(move |_: leptos::web_sys::MessageEvent| {
102            if let Some(queue) = { queue.lock().take() } {
103                for s in queue {
104                    let _ = ws.send_with_str(&s);
105                }
106            }
107        });
108        r.socket.set_onopen(Some(callback.as_ref().unchecked_ref()));
109        callback.forget();
110        Some(r)
111    }
112}
113
114#[cfg(feature = "hydrate")]
115#[allow(clippy::needless_pass_by_value)]
116fn callback(
117    ws: &leptos::web_sys::WebSocket,
118    handler: &mut dyn FnMut(String),
119    event: leptos::web_sys::MessageEvent,
120) {
121    let Some(data) = event.data().as_string() else {
122        tracing::error!("Not a string: {}", js_to_string(event.data()));
123        return;
124    };
125    if data == "ping" {
126        if let Err(e) = ws.send_with_str("pong") {
127            tracing::error!("Error sending websocket message: {}", js_to_string(e));
128        }
129    } else {
130        handler(data);
131    }
132}
133
134#[cfg(feature = "ssr")]
135#[derive(Clone)]
136pub struct WSSocket<
137    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
138    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
139> {
140    socket: tokio::sync::mpsc::UnboundedSender<ServerMsg>,
141    _phantom: PhantomData<(ClientMsg, ServerMsg)>,
142}
143
144#[cfg(feature = "ssr")]
145pub trait WSServerSocket<
146    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
147    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
148>: Sized + Sync + 'static
149{
150    const TIMEOUT: f32 = 10.0;
151
152    fn new(socket: WSSocket<ClientMsg, ServerMsg>) -> impl Future<Output = Self> + Send;
153    fn handle(&self, msg: ClientMsg) -> impl Future<Output = bool> + Send;
154    fn span(&self) -> Option<&'static tracing::Span>;
155
156    fn allow_user(state: crate::LoginState) -> bool {
157        true
158    }
159
160    async fn handler(
161        auth_session: axum_login::AuthSession<flams_database::DBBackend>,
162        ws: axum::extract::WebSocketUpgrade,
163    ) -> axum::response::Response
164    where
165        Self: Send,
166    {
167        let login = match &auth_session.backend.admin {
168            None => crate::LoginState::NoAccounts,
169            Some(_) => match auth_session.user {
170                None => crate::LoginState::None,
171                Some(flams_database::DBUser {
172                    id: 0, username, ..
173                }) if username == "admin" => crate::LoginState::Admin,
174                Some(u) => crate::LoginState::User {
175                    name: u.username,
176                    avatar: u.avatar_url.unwrap_or_default(),
177                    is_admin: u.is_admin,
178                },
179            },
180        };
181        if !Self::allow_user(login) {
182            let mut res = axum::response::Response::new(axum::body::Body::empty());
183            *(res.status_mut()) = axum::http::StatusCode::UNAUTHORIZED;
184            return res;
185        }
186        ws.on_upgrade(move |mut ws| async move {
187            use axum::body::Bytes;
188            let timeout = std::time::Duration::from_secs_f32(Self::TIMEOUT);
189            let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
190
191            let socket = WSSocket {
192                socket: sender.clone(),
193                _phantom: PhantomData,
194            };
195            let slf = std::sync::Arc::new(Self::new(socket).await);
196            let span = slf.span().map(tracing::Span::enter);
197            let cancel = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
198            loop {
199                if cancel.load(std::sync::atomic::Ordering::Acquire) {
200                    tracing::info!("Dropping websocket");
201                    return
202                }
203                tokio::select! {
204                    () = tokio::time::sleep(timeout) => {
205                        if ws.send(axum::extract::ws::Message::Ping(Bytes::new())).await.is_err() {
206                            tracing::info!("Error sending ping; Dropping websocket");
207                            return
208                        }
209                    },
210                    msg = receiver.recv() => {
211                        match msg {
212                            None => {
213                                tracing::info!("Receiver closed; Dropping websocket");
214                                return
215                            },
216                            Some(msg) => {
217                                if let Ok(msg) = serde_json::to_string(&msg) && {
218                                    tracing::info!("Returning {}",msg);
219                                    ws.send(axum::extract::ws::Message::Text(msg.into())).await.is_err()
220                                } {
221                                    tracing::info!("Error serializing result; Dropping websocket");
222                                    return
223                                }
224                            }
225                        }
226                    }
227                    msg = ws.recv() => {
228                        match msg {
229                            None => {
230                                tracing::info!("Received None-message from client; Dropping websocket");
231                                return
232                            },
233                            Some(Ok(axum::extract::ws::Message::Ping(_))) => {
234                                if ws.send(axum::extract::ws::Message::Pong(Bytes::new())).await.is_err() {
235                                    tracing::info!("Ping not returned; Dropping websocket");
236                                    return;
237                                }
238                            }
239                            Some(Ok(axum::extract::ws::Message::Text(msg))) => {
240                                tracing::info!("Received {}",msg);
241                                let cancel = cancel.clone();
242                                let slf = slf.clone();
243                                #[allow(clippy::let_underscore_future)]
244                                let _ = tokio::task::spawn(async move {
245                                    match serde_json::from_str(&msg) {
246                                        Ok(msg) => {
247                                            if !slf.handle(msg).await {
248                                                tracing::info!("handle returned false; Dropping websocket");
249                                                cancel.store(true, std::sync::atomic::Ordering::Release);
250                                            }
251                                        }
252                                        Err(e) => {
253                                            tracing::error!("Error: {e:?}");
254                                        }
255                                    }
256                                });
257                            }
258                            _ => ()
259                        }
260                    }
261                }
262            }
263        })
264    }
265}
266
267/*
268impl<
269    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
270    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
271> Clone for WSServer<ClientMsg,ServerMsg> {
272    fn clone(&self) -> Self {
273        Self {
274            socket:self.socket.clone()
275        }
276    }
277}
278*/
279
280#[cfg(feature = "ssr")]
281impl<
282    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
283    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
284> WSSocket<ClientMsg, ServerMsg>
285{
286    #[inline]
287    pub fn send(&self, msg: ServerMsg) {
288        let _ = self.socket.send(msg);
289    }
290}
291
292#[cfg(feature = "hydrate")]
293fn js_to_string(e: leptos::wasm_bindgen::JsValue) -> String {
294    use leptos::web_sys::js_sys::Object;
295    Object::from(e).to_string().into()
296}
297
298#[cfg(feature = "hydrate")]
299pub trait WebSocketClient<
300    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
301    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
302>: WebSocket<ClientMsg, ServerMsg>
303{
304    fn new(ws: leptos::web_sys::WebSocket) -> Self;
305    fn socket(&mut self) -> &mut leptos::web_sys::WebSocket;
306
307    fn send(&mut self, msg: &ClientMsg) {
308        let Ok(s) = serde_json::to_string(msg) else {
309            tracing::error!("Error serializing websocket message");
310            return;
311        };
312        if let Err(e) = self.socket().send_with_str(&s) {
313            tracing::error!("Error sending websocket message: {}", js_to_string(e));
314        }
315    }
316
317    #[allow(clippy::cognitive_complexity)]
318    fn callback(
319        ws: &leptos::web_sys::WebSocket,
320        handle: &mut impl FnMut(ServerMsg) -> Option<ClientMsg>,
321        event: leptos::web_sys::MessageEvent,
322    ) {
323        let Some(data) = event.data().as_string() else {
324            tracing::error!("Not a string: {}", js_to_string(event.data()));
325            return;
326        };
327        if data == "ping" {
328            if let Err(e) = ws.send_with_str("pong") {
329                tracing::error!("Error sending websocket message: {}", js_to_string(e));
330            }
331        } else {
332            let mut deserializer = serde_json::Deserializer::from_str(&data);
333            deserializer.disable_recursion_limit();
334            let value = ServerMsg::deserialize(&mut deserializer);
335            let ret = match value {
336                Ok(msg) => msg,
337                Err(e) => {
338                    tracing::error!("{e}");
339                    return;
340                }
341            };
342            if let Some(a) = handle(ret) {
343                let Ok(s) = serde_json::to_string(&a) else {
344                    tracing::error!("Error serializing websocket message");
345                    return;
346                };
347                if let Err(e) = ws.send_with_str(&s) {
348                    tracing::error!("Error sending websocket message: {}", js_to_string(e));
349                }
350            }
351        }
352    }
353
354    fn start(mut handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static) -> Option<Self> {
355        use leptos::wasm_bindgen::JsCast;
356        use leptos::wasm_bindgen::prelude::Closure;
357        let ws = match leptos::web_sys::WebSocket::new(Self::SERVER_ENDPOINT) {
358            Ok(ws) => ws,
359            Err(e) => {
360                tracing::error!("Error creating websocket: {}", js_to_string(e));
361                return None;
362            }
363        };
364        let ws2 = ws.clone();
365        let callback =
366            Closure::<dyn FnMut(_)>::new(move |event| Self::callback(&ws2, &mut handle, event));
367        ws.set_onmessage(Some(callback.as_ref().unchecked_ref()));
368        let mut r = Self::new(ws);
369        callback.forget();
370        if let Some(mut f) = r.on_open() {
371            let callback = Closure::<dyn FnMut(_)>::new(move |_: leptos::web_sys::MessageEvent| {
372                f();
373            });
374            r.socket()
375                .set_onopen(Some(callback.as_ref().unchecked_ref()));
376            callback.forget();
377        }
378        Some(r)
379    }
380
381    fn on_open(&self) -> Option<Box<dyn FnMut()>> {
382        None
383    }
384}
385
386#[cfg(feature = "ssr")]
387#[async_trait::async_trait]
388pub trait WebSocketServer<
389    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
390    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
391>: WebSocket<ClientMsg, ServerMsg>
392{
393    async fn new(account: crate::LoginState, db: flams_database::DBBackend) -> Option<Self>;
394    async fn next(&mut self) -> Option<ServerMsg>;
395    async fn handle_message(&mut self, msg: ClientMsg) -> Option<ServerMsg>;
396    async fn on_start(&mut self, _socket: &mut axum::extract::ws::WebSocket) {}
397
398    async fn ws_handler(
399        auth_session: axum_login::AuthSession<flams_database::DBBackend>,
400        ws: axum::extract::WebSocketUpgrade,
401    ) -> axum::response::Response
402    where
403        Self: Send,
404    {
405        let login = match &auth_session.backend.admin {
406            None => crate::LoginState::NoAccounts,
407            Some(_) => match auth_session.user {
408                None => crate::LoginState::None,
409                Some(flams_database::DBUser {
410                    id: 0, username, ..
411                }) if username == "admin" => crate::LoginState::Admin,
412                Some(u) => crate::LoginState::User {
413                    name: u.username,
414                    avatar: u.avatar_url.unwrap_or_default(),
415                    is_admin: u.is_admin,
416                },
417            },
418        };
419        Self::new(login, auth_session.backend).await.map_or_else(
420            || {
421                let mut res = axum::response::Response::new(axum::body::Body::empty());
422                *(res.status_mut()) = axum::http::StatusCode::UNAUTHORIZED;
423                res
424            },
425            |conn| ws.on_upgrade(move |socket| conn.on_upgrade(socket)),
426        )
427    }
428
429    async fn on_upgrade(mut self, mut socket: axum::extract::ws::WebSocket)
430    where
431        Self: Send,
432    {
433        use axum::body::Bytes;
434        if socket
435            .send(axum::extract::ws::Message::Ping(Bytes::new()))
436            .await
437            .is_err()
438        {
439            return;
440        }
441        let timeout = std::time::Duration::from_secs_f32(Self::TIMEOUT);
442        self.on_start(&mut socket).await;
443        loop {
444            tokio::select! {
445                () = tokio::time::sleep(timeout) => {
446                    if socket.send(axum::extract::ws::Message::Ping(Bytes::new())).await.is_err() {
447                        return
448                    }
449                },
450                msg = self.next() => {
451                    if let Some(msg) = msg {
452                    if let Ok(msg) = serde_json::to_string(&msg) {
453                        if socket.send(axum::extract::ws::Message::Text(msg.into())).await.is_err() {
454                            return
455                        }
456                    }
457                } else {return}
458                },
459                o = socket.recv() => {
460                    match o {
461                        None => {
462                            break
463                        },
464                        Some(msg) => match msg {
465                            Ok(axum::extract::ws::Message::Ping(_)) => {
466                                if socket.send(axum::extract::ws::Message::Pong(Bytes::new())).await.is_err() {
467                                    break
468                                }
469                            },
470                            Ok(axum::extract::ws::Message::Text(msg)) => {
471                                if let Ok(msg) = serde_json::from_str(&msg) {
472                                    if let Some(reply) = self.handle_message(msg).await {
473                                        if let Ok(reply) = serde_json::to_string(&reply) {
474                                            if socket.send(axum::extract::ws::Message::Text(reply.into())).await.is_err() {
475                                                break
476                                            }
477                                        }
478                                    }
479                                }
480                            },
481                            _ => ()
482                        },
483                    }
484                },
485            }
486        }
487    }
488}
489
490pub trait WebSocket<
491    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
492    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
493>: Sized + 'static
494{
495    const TIMEOUT: f32 = 10.0;
496    const SERVER_ENDPOINT: &'static str;
497
498    #[cfg(feature = "ssr")]
499    fn force_start_server() {
500        //let (signal_read,_) = signal(false);
501        //let _res = Effect::new(move |_| {
502        //    let _ = signal_read.get();
503        //});
504    }
505
506    #[cfg(feature = "hydrate")]
507    fn force_start_client(
508        handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static + Clone,
509        mut on_start: impl FnMut(Self) + 'static,
510    ) where
511        Self: WebSocketClient<ClientMsg, ServerMsg>,
512    {
513        //let (signal_read,_) = signal(false);
514        let _res = leptos::prelude::Effect::new(move |_| {
515            //let _ = signal_read.get();
516            if let Some(r) = Self::start(handle.clone()) {
517                on_start(r);
518            }
519        });
520    }
521}