diff options
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | Cargo.toml | 2 | ||||
| -rw-r--r-- | src/main.rs | 351 |
3 files changed, 265 insertions, 89 deletions
@@ -126,6 +126,7 @@ dependencies = [ "axum", "axum-core 0.5.0", "bytes", + "cookie", "futures-util", "http", "http-body", @@ -6,7 +6,7 @@ edition = "2021" [dependencies] anyhow = "1.0.95" axum = { version = "0.8.1", features = ["ws"] } -axum-extra = "0.10.0" +axum-extra = { version = "0.10.0", features = ["cookie"] } axum-macros = "0.5.0" bytes = "1.9.0" cookie = { version = "0.18.1", features = ["percent-encode"] } diff --git a/src/main.rs b/src/main.rs index e164981..bb3e9b3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,23 @@ -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; +use std::{ + collections::{HashMap, HashSet}, + net::SocketAddr, + sync::Arc, +}; -use anyhow::{Context, Result}; +use anyhow::{bail, Context, Result}; use axum::{ body::Body, extract::{ ws::{Message, WebSocket}, - Path, State, WebSocketUpgrade, + Path, Request, State, WebSocketUpgrade, }, http::StatusCode, + middleware::{self, Next}, response::{IntoResponse, Response}, routing::{get, post}, - Form, Router, + Extension, Form, Router, }; +use axum_extra::extract::CookieJar; use maud::{html, Markup, PreEscaped, DOCTYPE}; use rand::random; use serde::Deserialize; @@ -41,7 +47,7 @@ fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> Response { (StatusCode::INTERNAL_SERVER_ERROR).into_response() } -#[derive(Deserialize)] +#[derive(Deserialize, Debug, Clone)] struct NoteDatum { // Epoch time in ms time: u64, @@ -49,7 +55,7 @@ struct NoteDatum { y: f64, } -#[derive(Deserialize)] +#[derive(Deserialize, Debug, Clone)] struct Notes { notes: Vec<NoteDatum>, } @@ -59,52 +65,124 @@ struct NotesForm { notes: String, } +struct Round { + author: UserId, + notes: Notes, + guesses: HashMap<UserId, String>, +} + struct Game { id: u64, + // If we've started or not is_started: bool, - submitted: Option<Notes>, + /// All users that are in the game. + active_users: HashSet<UserId>, + /// All submissions + submissions: HashMap<UserId, Notes>, + /// Stack for UserIds whose rounds have not been played + next_rounds: Vec<UserId>, + /// Data for each round + rounds: Vec<Round>, /// Sender to broadcast things broadcast_tx: Sender<String>, /// Broadcast receiver to send to WebSocket broadcast_rx: Receiver<String>, + + senders: HashMap<UserId, Sender<String>>, } impl Game { - /// Return the appropriate HTML for the game state. - fn get_html(&self) -> Markup { - let gid = self.id; + /// Game starts. + fn start(&mut self) { + assert!(!self.is_started, "Cannot re-start a game"); + self.is_started = true; + self.active_users = self.senders.keys().cloned().collect(); + self.broadcast_screen(self.screen_compose()) + .context("Broadcast compose screen") + .unwrap(); + } - if self.submitted.is_some() { - return html! { - h1 { "Submitted!" } - }; + /// User submits a melody. + fn submit(&mut self, uid: UserId, notes: Notes) { + if !self.active_users.contains(&uid) { + warn!(uid=?uid, "Inactive user tried to submit"); + return; + } + if self.submissions.contains_key(&uid) { + warn!(uid=?uid, "User tried to submit twice"); + return; } - if self.is_started { - return html! { - script { ({PreEscaped("document.notes = []; ")})} - h1 { "Playing Game" } - p { "Were playing here" } - - canvas - #canvas - height=(256) width=(256) - style="position:relative; background: cyan;" - {}; - script { (PreEscaped(r#"document.getElementById("canvas").onclick = (e) => { - const notes = JSON.parse(e.target.dataset.notes ?? "[]"); - notes.push({ time: Date.now(), y: e.layerY / e.target.height}); - e.target.dataset.notes = JSON.stringify(notes); - }"#))} - - button hx-post=(format!("/game/{gid}/submit")) - hx-vals=r#"js:{ "notes": document.getElementById("canvas").dataset.notes }"# - { "Submit" } - }; + + self.submissions.insert(uid, notes); + + self.message_screen(uid, self.screen_submitted()) + .context("message submit screen") + .unwrap(); + + if self.active_users.len() == self.submissions.len() { + info!("All users have submitted"); + let users = self.submissions.keys().cloned(); + self.next_rounds = users.collect(); + self.new_round(); } + } + fn new_round(&mut self) { + let Some(uid) = self.next_rounds.pop() else { + error!("No more rounds"); + return; + }; + self.rounds.push(Round { + author: uid, + notes: self + .submissions + .get(&uid) + .context("get users submission") + .unwrap() + .clone(), + guesses: HashMap::new(), + }); + + self.broadcast_screen(self.screen_round()) + .context("broadcast screen round") + .unwrap(); + } + + fn screen_submitted(&self) -> Markup { + html! { + h1 { "Submitted!" } + } + } + + fn screen_compose(&self) -> Markup { + let gid = self.id; + html! { + script { ({PreEscaped("document.notes = []; ")})} + h1 { "Playing Game" } + p { "Were playing here" } + + canvas + #canvas + height=(256) width=(256) + style="position:relative; background: cyan;" + {}; + script { (PreEscaped(r#"document.getElementById("canvas").onclick = (e) => { + const notes = JSON.parse(e.target.dataset.notes ?? "[]"); + notes.push({ time: Date.now(), y: e.layerY / e.target.height}); + e.target.dataset.notes = JSON.stringify(notes); + }"#))} + + button hx-post=(format!("/game/{gid}/submit")) + hx-vals=r#"js:{ "notes": document.getElementById("canvas").dataset.notes }"# + { "Submit" } + } + } + + fn screen_lobby(&self) -> Markup { + let gid = self.id; html! { h1 { "Game" } p { "Current game id is " (gid)} @@ -112,25 +190,46 @@ impl Game { button hx-post=(format!("/game/{gid}/start")) { "Start" } } } -} -async fn submit_game_answer(lock: Arc<Mutex<Game>>, notes: Notes) { - lock.lock().await.submitted = Some(notes); + fn screen_round(&self) -> Markup { + let i = self.rounds.len(); + let n = self.rounds.len() + self.next_rounds.len(); + let Some(r) = self.rounds.last() else { + return html! { h1 { "No more rounds" }}; + }; - tokio::spawn(async move { - let _ = tokio::time::sleep(Duration::from_secs(2)).await; - let mut g = lock.lock().await; - g.is_started = false; - g.submitted = None; - g.broadcast_tx.send( - html! { - section #content { - (g.get_html()) + html! { + h1 { (format!("Round {i}/{n}"))} + p { "Author was " (r.author.0)} + } + } + + fn broadcast_screen(&self, html: Markup) -> Result<()> { + self.broadcast_tx + .send( + html! { + section #content { (html) } } + .into_string(), + ) + .map(|_| ()) + .context("Failed to send") + } + + /// Message a single user a screen. + fn message_screen(&self, user: UserId, html: Markup) -> Result<()> { + let Some(send) = self.senders.get(&user) else { + bail!("No such user") + }; + send.send( + html! { + section #content { (html) } } .into_string(), ) - }); + .map(|_| ()) + .context("Failed to send") + } } struct ServerState { @@ -143,9 +242,13 @@ impl ServerState { let g = Game { id: gid, is_started: false, - submitted: None, + active_users: HashSet::new(), + submissions: HashMap::new(), + next_rounds: Vec::new(), + rounds: Vec::new(), broadcast_rx: rx, broadcast_tx: tx, + senders: HashMap::new(), }; self.games.insert(gid, Arc::new(Mutex::new(g))); } @@ -221,19 +324,16 @@ async fn start_game(Path(gid): Path<u64>, State(st): State<Server>) -> Response } }; - let mut g = lock.lock().await; - - g.is_started = true; - let html = g.get_html(); - - g.broadcast_tx - .send(html! { section #content { (html) } }.into_string()) - .expect("failed to send"); + lock.lock().await.start(); StatusCode::OK.into_response() } -async fn get_game(Path(gid): Path<u64>, State(st): State<Server>) -> Response { +async fn get_game( + Path(gid): Path<u64>, + Extension(uid): Extension<UserId>, + State(st): State<Server>, +) -> Response { let game = { if let Some(g) = st.lock().await.games.get(&gid) { g.clone() @@ -242,7 +342,22 @@ async fn get_game(Path(gid): Path<u64>, State(st): State<Server>) -> Response { } }; - let html = game.lock().await.get_html(); + let html = { + let g = game.lock().await; + if !g.is_started { + g.screen_lobby() + } else if g.rounds.len() == 0 { + if g.submissions.contains_key(&uid) { + g.screen_submitted() + } else { + g.screen_compose() + } + } else { + // TODO + warn!("Wrong screen"); + g.screen_lobby() + } + }; html_response(html! { (DOCTYPE) @@ -260,7 +375,7 @@ async fn get_game(Path(gid): Path<u64>, State(st): State<Server>) -> Response { section #content { (html) } - footer { span { "Good Game" } } + footer { span { "You are " (uid.0) } } } }) } @@ -268,6 +383,7 @@ async fn get_game(Path(gid): Path<u64>, State(st): State<Server>) -> Response { async fn submit_game( Path(gid): Path<u64>, State(st): State<Server>, + Extension(uid): Extension<UserId>, Form(form): Form<NotesForm>, ) -> Response { let glock = { @@ -283,18 +399,16 @@ async fn submit_game( return (StatusCode::UNPROCESSABLE_ENTITY).into_response(); }; - submit_game_answer(glock.clone(), Notes { notes }).await; - - let g = glock.lock().await; - let html = g.get_html(); - g.broadcast_tx - .send(html! { section #content { (html) } }.into_string()) - .expect("failed to send"); - + glock.lock().await.submit(uid, Notes { notes }); StatusCode::OK.into_response() } -async fn game_ws(ws: WebSocketUpgrade, Path(gid): Path<u64>, State(st): State<Server>) -> Response { +async fn game_ws( + ws: WebSocketUpgrade, + Path(gid): Path<u64>, + Extension(uid): Extension<UserId>, + State(st): State<Server>, +) -> Response { let game = { let s = st.lock().await; if let Some(g) = s.games.get(&gid) { @@ -304,11 +418,11 @@ async fn game_ws(ws: WebSocketUpgrade, Path(gid): Path<u64>, State(st): State<Se } }; - ws.on_upgrade(move |socket| handle_socket(socket, game)) + ws.on_upgrade(move |socket| handle_socket(socket, uid, game)) .into_response() } -async fn handle_socket(mut socket: WebSocket, game: Arc<Mutex<Game>>) { +async fn handle_socket(mut socket: WebSocket, uid: UserId, game: Arc<Mutex<Game>>) { if socket .send(Message::Ping(bytes::Bytes::from("wat"))) .await @@ -319,25 +433,83 @@ async fn handle_socket(mut socket: WebSocket, game: Arc<Mutex<Game>>) { println!("Could not send ping!"); return; } - { - game.lock() - .await - .broadcast_tx - .send( - html! { - div #messages hx-swap-oob="beforeend" { p { "Someone joined" } } + let mut single_rx = { + let mut g = game.lock().await; + // g.broadcast_tx + // .send( + // html! { + // div #messages hx-swap-oob="beforeend" { p { "Someone joined" } } + // } + // .into_string(), + // ) + // .unwrap(); + let (tx, rx) = broadcast::channel(16); + g.senders.insert(uid, tx); + rx + }; + + let mut broadcast_rx = { game.lock().await.broadcast_rx.resubscribe() }; + + loop { + tokio::select! { + res = single_rx.recv() => { + if let Ok(msg) = res { + let res = socket.send(msg.into()).await; + if res.is_err() { + warn!(res=?res, "Error from socket.send"); + break; + } + + } else { + warn!(res=?res, "Error from single_rx"); + break; } - .into_string(), - ) - .unwrap(); + }, + res = broadcast_rx.recv() => { + if let Ok(msg) = res { + let res = socket.send(msg.into()).await; + if res.is_err() { + warn!(res=?res, "Error from socket.send"); + break; + } + + } else { + warn!(res=?res, "Error from broadcast_rx"); + break; + } + }, + } } +} + +#[derive(Hash, Eq, PartialEq, Copy, Clone, Debug)] +struct UserId(u64); + +async fn identity_middleware(jar: CookieJar, mut req: Request, next: Next) -> Response { + const UID: &'static str = "uid"; + + let opt_uid = jar.get(UID); + let uid = opt_uid + .and_then(|c| c.value().parse::<u64>().ok()) + .unwrap_or_else(|| random::<u64>()); + let uid = UserId(uid); - let mut rx = { game.lock().await.broadcast_rx.resubscribe() }; - while let Ok(a) = rx.recv().await { - trace!(a, "send "); - let res = socket.send(a.into()).await; - trace!(?res); + req.extensions_mut().get_or_insert::<UserId>(uid); + + let mut res = next.run(req).await; + + if opt_uid.is_none() { + let cookie = cookie::Cookie::build(("uid", uid.0.to_string())) + .same_site(cookie::SameSite::Strict) + .path("/") + .http_only(true) + .secure(true) + .build(); + res.headers_mut() + .append("Set-Cookie", cookie.encoded().to_string().parse().unwrap()); } + + res } #[tokio::main] @@ -353,6 +525,8 @@ async fn main() -> Result<()> { }; server_state.new_game(123); + let st = Arc::new(Mutex::new(server_state)); + let app = Router::new() .route("/", get(route_index)) .route("/game", post(create_game)) @@ -367,9 +541,10 @@ async fn main() -> Result<()> { .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) .on_response(DefaultOnResponse::new().level(Level::INFO)), ) - .with_state(Arc::new(Mutex::new(server_state))); + .layer(middleware::from_fn(identity_middleware)) + .with_state(st.clone()); - let addr = "192.168.0.106:4800"; + let addr = "0.0.0.0:4800"; info!("Listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); axum::serve( |
