use std::{ collections::{HashMap, HashSet}, net::SocketAddr, sync::Arc, }; use anyhow::{bail, Context, Result}; use axum::{ body::Body, extract::{ ws::{Message, WebSocket}, Path, Request, State, WebSocketUpgrade, }, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Form, Router, }; use axum_extra::extract::CookieJar; use maud::{html, Markup, PreEscaped, DOCTYPE}; use rand::random; use serde::Deserialize; use tokio::sync::{ broadcast::{self, Receiver, Sender}, Mutex, }; use tower_http::{ catch_panic::CatchPanicLayer, services::ServeDir, trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, }; use tracing::{error, info, trace, warn, Level}; use tracing_subscriber::FmtSubscriber; /// If a handler panics we run this function and return the [Response]. fn handle_panic(err: Box) -> Response { let details = if let Some(s) = err.downcast_ref::() { s.clone() } else if let Some(s) = err.downcast_ref::<&str>() { s.to_string() } else { "Unknown panic message".to_string() }; error!(details = details, "Handler paniced"); (StatusCode::INTERNAL_SERVER_ERROR).into_response() } #[derive(Deserialize, Debug, Clone)] struct NoteDatum { // Epoch time in ms time: u64, // Fraction up/down y: f64, } #[derive(Deserialize, Debug, Clone)] struct Notes { notes: Vec, } #[derive(Deserialize)] struct NotesForm { notes: String, } struct Round { author: UserId, notes: Notes, guesses: HashMap, } struct Game { id: u64, // If we've started or not is_started: bool, /// All users that are in the game. active_users: HashSet, /// All submissions submissions: HashMap, /// Stack for UserIds whose rounds have not been played next_rounds: Vec, /// Data for each round rounds: Vec, /// Sender to broadcast things broadcast_tx: Sender, /// Broadcast receiver to send to WebSocket broadcast_rx: Receiver, senders: HashMap>, } impl Game { /// Game starts. fn start(&mut self) { assert!(!self.is_started, "Cannot re-start a game"); self.is_started = true; self.active_users = self.senders.keys().cloned().collect(); self.broadcast_screen(self.screen_compose()) .context("Broadcast compose screen") .unwrap(); } /// User submits a melody. fn submit(&mut self, uid: UserId, notes: Notes) { if !self.active_users.contains(&uid) { warn!(uid=?uid, "Inactive user tried to submit"); return; } if self.submissions.contains_key(&uid) { warn!(uid=?uid, "User tried to submit twice"); return; } self.submissions.insert(uid, notes); self.message_screen(uid, self.screen_submitted()) .context("message submit screen") .unwrap(); if self.active_users.len() == self.submissions.len() { info!("All users have submitted"); let users = self.submissions.keys().cloned(); self.next_rounds = users.collect(); self.new_round(); } } fn new_round(&mut self) { let Some(uid) = self.next_rounds.pop() else { error!("No more rounds"); return; }; self.rounds.push(Round { author: uid, notes: self .submissions .get(&uid) .context("get users submission") .unwrap() .clone(), guesses: HashMap::new(), }); self.broadcast_screen(self.screen_round()) .context("broadcast screen round") .unwrap(); } fn screen_submitted(&self) -> Markup { html! { h1 { "Submitted!" } } } fn screen_compose(&self) -> Markup { let gid = self.id; html! { script { ({PreEscaped("document.notes = []; ")})} h1 { "Playing Game" } p { "Were playing here" } canvas #canvas height=(256) width=(256) style="position:relative; background: cyan;" {}; script { (PreEscaped(r#"document.getElementById("canvas").onclick = (e) => { const notes = JSON.parse(e.target.dataset.notes ?? "[]"); notes.push({ time: Date.now(), y: e.layerY / e.target.height}); e.target.dataset.notes = JSON.stringify(notes); }"#))} button hx-post=(format!("/game/{gid}/submit")) hx-vals=r#"js:{ "notes": document.getElementById("canvas").dataset.notes }"# { "Submit" } } } fn screen_lobby(&self) -> Markup { let gid = self.id; html! { h1 { "Game" } p { "Current game id is " (gid)} div #messages {}; button hx-post=(format!("/game/{gid}/start")) { "Start" } } } fn screen_round(&self) -> Markup { let i = self.rounds.len(); let n = self.rounds.len() + self.next_rounds.len(); let Some(r) = self.rounds.last() else { return html! { h1 { "No more rounds" }}; }; html! { h1 { (format!("Round {i}/{n}"))} p { "Author was " (r.author.0)} } } fn broadcast_screen(&self, html: Markup) -> Result<()> { self.broadcast_tx .send( html! { section #content { (html) } } .into_string(), ) .map(|_| ()) .context("Failed to send") } /// Message a single user a screen. fn message_screen(&self, user: UserId, html: Markup) -> Result<()> { let Some(send) = self.senders.get(&user) else { bail!("No such user") }; send.send( html! { section #content { (html) } } .into_string(), ) .map(|_| ()) .context("Failed to send") } } struct ServerState { games: HashMap>>, } impl ServerState { fn new_game(&mut self, gid: u64) { let (tx, rx) = broadcast::channel(16); let g = Game { id: gid, is_started: false, active_users: HashSet::new(), submissions: HashMap::new(), next_rounds: Vec::new(), rounds: Vec::new(), broadcast_rx: rx, broadcast_tx: tx, senders: HashMap::new(), }; self.games.insert(gid, Arc::new(Mutex::new(g))); } } type Server = Arc>; async fn route_index(State(st): State) -> Response { let game_ids = st.lock().await.games.keys().cloned().collect::>(); let game_list = if game_ids.is_empty() { html! { "No current games"} } else { html! { h3 { "Games" } ul { @for id in &game_ids { li { a href=(format!("/game/{id}")) { (id) } } } } } }; let html = html! { h1 { "Good game "} (game_list) button hx-post="/game" { "Create game" } }; html_response(html! { (DOCTYPE) head { meta charset="utf-8"; title { "Good Game" }; meta name="viewport" content="width=device-width, initial-scale=1" {}; script src="https://unpkg.com/htmx.org@2.0.4" {}; link rel="stylesheet" type="text/css" href="static/style.css"; } body { (html) } }) } fn html_response(html: PreEscaped) -> Response { Response::builder() .status(StatusCode::OK) .body(Body::from(html.into_string())) .expect("Failed to set body") } async fn create_game(State(st): State) -> Response { let gid = random::(); let _ = { let mut s = st.lock().await; s.new_game(gid); }; Response::builder() .header("hx-redirect", format!("/game/{gid}")) .body(Body::empty()) .expect("failed to build response") } async fn start_game(Path(gid): Path, State(st): State) -> Response { let lock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return StatusCode::NOT_FOUND.into_response(); } }; lock.lock().await.start(); StatusCode::OK.into_response() } async fn get_game( Path(gid): Path, Extension(uid): Extension, State(st): State, ) -> Response { let game = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; let html = { let g = game.lock().await; if !g.is_started { g.screen_lobby() } else if g.rounds.len() == 0 { if g.submissions.contains_key(&uid) { g.screen_submitted() } else { g.screen_compose() } } else { // TODO warn!("Wrong screen"); g.screen_lobby() } }; html_response(html! { (DOCTYPE) head { meta charset="utf-8"; title { "Good Game" }; meta name="viewport" content="width=device-width, initial-scale=1" {}; script src="https://unpkg.com/htmx.org@2.0.4" {}; script src="https://unpkg.com/htmx-ext-ws@2.0.1/ws.js" {}; link rel="stylesheet" type="text/css" href="/static/style.css"; } body hx-ext="ws" ws-connect=(format!("/game/{gid}/ws")) { footer { } section #content { (html) } footer { span { "You are " (uid.0) } } } }) } async fn submit_game( Path(gid): Path, State(st): State, Extension(uid): Extension, Form(form): Form, ) -> Response { let glock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; let Ok(notes) = serde_json::from_str(&form.notes) else { warn!(notes = form.notes); return (StatusCode::UNPROCESSABLE_ENTITY).into_response(); }; glock.lock().await.submit(uid, Notes { notes }); StatusCode::OK.into_response() } async fn game_ws( ws: WebSocketUpgrade, Path(gid): Path, Extension(uid): Extension, State(st): State, ) -> Response { let game = { let s = st.lock().await; if let Some(g) = s.games.get(&gid) { g.clone() } else { return StatusCode::NOT_FOUND.into_response(); } }; ws.on_upgrade(move |socket| handle_socket(socket, uid, game)) .into_response() } async fn handle_socket(mut socket: WebSocket, uid: UserId, game: Arc>) { if socket .send(Message::Ping(bytes::Bytes::from("wat"))) .await .is_ok() { trace!("Pinged ws"); } else { println!("Could not send ping!"); return; } let mut single_rx = { let mut g = game.lock().await; // g.broadcast_tx // .send( // html! { // div #messages hx-swap-oob="beforeend" { p { "Someone joined" } } // } // .into_string(), // ) // .unwrap(); let (tx, rx) = broadcast::channel(16); g.senders.insert(uid, tx); rx }; let mut broadcast_rx = { game.lock().await.broadcast_rx.resubscribe() }; loop { tokio::select! { res = single_rx.recv() => { if let Ok(msg) = res { let res = socket.send(msg.into()).await; if res.is_err() { warn!(res=?res, "Error from socket.send"); break; } } else { warn!(res=?res, "Error from single_rx"); break; } }, res = broadcast_rx.recv() => { if let Ok(msg) = res { let res = socket.send(msg.into()).await; if res.is_err() { warn!(res=?res, "Error from socket.send"); break; } } else { warn!(res=?res, "Error from broadcast_rx"); break; } }, } } } #[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] struct UserId(u64); async fn identity_middleware(jar: CookieJar, mut req: Request, next: Next) -> Response { const UID: &'static str = "uid"; let opt_uid = jar.get(UID); let uid = opt_uid .and_then(|c| c.value().parse::().ok()) .unwrap_or_else(|| random::()); let uid = UserId(uid); req.extensions_mut().get_or_insert::(uid); let mut res = next.run(req).await; if opt_uid.is_none() { let cookie = cookie::Cookie::build(("uid", uid.0.to_string())) .same_site(cookie::SameSite::Strict) .path("/") .http_only(true) .secure(true) .build(); res.headers_mut() .append("Set-Cookie", cookie.encoded().to_string().parse().unwrap()); } res } #[tokio::main] async fn main() -> Result<()> { let subscriber = FmtSubscriber::builder() .with_max_level(Level::TRACE) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("set global default subscriber failed"); let mut server_state = ServerState { games: HashMap::new(), }; server_state.new_game(123); let st = Arc::new(Mutex::new(server_state)); let app = Router::new() .route("/", get(route_index)) .route("/game", post(create_game)) .route("/game/{gid}", get(get_game)) .route("/game/{gid}/start", post(start_game)) .route("/game/{gid}/submit", post(submit_game)) .route("/game/{gid}/ws", get(game_ws)) .nest_service("/static", ServeDir::new("static")) .layer(CatchPanicLayer::custom(handle_panic)) .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) .on_response(DefaultOnResponse::new().level(Level::INFO)), ) .layer(middleware::from_fn(identity_middleware)) .with_state(st.clone()); let addr = "0.0.0.0:4800"; info!("Listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve( listener, app.into_make_service_with_connect_info::(), ) .await .context("serve") }