use std::{ collections::{HashMap, HashSet}, net::SocketAddr, sync::Arc, }; use anyhow::{bail, Context, Result}; use axum::{ body::Body, extract::{ ws::{Message, WebSocket}, Path, Request, State, WebSocketUpgrade, }, http::StatusCode, middleware::{self, Next}, response::{IntoResponse, Response}, routing::{get, post}, Extension, Form, Router, }; use axum_extra::extract::CookieJar; use maud::{html, Markup, PreEscaped, DOCTYPE}; use rand::{random, seq::SliceRandom}; use serde::{Deserialize, Serialize}; use tokio::sync::{ broadcast::{self, Receiver, Sender}, Mutex, }; use tower_http::{ catch_panic::CatchPanicLayer, services::ServeDir, trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, }; use tracing::{error, info, trace, warn, Level}; use tracing_subscriber::FmtSubscriber; /// If a handler panics we run this function and return the [Response]. fn handle_panic(err: Box) -> Response { let details = if let Some(s) = err.downcast_ref::() { s.clone() } else if let Some(s) = err.downcast_ref::<&str>() { s.to_string() } else { "Unknown panic message".to_string() }; error!(details = details, "Handler paniced"); (StatusCode::INTERNAL_SERVER_ERROR).into_response() } #[derive(Hash, Eq, PartialEq, Copy, Clone, Debug, PartialOrd, Ord, Deserialize)] struct UserId(u64); #[derive(Serialize, Deserialize, Debug, Clone)] struct NoteDatum { // Epoch time in ms time: u64, // Fraction up/down y: f64, } #[derive(Deserialize, Debug, Clone)] struct Notes { notes: Vec, } struct Round { author: UserId, guesses: HashMap, /// [true] means the judge has accepted it. Missing means they haven't done anything correct: HashMap, sound_path: String, } struct Game { id: u64, // If we've started or not is_started: bool, is_finished: bool, /// All users that are in the game. active_users: HashSet, /// All submissions submissions: HashMap, /// Stack for UserIds whose rounds have not been played next_rounds: Vec, /// Data for each round rounds: Vec, user_to_sound: HashMap, /// Sender to broadcast things broadcast_tx: Sender, /// Broadcast receiver to send to WebSocket broadcast_rx: Receiver, senders: HashMap>, } impl Game { /// Game starts. fn start(&mut self) { assert!(!self.is_started, "Cannot re-start a game"); self.is_started = true; self.active_users = self.senders.keys().cloned().collect(); self.broadcast_screen(self.screen_compose()) .context("Broadcast compose screen") .unwrap(); } fn restart(&mut self) { assert!( self.is_finished, "Game needs to be finished to be restarted" ); self.is_started = true; self.is_finished = false; self.active_users = self.senders.keys().cloned().collect(); self.submissions = HashMap::new(); self.next_rounds = Vec::new(); self.rounds = Vec::new(); self.broadcast_screen(self.screen_compose()) .context("Broadcast compose screen") .unwrap(); } /// User submits a melody. fn submit(&mut self, uid: UserId, notes: Notes) { if !self.active_users.contains(&uid) { warn!(uid=?uid, "Inactive user tried to submit"); return; } if self.submissions.contains_key(&uid) { warn!(uid=?uid, "User tried to submit twice"); return; } self.submissions.insert(uid, notes); if self.active_users.len() == self.submissions.len() { info!("All users have submitted"); let users = self.submissions.keys().cloned(); self.next_rounds = users.collect(); let mut rng = rand::thread_rng(); let sound_fx = [ "cough.mp3", "wow.mp3", "cat.mp3", "pop.mp3", "car.mp3", "bike.mp3", "dog.mp3", "o.mp3", ] .choose_multiple(&mut rng, self.next_rounds.len()); self.user_to_sound = self .next_rounds .iter() .zip(sound_fx) .map(|(u, s)| (*u, s.to_string())) .collect(); self.new_round(); } else { self.message_screen(uid, self.screen_submitted()) .context("message submit screen") .unwrap(); } } fn new_round(&mut self) { let Some(uid) = self.next_rounds.pop() else { error!("No more rounds"); return; }; let sound_fx = self .user_to_sound .get(&uid) .cloned() .unwrap_or_else(|| format!("cough.mp3")); self.rounds.push(Round { author: uid, guesses: HashMap::new(), correct: HashMap::new(), sound_path: format!("/static/{}", sound_fx), }); for &u in &self.active_users { if u != uid { self.message_screen(u, self.screen_round()) .context("message screen round") .unwrap(); } } self.message_screen(uid, self.screen_judge()) .context("message screen judge") .unwrap(); } fn user_guess(&mut self, uid: UserId, guess: String) { let r = self.rounds.last_mut().context("No current round").unwrap(); if r.guesses.contains_key(&uid) { warn!("User tried to guess again"); return; } r.guesses.insert(uid, guess.clone()); if let Some(s) = self.senders.get(&r.author) { s.send( html! { li id=(format!("guess-{}", uid.0)) hx-swap-oob="outerHTML" { (self.guess_li(uid)) } } .into_string(), ) .context("Send screen to author") .unwrap(); } else { warn!("Missing sender for author"); } } fn mark_guess(&mut self, uid: UserId, mark: bool) { let r = self.rounds.last_mut().context("No round").unwrap(); r.correct.insert(uid, mark); } fn scoring(&mut self) { self.broadcast_screen(self.screen_scoring()) .context("broadcast scoring screen") .unwrap(); self.is_finished = true; } fn screen_submitted(&self) -> Markup { html! { h1 { "Submitted!" } } } fn screen_compose(&self) -> Markup { let gid = self.id; html! { script { ({PreEscaped("document.notes = []; ")})} h1 { "Playing Game" } p { "Were playing here" } canvas #canvas height=(256) width=(256) style="position:relative; background: cyan;" {}; script { (PreEscaped(r#"document.getElementById("canvas").onclick = (e) => { const notes = JSON.parse(e.target.dataset.notes ?? "[]"); notes.push({ time: Date.now(), y: e.layerY / e.target.height}); e.target.dataset.notes = JSON.stringify(notes); }"#))} button hx-post=(format!("/game/{gid}/submit")) hx-vals=r#"js:{ "notes": document.getElementById("canvas").dataset.notes }"# { "Submit" } } } fn screen_lobby(&self) -> Markup { let gid = self.id; html! { h1 { "Game" } p { "Current game id is " (gid)} div #messages {}; button hx-post=(format!("/game/{gid}/start")) { "Start" } } } /// Players guess the melody fn screen_round(&self) -> Markup { let i = self.rounds.len(); let n = self.rounds.len() + self.next_rounds.len(); let Some(r) = self.rounds.last() else { return html! { h1 { "No more rounds" }}; }; let gid = self.id; let sound = &r.sound_path; let Some(notes) = self.submissions.get(&r.author) else { return html! { h1 { "No song!" } p { "This user didn't submit anything!" } }; }; let notes_json = serde_json::to_value(¬es.notes) .context("Convert to json should work") .unwrap() .to_string(); html! { h1 { (format!("Round {i}/{n}"))} p { "Author was " (r.author.0)} form #form-guess { input placeholder="Your guess..." name="guess" {}; input type="submit" hx-post=(format!("/game/{gid}/guess")) hx-target="#form-guess" {}; } script { (PreEscaped(format!(r#"{{ const notes = {notes_json}; const soundUrl = "{sound}"; {} }}"#, r#" loadSample(soundUrl).then((sample) => { const t0 = notes[0].time - 1_000; notes.forEach(({time, y}) => { const pitch = 60 + 12 * (1.0 - y); setTimeout(() => playSoundSample(sample, 60, pitch), time - t0); }); }); "#)))} } } /// Return the `
  • ` for the guess of `uid` so that its status is marked correctly fn guess_li(&self, uid: UserId) -> Markup { let r = self.rounds.last().unwrap(); let Some(g) = r.guesses.get(&uid) else { return html! { li id=(format!("guess-{}", uid.0)) style="display: none;" { } }; }; let gid = self.id; let curr = r.correct.get(&uid).cloned(); let accept = match curr { None => "none", Some(false) => "false", Some(true) => "true", }; let post = format!( "/game/{gid}/mark/{}/{}", uid.0, !curr.unwrap_or(false) // None should mark as True when clicked ); html! { li id=(format!("guess-{}", uid.0)) data-accept=(accept) hx-post=(post) hx-swap="outerHTML" { (g) " (click to toggle)"} } } /// Author of the melody marks correct answers fn screen_judge(&self) -> Markup { let mut v = self.active_users.iter().collect::>(); v.sort(); html! { h1 { "Mark correct answers" } ul #guesslist { @for &u in v { (self.guess_li(u)) } } button hx-post=(format!("/game/{}/judge", self.id)) { "Done" } } } fn screen_scoring(&self) -> Markup { let mut scores = HashMap::new(); for u in &self.active_users { scores.insert(u, 0); } for r in &self.rounds { for (u, yes) in r.correct.iter() { if *yes { if let Some(n) = scores.get_mut(u) { *n += 1; } else { warn!(u=?u, "Tried to give points to inactive user"); } } } } let mut v = scores.into_iter().collect::>(); v.sort_by_key(|t| -t.1); html! { h1 { "Leaderboard" } ol { @for (uid, score) in v { li { span { (uid.0) ": " (score) " points" } } } } a href="/" { "Back" } button hx-post=(format!("/game/{}/restart", self.id)) { "Start new round" } } } fn broadcast_screen(&self, html: Markup) -> Result<()> { self.broadcast_tx .send( html! { section #content { (html) } } .into_string(), ) .map(|_| ()) .context("Failed to send") } /// Message a single user a screen. fn message_screen(&self, user: UserId, html: Markup) -> Result<()> { let Some(send) = self.senders.get(&user) else { bail!("No such user") }; send.send( html! { section #content { (html) } } .into_string(), ) .map(|_| ()) .context("Failed to send") } } struct ServerState { games: HashMap>>, } impl ServerState { fn new_game(&mut self, gid: u64) { let (tx, rx) = broadcast::channel(16); let g = Game { id: gid, is_started: false, is_finished: false, active_users: HashSet::new(), submissions: HashMap::new(), next_rounds: Vec::new(), user_to_sound: HashMap::new(), rounds: Vec::new(), broadcast_rx: rx, broadcast_tx: tx, senders: HashMap::new(), }; self.games.insert(gid, Arc::new(Mutex::new(g))); } } type Server = Arc>; async fn route_index(State(st): State) -> Response { let game_ids = st.lock().await.games.keys().cloned().collect::>(); let game_list = if game_ids.is_empty() { html! { "No current games"} } else { html! { h3 { "Games" } ul { @for id in &game_ids { li { a href=(format!("/game/{id}")) { (id) } } } } } }; let html = html! { h1 { "Good game "} (game_list) button hx-post="/game" { "Create game" } }; html_response(html! { (DOCTYPE) head { meta charset="utf-8"; title { "Good Game" }; meta name="viewport" content="width=device-width, initial-scale=1" {}; script src="https://unpkg.com/htmx.org@2.0.4" {}; link rel="stylesheet" type="text/css" href="static/style.css"; } body { (html) } }) } fn html_response(html: PreEscaped) -> Response { Response::builder() .status(StatusCode::OK) .body(Body::from(html.into_string())) .expect("Failed to set body") } async fn create_game(State(st): State) -> Response { let gid = random::(); let _ = { let mut s = st.lock().await; s.new_game(gid); }; Response::builder() .header("hx-redirect", format!("/game/{gid}")) .body(Body::empty()) .expect("failed to build response") } async fn start_game(Path(gid): Path, State(st): State) -> Response { let lock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return StatusCode::NOT_FOUND.into_response(); } }; lock.lock().await.start(); StatusCode::OK.into_response() } async fn restart_game(Path(gid): Path, State(st): State) -> Response { let lock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return StatusCode::NOT_FOUND.into_response(); } }; lock.lock().await.restart(); StatusCode::OK.into_response() } async fn get_game( Path(gid): Path, Extension(uid): Extension, State(st): State, ) -> Response { let game = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; let html = { let g = game.lock().await; if !g.is_started { g.screen_lobby() } else if g.is_finished { g.screen_scoring() } else if g.rounds.len() == 0 { if g.submissions.contains_key(&uid) { g.screen_submitted() } else { g.screen_compose() } } else { if let Some(r) = g.rounds.last() { if r.author == uid { g.screen_judge() } else { g.screen_round() } } else { // TODO warn!("Missing screen, default to lobby"); g.screen_lobby() } } }; html_response(html! { (DOCTYPE) head { meta charset="utf-8"; title { "Good Game" }; meta name="viewport" content="width=device-width, initial-scale=1" {}; script src="https://unpkg.com/htmx.org@2.0.4" {}; script src="https://unpkg.com/htmx-ext-ws@2.0.1/ws.js" {}; link rel="stylesheet" type="text/css" href="/static/style.css"; script src="/static/main.js" {}; } body hx-ext="ws" ws-connect=(format!("/game/{gid}/ws")) { footer { } section #content { (html) } footer { span { "You are " (uid.0) } } } }) } #[derive(Deserialize)] struct NotesForm { notes: String, } async fn submit_game( Path(gid): Path, State(st): State, Extension(uid): Extension, Form(form): Form, ) -> Response { let glock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; let Ok(notes) = serde_json::from_str(&form.notes) else { warn!(notes = form.notes); return (StatusCode::UNPROCESSABLE_ENTITY).into_response(); }; glock.lock().await.submit(uid, Notes { notes }); StatusCode::OK.into_response() } #[derive(Deserialize)] struct GuessTune { guess: String, } async fn guess_tune( Path(gid): Path, State(st): State, Extension(uid): Extension, Form(form): Form, ) -> Response { let glock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; glock.lock().await.user_guess(uid, form.guess.clone()); html! { span { (format!(r#"You guessed "{}" "#, form.guess)) } } .into_string() .into_response() } async fn mark_guess( Path((gid, guess_id, mark)): Path<(u64, UserId, bool)>, State(st): State, Extension(uid): Extension, ) -> Response { let glock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; let mut g = glock.lock().await; let allowed = g.rounds.last().map(|r| r.author == uid).unwrap_or(false); if !allowed { return StatusCode::UNAUTHORIZED.into_response(); } g.mark_guess(guess_id, mark); g.guess_li(guess_id).into_string().into_response() } /// Judge is done judging all submissions async fn submit_judge(Path(gid): Path, State(st): State) -> Response { let glock = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() } else { return (StatusCode::NOT_FOUND).into_response(); } }; let mut g = glock.lock().await; if g.next_rounds.is_empty() { g.scoring(); } else { g.new_round(); } StatusCode::OK.into_response() } async fn game_ws( ws: WebSocketUpgrade, Path(gid): Path, Extension(uid): Extension, State(st): State, ) -> Response { let game = { let s = st.lock().await; if let Some(g) = s.games.get(&gid) { g.clone() } else { return StatusCode::NOT_FOUND.into_response(); } }; ws.on_upgrade(move |socket| handle_socket(socket, uid, game)) .into_response() } async fn handle_socket(mut socket: WebSocket, uid: UserId, game: Arc>) { if socket .send(Message::Ping(bytes::Bytes::from("wat"))) .await .is_ok() { trace!("Pinged ws"); } else { println!("Could not send ping!"); return; } let mut single_rx = { let mut g = game.lock().await; let (tx, rx) = broadcast::channel(16); g.senders.insert(uid, tx); rx }; let mut broadcast_rx = { game.lock().await.broadcast_rx.resubscribe() }; loop { tokio::select! { res = single_rx.recv() => { if let Ok(msg) = res { let res = socket.send(msg.into()).await; if res.is_err() { warn!(res=?res, "Error from socket.send"); break; } } else { warn!(res=?res, "Error from single_rx"); break; } }, res = broadcast_rx.recv() => { if let Ok(msg) = res { let res = socket.send(msg.into()).await; if res.is_err() { warn!(res=?res, "Error from socket.send"); break; } } else { warn!(res=?res, "Error from broadcast_rx"); break; } }, } } } async fn identity_middleware(jar: CookieJar, mut req: Request, next: Next) -> Response { const UID: &'static str = "uid"; let opt_uid = jar.get(UID); let uid = opt_uid .and_then(|c| c.value().parse::().ok()) .unwrap_or_else(|| random::()); let uid = UserId(uid); req.extensions_mut().get_or_insert::(uid); let mut res = next.run(req).await; if opt_uid.is_none() { let cookie = cookie::Cookie::build(("uid", uid.0.to_string())) .same_site(cookie::SameSite::Strict) .path("/") .http_only(true) .secure(true) .build(); res.headers_mut() .append("Set-Cookie", cookie.encoded().to_string().parse().unwrap()); } res } #[tokio::main] async fn main() -> Result<()> { let subscriber = FmtSubscriber::builder() .with_max_level(Level::TRACE) .finish(); tracing::subscriber::set_global_default(subscriber) .expect("set global default subscriber failed"); let mut server_state = ServerState { games: HashMap::new(), }; server_state.new_game(123); let st = Arc::new(Mutex::new(server_state)); let app = Router::new() .route("/", get(route_index)) .route("/game", post(create_game)) .route("/game/{gid}", get(get_game)) .route("/game/{gid}/start", post(start_game)) .route("/game/{gid}/restart", post(restart_game)) .route("/game/{gid}/guess", post(guess_tune)) .route("/game/{gid}/submit", post(submit_game)) .route("/game/{gid}/mark/{uid}/{status}", post(mark_guess)) .route("/game/{gid}/judge", post(submit_judge)) .route("/game/{gid}/ws", get(game_ws)) .nest_service("/static", ServeDir::new("static")) .layer(CatchPanicLayer::custom(handle_panic)) .layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) .on_response(DefaultOnResponse::new().level(Level::INFO)), ) .layer(middleware::from_fn(identity_middleware)) .with_state(st.clone()); let addr = "0.0.0.0:4800"; info!("Listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve( listener, app.into_make_service_with_connect_info::(), ) .await .context("serve") }