flams_lsp/
ws.rs

1use super::{FLAMSLSPServer, ServerWrapper};
2use async_lsp::{
3    client_monitor::ClientProcessMonitorLayer, concurrency::ConcurrencyLayer,
4    panic::CatchUnwindLayer, server::LifecycleLayer, tracing::TracingLayer, ClientSocket,
5    LspService, MainLoop,
6};
7use axum::extract::ws::Message;
8use std::pin::Pin;
9use std::{
10    io::{self, ErrorKind},
11    task::{Context, Poll},
12};
13use tower::ServiceBuilder;
14
15pub fn upgrade<T: FLAMSLSPServer + Send + 'static>(
16    ws: axum::extract::WebSocketUpgrade,
17    new: impl FnOnce(ClientSocket) -> T + Send + 'static,
18) -> axum::response::Response {
19    ws.on_upgrade(|ws| {
20        let (server, _) = async_lsp::MainLoop::new_server(|client| {
21            //let server = ServerWrapper::new(new(client.clone()));
22            ServiceBuilder::new()
23                .layer(TracingLayer::default())
24                .layer(LifecycleLayer::default())
25                .layer(CatchUnwindLayer::default())
26                .layer(ConcurrencyLayer::default())
27                .layer(ClientProcessMonitorLayer::new(client.clone()))
28                .service(ServerWrapper::new(new(client)).router())
29        });
30        let socket = SocketWrapper {
31            inner: std::sync::Arc::new(parking_lot::Mutex::new(ws)),
32            read_buf: Vec::new(),
33        };
34        run(socket, server)
35    })
36}
37
38#[allow(clippy::future_not_send)]
39async fn run<T: LspService<Response = serde_json::value::Value>>(
40    socket: SocketWrapper,
41    main: MainLoop<T>,
42) where
43    async_lsp::ResponseError: From<T::Error>,
44{
45    if let Err(e) = main.run_buffered(socket.clone(), socket).await {
46        tracing::error!("Error: {:?}", e);
47    }
48}
49
50#[derive(Clone)]
51struct SocketWrapper {
52    inner: std::sync::Arc<parking_lot::Mutex<axum::extract::ws::WebSocket>>,
53    read_buf: Vec<u8>,
54}
55
56impl SocketWrapper {
57    fn poll_read_internal(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60        buf: &mut [u8],
61    ) -> Poll<io::Result<usize>> {
62        let this = self.get_mut();
63        let mut lock = this.inner.lock();
64        let inner = Pin::new(&mut *lock);
65        let r = futures::Stream::poll_next(inner, cx);
66        drop(lock);
67
68        r.map(|result| match result {
69            None => Ok(0),
70            Some(Err(e)) => Err(io::Error::new(ErrorKind::Other, e)),
71            Some(Ok(message)) => match message {
72                Message::Text(text) => Ok(Self::handle_incoming_data(
73                    buf,
74                    text.as_bytes(),
75                    &mut this.read_buf,
76                )),
77                Message::Binary(binary) => {
78                    Ok(Self::handle_incoming_data(buf, &binary, &mut this.read_buf))
79                }
80                Message::Close(_) => Err(io::Error::new(ErrorKind::BrokenPipe, "WebSocket closed")),
81                Message::Ping(_) | Message::Pong(_) => Ok(0), // Ignore control frames
82            },
83        })
84    }
85
86    fn handle_incoming_data(buf: &mut [u8], data: &[u8], read_buf: &mut Vec<u8>) -> usize {
87        let data_len = data.len();
88        let buf_len = buf.len();
89        if data_len > buf_len {
90            buf.copy_from_slice(&data[..buf_len]);
91            read_buf.extend_from_slice(&data[buf_len..]);
92            buf_len
93        } else {
94            buf[..data_len].copy_from_slice(data);
95            data_len
96        }
97    }
98}
99
100impl futures::AsyncRead for SocketWrapper {
101    fn poll_read(
102        mut self: Pin<&mut Self>,
103        cx: &mut Context<'_>,
104        buf: &mut [u8],
105    ) -> Poll<io::Result<usize>> {
106        if !self.read_buf.is_empty() {
107            let to_copy = std::cmp::min(buf.len(), self.read_buf.len());
108            buf[..to_copy].copy_from_slice(&self.read_buf[..to_copy]);
109            self.read_buf.drain(..to_copy);
110            return Poll::Ready(Ok(to_copy));
111        }
112        self.as_mut().poll_read_internal(cx, buf)
113    }
114}
115
116impl futures::AsyncBufRead for SocketWrapper {
117    fn poll_fill_buf(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
118        if self.read_buf.is_empty() {
119            match self.as_mut().poll_read_internal(cx, &mut []) {
120                Poll::Ready(Ok(0)) | Poll::Pending => (),
121                Poll::Ready(Ok(_)) => unreachable!(),
122                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
123            }
124        }
125
126        let this = self.into_ref().get_ref();
127        if this.read_buf.is_empty() {
128            Poll::Pending
129        } else {
130            Poll::Ready(Ok(&this.read_buf))
131        }
132    }
133
134    fn consume(self: Pin<&mut Self>, amt: usize) {
135        let this = self.get_mut();
136        this.read_buf
137            .drain(..std::cmp::min(amt, this.read_buf.len()));
138    }
139}
140
141impl futures::AsyncWrite for SocketWrapper {
142    fn poll_write(
143        self: Pin<&mut Self>,
144        cx: &mut Context<'_>,
145        buf: &[u8],
146    ) -> Poll<io::Result<usize>> {
147        let msg = match buf.to_vec().try_into() {
148            Ok(m) => m,
149            Err(e) => return Poll::Ready(Err(io::Error::new(ErrorKind::Other, e))),
150        };
151        let message = Message::Text(msg);
152        let mut lock = self.inner.lock();
153        let inner = Pin::new(&mut *lock);
154
155        match futures::Sink::poll_ready(inner, cx) {
156            Poll::Pending => Poll::Pending,
157            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(ErrorKind::Other, e))),
158            Poll::Ready(Ok(())) => Poll::Ready(
159                futures::Sink::start_send(Pin::new(&mut *lock), message)
160                    .map(|()| buf.len())
161                    .map_err(|e| io::Error::new(ErrorKind::Other, e)),
162            ),
163        }
164    }
165
166    #[allow(clippy::significant_drop_tightening)]
167    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
168        let mut lock = self.inner.lock();
169        let inner = Pin::new(&mut *lock);
170        futures::Sink::poll_flush(inner, cx).map_err(|e| io::Error::new(ErrorKind::Other, e))
171    }
172
173    #[allow(clippy::significant_drop_tightening)]
174    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175        let mut lock = self.inner.lock();
176        let inner = Pin::new(&mut *lock);
177        futures::Sink::poll_close(inner, cx).map_err(|e| io::Error::new(ErrorKind::Other, e))
178    }
179}