summaryrefslogtreecommitdiff
path: root/src/hot_reload.rs
blob: 88cfc217ba9c5f7f96ea955891eafe53bc4cda3e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
use std::{
    env,
    net::SocketAddr,
    path::{Path, PathBuf},
    sync::Arc,
};

use anyhow::Context;
use axum::{
    extract::{
        ws::{Message, WebSocket},
        ConnectInfo, State, WebSocketUpgrade,
    },
    response::IntoResponse,
    routing, Router,
};
use bytes::Bytes;
use notify::{Event, FsEventWatcher, Watcher};
use tracing::trace;

pub struct HotReload {
    watcher: FsEventWatcher,

    tx: tokio::sync::broadcast::Sender<String>,
    rx: tokio::sync::broadcast::Receiver<String>,
}

impl HotReload {
    pub fn new(base_path: impl AsRef<Path>) -> Arc<Self> {
        Arc::new_cyclic(|hr| {
            let weak = hr.clone();
            let static_dir = PathBuf::from(base_path.as_ref())
                .canonicalize()
                .context("canonicalize static_dir")
                .unwrap();

            let (_tx, rx) = tokio::sync::broadcast::channel(16);

            let tx = _tx.clone();

            let mut watcher = notify::recommended_watcher(
                move |res: std::result::Result<Event, notify::Error>| {
                    let Some(hot) = weak.upgrade() else {
                        return;
                    };
                    match res {
                        Ok(event) => {
                            for path in &event.paths {
                                let Ok(p) = path.strip_prefix(&static_dir) else {
                                    continue;
                                };
                                if p.extension().is_some_and(|o| o == "css" || o == "js") {
                                    let s = p.file_name().unwrap().to_string_lossy().to_string();
                                    tx.send(s).expect("Failed to send to channel");
                                }
                            }
                        }
                        Err(e) => println!("watch error: {:?}", e),
                    }
                },
            )
            .context("create watcher")
            .unwrap();

            watcher
                .watch(std::path::Path::new("."), notify::RecursiveMode::Recursive)
                .context("watcher.watch")
                .unwrap();

            HotReload {
                watcher,
                tx: _tx,
                rx,
            }
        })
    }
}

pub fn router(watch_dir: impl AsRef<Path>) -> Router<()> {
    let hrl = HotReload::new(watch_dir);
    Router::new()
        .route("/ws", routing::get(handler_ws))
        .with_state(hrl)
}

pub async fn handler_ws(
    ws: WebSocketUpgrade,
    State(st): State<Arc<HotReload>>,
    ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
    trace!("Connected to ws");
    ws.on_upgrade(move |socket| handle_socket(socket, addr, st))
}

async fn handle_socket(mut socket: WebSocket, _: SocketAddr, st: Arc<HotReload>) {
    if socket.send(Message::Ping(Bytes::new())).await.is_ok() {
        trace!("Pinged ws");
    } else {
        println!("Could not send ping!");
        return;
    }

    let mut rx = st.rx.resubscribe();
    while let Ok(a) = rx.recv().await {
        trace!(a, "send ");
        let res = socket.send(a.into()).await;
        trace!(?res);
    }
}