1use super::{FLAMSLSPServer, ServerWrapper};
2use async_lsp::{
3 client_monitor::ClientProcessMonitorLayer, concurrency::ConcurrencyLayer,
4 panic::CatchUnwindLayer, server::LifecycleLayer, tracing::TracingLayer, ClientSocket,
5 LspService, MainLoop,
6};
7use axum::extract::ws::Message;
8use std::pin::Pin;
9use std::{
10 io::{self, ErrorKind},
11 task::{Context, Poll},
12};
13use tower::ServiceBuilder;
14
15pub fn upgrade<T: FLAMSLSPServer + Send + 'static>(
16 ws: axum::extract::WebSocketUpgrade,
17 new: impl FnOnce(ClientSocket) -> T + Send + 'static,
18) -> axum::response::Response {
19 ws.on_upgrade(|ws| {
20 let (server, _) = async_lsp::MainLoop::new_server(|client| {
21 ServiceBuilder::new()
23 .layer(TracingLayer::default())
24 .layer(LifecycleLayer::default())
25 .layer(CatchUnwindLayer::default())
26 .layer(ConcurrencyLayer::default())
27 .layer(ClientProcessMonitorLayer::new(client.clone()))
28 .service(ServerWrapper::new(new(client)).router())
29 });
30 let socket = SocketWrapper {
31 inner: std::sync::Arc::new(parking_lot::Mutex::new(ws)),
32 read_buf: Vec::new(),
33 };
34 run(socket, server)
35 })
36}
37
38#[allow(clippy::future_not_send)]
39async fn run<T: LspService<Response = serde_json::value::Value>>(
40 socket: SocketWrapper,
41 main: MainLoop<T>,
42) where
43 async_lsp::ResponseError: From<T::Error>,
44{
45 if let Err(e) = main.run_buffered(socket.clone(), socket).await {
46 tracing::error!("Error: {:?}", e);
47 }
48}
49
50#[derive(Clone)]
51struct SocketWrapper {
52 inner: std::sync::Arc<parking_lot::Mutex<axum::extract::ws::WebSocket>>,
53 read_buf: Vec<u8>,
54}
55
56impl SocketWrapper {
57 fn poll_read_internal(
58 self: Pin<&mut Self>,
59 cx: &mut Context<'_>,
60 buf: &mut [u8],
61 ) -> Poll<io::Result<usize>> {
62 let this = self.get_mut();
63 let mut lock = this.inner.lock();
64 let inner = Pin::new(&mut *lock);
65 let r = futures::Stream::poll_next(inner, cx);
66 drop(lock);
67
68 r.map(|result| match result {
69 None => Ok(0),
70 Some(Err(e)) => Err(io::Error::new(ErrorKind::Other, e)),
71 Some(Ok(message)) => match message {
72 Message::Text(text) => Ok(Self::handle_incoming_data(
73 buf,
74 text.as_bytes(),
75 &mut this.read_buf,
76 )),
77 Message::Binary(binary) => {
78 Ok(Self::handle_incoming_data(buf, &binary, &mut this.read_buf))
79 }
80 Message::Close(_) => Err(io::Error::new(ErrorKind::BrokenPipe, "WebSocket closed")),
81 Message::Ping(_) | Message::Pong(_) => Ok(0), },
83 })
84 }
85
86 fn handle_incoming_data(buf: &mut [u8], data: &[u8], read_buf: &mut Vec<u8>) -> usize {
87 let data_len = data.len();
88 let buf_len = buf.len();
89 if data_len > buf_len {
90 buf.copy_from_slice(&data[..buf_len]);
91 read_buf.extend_from_slice(&data[buf_len..]);
92 buf_len
93 } else {
94 buf[..data_len].copy_from_slice(data);
95 data_len
96 }
97 }
98}
99
100impl futures::AsyncRead for SocketWrapper {
101 fn poll_read(
102 mut self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &mut [u8],
105 ) -> Poll<io::Result<usize>> {
106 if !self.read_buf.is_empty() {
107 let to_copy = std::cmp::min(buf.len(), self.read_buf.len());
108 buf[..to_copy].copy_from_slice(&self.read_buf[..to_copy]);
109 self.read_buf.drain(..to_copy);
110 return Poll::Ready(Ok(to_copy));
111 }
112 self.as_mut().poll_read_internal(cx, buf)
113 }
114}
115
116impl futures::AsyncBufRead for SocketWrapper {
117 fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
118 if self.read_buf.is_empty() {
119 match self.as_mut().poll_read_internal(cx, &mut []) {
120 Poll::Ready(Ok(0)) | Poll::Pending => (),
121 Poll::Ready(Ok(_)) => unreachable!(),
122 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
123 }
124 }
125
126 let this = self.into_ref().get_ref();
127 if this.read_buf.is_empty() {
128 Poll::Pending
129 } else {
130 Poll::Ready(Ok(&this.read_buf))
131 }
132 }
133
134 fn consume(self: Pin<&mut Self>, amt: usize) {
135 let this = self.get_mut();
136 this.read_buf
137 .drain(..std::cmp::min(amt, this.read_buf.len()));
138 }
139}
140
141impl futures::AsyncWrite for SocketWrapper {
142 fn poll_write(
143 self: Pin<&mut Self>,
144 cx: &mut Context<'_>,
145 buf: &[u8],
146 ) -> Poll<io::Result<usize>> {
147 let msg = match buf.to_vec().try_into() {
148 Ok(m) => m,
149 Err(e) => return Poll::Ready(Err(io::Error::new(ErrorKind::Other, e))),
150 };
151 let message = Message::Text(msg);
152 let mut lock = self.inner.lock();
153 let inner = Pin::new(&mut *lock);
154
155 match futures::Sink::poll_ready(inner, cx) {
156 Poll::Pending => Poll::Pending,
157 Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(ErrorKind::Other, e))),
158 Poll::Ready(Ok(())) => Poll::Ready(
159 futures::Sink::start_send(Pin::new(&mut *lock), message)
160 .map(|()| buf.len())
161 .map_err(|e| io::Error::new(ErrorKind::Other, e)),
162 ),
163 }
164 }
165
166 #[allow(clippy::significant_drop_tightening)]
167 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168 let mut lock = self.inner.lock();
169 let inner = Pin::new(&mut *lock);
170 futures::Sink::poll_flush(inner, cx).map_err(|e| io::Error::new(ErrorKind::Other, e))
171 }
172
173 #[allow(clippy::significant_drop_tightening)]
174 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175 let mut lock = self.inner.lock();
176 let inner = Pin::new(&mut *lock);
177 futures::Sink::poll_close(inner, cx).map_err(|e| io::Error::new(ErrorKind::Other, e))
178 }
179}