diff options
Diffstat (limited to 'src/hot_reload.rs')
| -rw-r--r-- | src/hot_reload.rs | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/src/hot_reload.rs b/src/hot_reload.rs new file mode 100644 index 0000000..88cfc21 --- /dev/null +++ b/src/hot_reload.rs @@ -0,0 +1,109 @@ +use std::{ + env, + net::SocketAddr, + path::{Path, PathBuf}, + sync::Arc, +}; + +use anyhow::Context; +use axum::{ + extract::{ + ws::{Message, WebSocket}, + ConnectInfo, State, WebSocketUpgrade, + }, + response::IntoResponse, + routing, Router, +}; +use bytes::Bytes; +use notify::{Event, FsEventWatcher, Watcher}; +use tracing::trace; + +pub struct HotReload { + watcher: FsEventWatcher, + + tx: tokio::sync::broadcast::Sender<String>, + rx: tokio::sync::broadcast::Receiver<String>, +} + +impl HotReload { + pub fn new(base_path: impl AsRef<Path>) -> Arc<Self> { + Arc::new_cyclic(|hr| { + let weak = hr.clone(); + let static_dir = PathBuf::from(base_path.as_ref()) + .canonicalize() + .context("canonicalize static_dir") + .unwrap(); + + let (_tx, rx) = tokio::sync::broadcast::channel(16); + + let tx = _tx.clone(); + + let mut watcher = notify::recommended_watcher( + move |res: std::result::Result<Event, notify::Error>| { + let Some(hot) = weak.upgrade() else { + return; + }; + match res { + Ok(event) => { + for path in &event.paths { + let Ok(p) = path.strip_prefix(&static_dir) else { + continue; + }; + if p.extension().is_some_and(|o| o == "css" || o == "js") { + let s = p.file_name().unwrap().to_string_lossy().to_string(); + tx.send(s).expect("Failed to send to channel"); + } + } + } + Err(e) => println!("watch error: {:?}", e), + } + }, + ) + .context("create watcher") + .unwrap(); + + watcher + .watch(std::path::Path::new("."), notify::RecursiveMode::Recursive) + .context("watcher.watch") + .unwrap(); + + HotReload { + watcher, + tx: _tx, + rx, + } + }) + } +} + +pub fn router(watch_dir: impl AsRef<Path>) -> Router<()> { + let hrl = HotReload::new(watch_dir); + Router::new() + .route("/ws", routing::get(handler_ws)) + .with_state(hrl) +} + +pub async fn handler_ws( + ws: WebSocketUpgrade, + State(st): State<Arc<HotReload>>, + ConnectInfo(addr): ConnectInfo<SocketAddr>, +) -> impl IntoResponse { + trace!("Connected to ws"); + ws.on_upgrade(move |socket| handle_socket(socket, addr, st)) +} + +async fn handle_socket(mut socket: WebSocket, _: SocketAddr, st: Arc<HotReload>) { + if socket.send(Message::Ping(Bytes::new())).await.is_ok() { + trace!("Pinged ws"); + } else { + println!("Could not send ping!"); + return; + } + + let mut rx = st.rx.resubscribe(); + while let Ok(a) = rx.recv().await { + trace!(a, "send "); + let res = socket.send(a.into()).await; + trace!(?res); + } +} |
