flams_router_base/
ws.rs

1#[cfg(feature = "ssr")]
2pub use axum::extract::ws::Message as WSMessage;
3#[cfg(feature = "ssr")]
4pub use axum::extract::ws::WebSocket as AxumWS;
5#[cfg(feature = "ssr")]
6pub use flams_database::DBBackend;
7
8#[cfg(feature = "hydrate")]
9fn js_to_string(e: leptos::wasm_bindgen::JsValue) -> String {
10    use leptos::web_sys::js_sys::Object;
11    Object::from(e).to_string().into()
12}
13
14#[cfg(feature = "hydrate")]
15pub trait WebSocketClient<
16    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
17    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
18>: WebSocket<ClientMsg, ServerMsg>
19{
20    fn new(ws: leptos::web_sys::WebSocket) -> Self;
21    fn socket(&mut self) -> &mut leptos::web_sys::WebSocket;
22
23    fn send(&mut self, msg: &ClientMsg) {
24        let Ok(s) = serde_json::to_string(msg) else {
25            tracing::error!("Error serializing websocket message");
26            return;
27        };
28        if let Err(e) = self.socket().send_with_str(&s) {
29            tracing::error!("Error sending websocket message: {}", js_to_string(e));
30        }
31    }
32
33    #[allow(clippy::cognitive_complexity)]
34    fn callback(
35        ws: &leptos::web_sys::WebSocket,
36        handle: &mut impl (FnMut(ServerMsg) -> Option<ClientMsg>),
37        event: leptos::web_sys::MessageEvent,
38    ) {
39        let Some(data) = event.data().as_string() else {
40            tracing::error!("Not a string: {}", js_to_string(event.data()));
41            return;
42        };
43        if data == "ping" {
44            if let Err(e) = ws.send_with_str("pong") {
45                tracing::error!("Error sending websocket message: {}", js_to_string(e));
46            }
47        } else {
48            let mut deserializer = serde_json::Deserializer::from_str(&data);
49            deserializer.disable_recursion_limit();
50            let value = ServerMsg::deserialize(&mut deserializer);
51            let ret = match value {
52                Ok(msg) => msg,
53                Err(e) => {
54                    tracing::error!("{e}");
55                    return;
56                }
57            };
58            if let Some(a) = handle(ret) {
59                let Ok(s) = serde_json::to_string(&a) else {
60                    tracing::error!("Error serializing websocket message");
61                    return;
62                };
63                if let Err(e) = ws.send_with_str(&s) {
64                    tracing::error!("Error sending websocket message: {}", js_to_string(e));
65                }
66            }
67        }
68    }
69
70    fn start(mut handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static) -> Option<Self> {
71        use leptos::wasm_bindgen::JsCast;
72        use leptos::wasm_bindgen::prelude::Closure;
73        let ws = match leptos::web_sys::WebSocket::new(Self::SERVER_ENDPOINT) {
74            Ok(ws) => ws,
75            Err(e) => {
76                tracing::error!(
77                    "Error creating websockecrate::users::t: {}",
78                    js_to_string(e)
79                );
80                return None;
81            }
82        };
83        let ws2 = ws.clone();
84        let callback =
85            Closure::<dyn FnMut(_)>::new(move |event| Self::callback(&ws2, &mut handle, event));
86        ws.set_onmessage(Some(callback.as_ref().unchecked_ref()));
87        let mut r = Self::new(ws);
88        callback.forget();
89        if let Some(mut f) = r.on_open() {
90            let callback = Closure::<dyn FnMut(_)>::new(move |_: leptos::web_sys::MessageEvent| {
91                f();
92            });
93            r.socket()
94                .set_onopen(Some(callback.as_ref().unchecked_ref()));
95            callback.forget();
96        }
97        Some(r)
98    }
99
100    fn on_open(&self) -> Option<Box<dyn FnMut()>> {
101        None
102    }
103}
104
105#[cfg(feature = "ssr")]
106#[async_trait::async_trait]
107pub trait WebSocketServer<
108    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
109    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
110>: WebSocket<ClientMsg, ServerMsg>
111{
112    async fn new(account: crate::LoginState, db: flams_database::DBBackend) -> Option<Self>;
113    async fn next(&mut self) -> Option<ServerMsg>;
114    async fn handle_message(&mut self, msg: ClientMsg) -> Option<ServerMsg>;
115    async fn on_start(&mut self, _socket: &mut axum::extract::ws::WebSocket) {}
116
117    async fn ws_handler(
118        auth_session: axum_login::AuthSession<flams_database::DBBackend>,
119        ws: axum::extract::WebSocketUpgrade,
120    ) -> axum::response::Response
121    where
122        Self: Send,
123    {
124        let login = match &auth_session.backend.admin {
125            None => crate::LoginState::NoAccounts,
126            Some(_) => match auth_session.user {
127                None => crate::LoginState::None,
128                Some(flams_database::DBUser {
129                    id: 0, username, ..
130                }) if username == "admin" => crate::LoginState::Admin,
131                Some(u) => crate::LoginState::User {
132                    name: u.username,
133                    avatar: u.avatar_url.unwrap_or_default(),
134                    is_admin: u.is_admin,
135                },
136            },
137        };
138        Self::new(login, auth_session.backend).await.map_or_else(
139            || {
140                let mut res = axum::response::Response::new(axum::body::Body::empty());
141                *(res.status_mut()) = axum::http::StatusCode::UNAUTHORIZED;
142                res
143            },
144            |conn| ws.on_upgrade(move |socket| conn.on_upgrade(socket)),
145        )
146    }
147
148    async fn on_upgrade(mut self, mut socket: axum::extract::ws::WebSocket)
149    where
150        Self: Send,
151    {
152        use axum::body::Bytes;
153        if socket
154            .send(axum::extract::ws::Message::Ping(Bytes::new()))
155            .await
156            .is_err()
157        {
158            return;
159        }
160        let timeout = std::time::Duration::from_secs_f32(Self::TIMEOUT);
161        self.on_start(&mut socket).await;
162        loop {
163            tokio::select! {
164                () = tokio::time::sleep(timeout) => if socket.send(axum::extract::ws::Message::Ping(Bytes::new())).await.is_err() {
165                    return
166                },
167                msg = self.next() => if let Some(msg) = msg {
168                    if let Ok(msg) = serde_json::to_string(&msg) {
169                        if socket.send(axum::extract::ws::Message::Text(msg.into())).await.is_err() {
170                            return
171                        }
172                    }
173                } else {return},
174                o = socket.recv() => match o {
175                    None => break,
176                    Some(msg) => match msg {
177                        Ok(axum::extract::ws::Message::Ping(_)) => {
178                            if socket.send(axum::extract::ws::Message::Pong(Bytes::new())).await.is_err() {
179                                break
180                            }
181                        },
182                        Ok(axum::extract::ws::Message::Text(msg)) => {
183                            if let Ok(msg) = serde_json::from_str(&msg) {
184                                if let Some(reply) = self.handle_message(msg).await {
185                                    if let Ok(reply) = serde_json::to_string(&reply) {
186                                        if socket.send(axum::extract::ws::Message::Text(reply.into())).await.is_err() {
187                                            break
188                                        }
189                                    }
190                                }
191                            }
192                        },
193                        _ => ()
194                    },
195                },
196            }
197        }
198    }
199}
200
201pub trait WebSocket<
202    ClientMsg: serde::Serialize + for<'a> serde::Deserialize<'a> + Send,
203    ServerMsg: serde::Serialize + std::fmt::Debug + for<'a> serde::Deserialize<'a> + Send,
204>: Sized + 'static
205{
206    const TIMEOUT: f32 = 10.0;
207    const SERVER_ENDPOINT: &'static str;
208
209    #[cfg(feature = "ssr")]
210    fn force_start_server() {
211        //let (signal_read,_) = signal(false);
212        //let _res = Effect::new(move |_| {
213        //    let _ = signal_read.get();
214        //});
215    }
216
217    #[cfg(feature = "hydrate")]
218    fn force_start_client(
219        handle: impl (FnMut(ServerMsg) -> Option<ClientMsg>) + 'static + Clone,
220        mut on_start: impl FnMut(Self) + 'static,
221    ) where
222        Self: WebSocketClient<ClientMsg, ServerMsg>,
223    {
224        //let (signal_read,_) = signal(false);
225        let _res = leptos::prelude::Effect::new(move |_| {
226            //let _ = signal_read.get();
227            if let Some(r) = Self::start(handle.clone()) {
228                on_start(r);
229            }
230        });
231    }
232}