From ae3692879138b956e78463b46cb501acc6b6465a Mon Sep 17 00:00:00 2001 From: JMARyA Date: Mon, 16 Dec 2024 21:16:02 +0100 Subject: [PATCH] implement user auth --- Cargo.lock | 51 ++++++++++ Cargo.toml | 3 + migrations/003_add_users.sql | 14 +++ src/library/mod.rs | 36 +++++++ src/library/user.rs | 179 +++++++++++++++++++++++++++++++++++ src/main.rs | 10 +- src/pages/components.rs | 39 ++++++-- src/pages/index.rs | 19 +++- src/pages/mod.rs | 1 + src/pages/user.rs | 64 +++++++++++++ src/pages/watch.rs | 9 +- 11 files changed, 410 insertions(+), 15 deletions(-) create mode 100644 migrations/003_add_users.sql create mode 100644 src/library/user.rs create mode 100644 src/pages/user.rs diff --git a/Cargo.lock b/Cargo.lock index 35d5005..44488db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -149,6 +149,19 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bcrypt" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b1866ecef4f2d06a0bb77880015fdf2b89e25a1c2e5addacb87e459c86dc67e" +dependencies = [ + "base64", + "blowfish", + "getrandom", + "subtle", + "zeroize", +] + [[package]] name = "binascii" version = "0.1.4" @@ -173,6 +186,16 @@ dependencies = [ "generic-array", ] +[[package]] +name = "blowfish" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e412e2cd0f2b2d93e02543ceae7917b3c70331573df19ee046bcbc35e45e87d7" +dependencies = [ + "byteorder", + "cipher", +] + [[package]] name = "bumpalo" version = "3.16.0" @@ -227,6 +250,16 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "concurrent-queue" version = "2.5.0" @@ -337,6 +370,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" + [[package]] name = "der" version = "0.7.9" @@ -1044,6 +1083,15 @@ version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8fae54786f62fb2918dcfae3d568594e50eb9b5c25bf04371af6fe7516452fb" +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + [[package]] name = "is-terminal" version = "0.4.13" @@ -2864,12 +2912,15 @@ checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" name = "watchdogs" version = "0.1.0" dependencies = [ + "bcrypt", "chrono", + "data-encoding", "env_logger", "futures", "hex", "log", "maud", + "rand", "rayon", "regex", "ring", diff --git a/Cargo.toml b/Cargo.toml index a040f9b..e5e5bf7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,6 @@ tokio = { version = "1.35.1", features = ["full"] } uuid = { version = "1.8.0", features = ["v4", "serde"] } sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json"] } maud = "0.26.0" +rand = "0.8.5" +data-encoding = "2.6.0" +bcrypt = "0.16.0" diff --git a/migrations/003_add_users.sql b/migrations/003_add_users.sql new file mode 100644 index 0000000..cd9fa12 --- /dev/null +++ b/migrations/003_add_users.sql @@ -0,0 +1,14 @@ +CREATE TYPE user_role AS ENUM ('regular', 'admin'); + +CREATE TABLE IF NOT EXISTS users ( + username VARCHAR(255) NOT NULL PRIMARY KEY, + "password" text NOT NULL, + user_role user_role NOT NULL DEFAULT 'regular' +); + +CREATE TABLE IF NOT EXISTS user_session ( + id UUID NOT NULL PRIMARY KEY DEFAULT gen_random_uuid(), + token text NOT NULL, + "user" varchar(255) NOT NULL, + FOREIGN KEY("user") REFERENCES users(username) +); diff --git a/src/library/mod.rs b/src/library/mod.rs index d31ec94..cdd7b3a 100644 --- a/src/library/mod.rs +++ b/src/library/mod.rs @@ -1,3 +1,4 @@ +use serde_json::json; use std::path::Path; use std::path::PathBuf; use std::str::FromStr; @@ -6,6 +7,7 @@ use walkdir::WalkDir; use func::is_video_file; pub use video::Video; mod func; +pub mod user; mod video; #[derive(Debug, Clone)] @@ -208,3 +210,37 @@ impl Library { videos } } + +/// A trait to generate a Model API representation in JSON format. +pub trait ToAPI: Sized { + /// Generate public API JSON + fn api(&self) -> impl std::future::Future; +} + +/// Converts a slice of items implementing the `ToAPI` trait into a `Vec` of JSON values. +pub async fn vec_to_api(items: &[impl ToAPI]) -> Vec { + let mut ret = Vec::with_capacity(items.len()); + + for e in items { + ret.push(e.api().await); + } + + ret +} + +pub fn to_uuid(id: &str) -> Result { + uuid::Uuid::from_str(id).map_err(|_| no_uuid_error()) +} + +type ApiError = rocket::response::status::BadRequest; +type FallibleApiResponse = Result; + +pub fn no_uuid_error() -> ApiError { + api_error("No valid UUID") +} + +pub fn api_error(msg: &str) -> ApiError { + rocket::response::status::BadRequest(json!({ + "error": msg + })) +} diff --git a/src/library/user.rs b/src/library/user.rs new file mode 100644 index 0000000..f26ef41 --- /dev/null +++ b/src/library/user.rs @@ -0,0 +1,179 @@ +use std::str::FromStr; + +use data_encoding::HEXUPPER; +use rand::RngCore; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use sqlx::FromRow; + +use crate::pages::ToAPI; + +fn gen_token(token_length: usize) -> String { + let mut token_bytes = vec![0u8; token_length]; + + rand::thread_rng().fill_bytes(&mut token_bytes); + + HEXUPPER.encode(&token_bytes) +} + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct User { + /// The username chosen by the user + pub username: String, + /// The hashed password for the user + pub password: String, + /// The role of the user + pub user_role: UserRole, +} + +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)] +#[sqlx(type_name = "user_role", rename_all = "lowercase")] +pub enum UserRole { + /// A regular user with limited permissions + Regular, + /// An admin user with full system privileges + Admin, +} + +pub struct UserManager { + conn: sqlx::PgPool, +} + +impl UserManager { + pub fn new(conn: sqlx::PgPool) -> Self { + Self { conn } + } + + /// Find a user by their username + pub async fn find(&self, username: &str) -> Option { + sqlx::query_as("SELECT * FROM users WHERE username = $1") + .bind(username) + .fetch_optional(&self.conn) + .await + .unwrap() + } + + /// Create a new user with the given details + /// + /// Returns an Option containing the created user, or None if a user already exists with the same username + pub async fn create(&self, username: &str, password: &str, role: UserRole) -> Option { + // Check if a user already exists with the same username + if self.find(username).await.is_some() { + return None; + } + + let u = User { + username: username.to_string(), + password: bcrypt::hash(password, bcrypt::DEFAULT_COST).unwrap(), + user_role: role, + }; + + sqlx::query("INSERT INTO users (username, \"password\", user_role) VALUES ($1, $2, $3)") + .bind(&u.username) + .bind(&u.password) + .bind(&u.user_role) + .execute(&self.conn) + .await + .unwrap(); + + Some(u) + } + + /// Login a user with the given username and password + pub async fn login(&self, username: &str, password: &str) -> Option<(Session, UserRole)> { + let u = self.find(username).await?; + + if !u.verify_pw(password) { + return None; + } + + Some((u.session(&self.conn).await, u.user_role)) + } + + /// Find all users in the system + pub async fn find_all(&self) -> Vec { + sqlx::query_as("SELECT * FROM users") + .fetch_all(&self.conn) + .await + .unwrap() + } + + pub async fn verify(&self, session_id: &str) -> Option { + let ses: Option = sqlx::query_as("SELECT * FROM user_session WHERE id = $1") + .bind(uuid::Uuid::from_str(session_id).unwrap_or(uuid::Uuid::nil())) + .fetch_optional(&self.conn) + .await + .unwrap(); + + if ses.is_some() { + self.find(&ses.unwrap().user).await + } else { + None + } + } +} + +impl User { + /// Generate a new session token for the user + /// + /// Returns a Session instance containing the generated token and associated user + pub async fn session(&self, conn: &sqlx::PgPool) -> Session { + sqlx::query_as( + "INSERT INTO user_session (token, \"user\") VALUES ($1, $2) RETURNING id, token, \"user\"", + ) + .bind(gen_token(64)) + .bind(&self.username) + .fetch_one(conn) + .await + .unwrap() + } + + /// Check if the user is an admin + pub const fn is_admin(&self) -> bool { + matches!(self.user_role, UserRole::Admin) + } + + /// Verify that a provided password matches the hashed password for the user + /// + /// Returns a boolean indicating whether the passwords match or not + pub fn verify_pw(&self, password: &str) -> bool { + bcrypt::verify(password, &self.password).unwrap() + } + + /// Change the password of a User + /// + /// Returns a Result indicating whether the password change was successful or not + pub async fn passwd(self, old: &str, new: &str, conn: &sqlx::PgPool) -> Result<(), ()> { + if self.verify_pw(old) { + sqlx::query("UPDATE users SET \"password\" = $1 WHERE username = $2;") + .bind(bcrypt::hash(new, bcrypt::DEFAULT_COST).unwrap()) + .bind(&self.username) + .fetch_one(conn) + .await + .unwrap(); + + return Ok(()); + } + + Err(()) + } +} + +impl ToAPI for User { + async fn api(&self) -> serde_json::Value { + json!({ + "username": self.username, + "role": self.user_role + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, FromRow)] +pub struct Session { + /// The unique ID of the session token + pub id: uuid::Uuid, + /// The generated session token + pub token: String, + /// The username associated with the session token + pub user: String, +} diff --git a/src/main.rs b/src/main.rs index ac3d36b..b2bd6aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ use std::path::Path; +use library::user::UserManager; use rocket::{http::Method, routes}; use tokio::sync::OnceCell; @@ -46,6 +47,10 @@ async fn launch() -> _ { sqlx::migrate!("./migrations").run(pg).await.unwrap(); let lib = library::Library::new().await; + let um = UserManager::new(pg.clone()); + + um.create("admin", "admin", library::user::UserRole::Admin) + .await; let library = lib.clone(); @@ -81,9 +86,12 @@ async fn launch() -> _ { pages::yt::yt_tag_page, pages::yt::yt_channel_page, pages::index::index_page, - pages::watch::watch_page + pages::watch::watch_page, + pages::user::login, + pages::user::login_post ], ) .attach(cors) .manage(lib) + .manage(um) } diff --git a/src/pages/components.rs b/src/pages/components.rs index 4a4ffc1..78f792b 100644 --- a/src/pages/components.rs +++ b/src/pages/components.rs @@ -2,7 +2,10 @@ use core::num; use maud::{html, PreEscaped}; -use crate::library::Video; +use crate::library::{ + user::{User, UserManager}, + Video, +}; use rocket::{ http::{ContentType, Status}, @@ -10,7 +13,21 @@ use rocket::{ }; pub struct HTMX { - is_htmx: bool, + pub is_htmx: bool, + pub session: Option, + user: Option, +} + +impl HTMX { + pub async fn user(&mut self, um: &UserManager) -> Option { + if let Some(user) = &self.user { + return Some(user.clone()); + } + + let user = um.verify(&self.session.clone().unwrap_or_default()).await; + self.user = user.clone(); + user + } } #[rocket::async_trait] @@ -24,6 +41,11 @@ impl<'r> FromRequest<'r> for HTMX { .get("HX-Request") .collect::>() .is_empty(), + session: req + .cookies() + .get("session_id") + .map(|x| x.value().to_string()), + user: None, }) } } @@ -49,7 +71,7 @@ pub fn script(script: &str) -> PreEscaped { ) } -pub fn shell(content: PreEscaped, title: &str) -> PreEscaped { +pub fn shell(content: PreEscaped, title: &str, user: Option) -> PreEscaped { html! { html { head { @@ -62,13 +84,17 @@ pub fn shell(content: PreEscaped, title: &str) -> PreEscaped { header class="bg-gray-800 text-white shadow-md py-2" { (script(include_str!("../scripts/header.js"))); - div class="flex justify-start px-6" { + div class="flex justify-between px-6" { a href="/" class="flex items-center space-x-2" { img src="/favicon" alt="Logo" class="w-10 h-10 rounded-md"; span class="font-semibold text-xl" { "WatchDogs" }; }; + @if user.is_some() { + p { (user.unwrap().username) }; + }; + }; }; @@ -81,15 +107,16 @@ pub fn shell(content: PreEscaped, title: &str) -> PreEscaped { } } -pub fn render_page( +pub async fn render_page( htmx: HTMX, content: PreEscaped, title: &str, + user: Option, ) -> (Status, (ContentType, String)) { if !htmx.is_htmx { ( Status::Ok, - (ContentType::HTML, shell(content, title).into_string()), + (ContentType::HTML, shell(content, title, user).into_string()), ) } else { (Status::Ok, (ContentType::HTML, content.into_string())) diff --git a/src/pages/index.rs b/src/pages/index.rs index d1ddd46..7e679a1 100644 --- a/src/pages/index.rs +++ b/src/pages/index.rs @@ -7,7 +7,7 @@ use rocket::{ use serde_json::json; use crate::{ - library::Library, + library::{user::UserManager, Library}, pages::components::{htmx_link, video_element}, }; @@ -35,10 +35,13 @@ pub async fn search( #[get("/d/")] pub async fn channel_page( - htmx: HTMX, + mut htmx: HTMX, dir: &str, library: &State, + um: &State, ) -> (Status, (ContentType, String)) { + let user = htmx.user(um).await; + if dir.ends_with(".json") { let dir_videos = library .get_directory_videos(dir.split_once(".json").map(|x| x.0).unwrap_or_default()) @@ -57,11 +60,17 @@ pub async fn channel_page( }; ); - render_page(htmx, content, dir) + render_page(htmx, content, dir, user).await } #[get("/")] -pub async fn index_page(htmx: HTMX, library: &State) -> (Status, (ContentType, String)) { +pub async fn index_page( + mut htmx: HTMX, + library: &State, + um: &State, +) -> (Status, (ContentType, String)) { + let user = htmx.user(um).await; + let content = html!( h1 class="text-center text-4xl font-extrabold leading-tight mt-4" { "Random Videos" }; div class="grid grid-cols-3 gap-6 p-6" { @@ -78,5 +87,5 @@ pub async fn index_page(htmx: HTMX, library: &State) -> (Status, (Conte }; ); - render_page(htmx, content, "WatchDogs") + render_page(htmx, content, "WatchDogs", user).await } diff --git a/src/pages/mod.rs b/src/pages/mod.rs index f4f6eb9..283e932 100644 --- a/src/pages/mod.rs +++ b/src/pages/mod.rs @@ -3,6 +3,7 @@ use rocket::http::{ContentType, Status}; pub mod assets; pub mod components; pub mod index; +pub mod user; pub mod watch; pub mod yt; diff --git a/src/pages/user.rs b/src/pages/user.rs new file mode 100644 index 0000000..930b638 --- /dev/null +++ b/src/pages/user.rs @@ -0,0 +1,64 @@ +use crate::{ + library::{user::UserManager, Library}, + pages::components::{htmx_link, video_element}, +}; +use maud::html; +use rocket::http::CookieJar; +use rocket::{ + form::Form, + get, + http::{ContentType, Cookie, Status}, + post, + response::Redirect, + FromForm, State, +}; +use serde_json::json; + +use super::{ + api_response, + components::{render_page, video_element_wide, HTMX}, + vec_to_api, +}; + +#[get("/login")] +pub async fn login(mut htmx: HTMX, um: &State) -> (Status, (ContentType, String)) { + let user = htmx.user(um).await; + + let content = html!( + h2 { "Login" }; + form action="/login" method="POST" { + input type="text" name="username" placeholder="Username" required; + input type="password" name="password" placeholder="Password" required; + input type="submit" value="Login"; + } + ); + + render_page(htmx, content, "Login", user).await +} + +#[derive(FromForm)] +pub struct LoginForm { + username: String, + password: String, +} + +#[post("/login", data = "")] +pub async fn login_post( + login_form: Form, + um: &State, + cookies: &CookieJar<'_>, +) -> Option { + let login_data = login_form.into_inner(); + + let (session, _) = um.login(&login_data.username, &login_data.password).await?; + + let session_cookie = Cookie::build(("session_id", session.id.to_string())) + .path("/") // Set the cookie path to the root so it’s available for the whole app + .http_only(true) // Make the cookie HTTP only for security + .max_age(rocket::time::Duration::days(7)) // Set the cookie expiration (7 days in this case) + .build(); + + cookies.add(session_cookie); + + Some(Redirect::to("/")) +} diff --git a/src/pages/watch.rs b/src/pages/watch.rs index ec504f1..101cfff 100644 --- a/src/pages/watch.rs +++ b/src/pages/watch.rs @@ -7,7 +7,7 @@ use rocket::{ use serde_json::json; use crate::{ - library::{self, Library}, + library::{self, user::UserManager, Library}, pages::components::{format_date, video_element}, }; @@ -18,10 +18,13 @@ use super::{ #[get("/watch?")] pub async fn watch_page( - htmx: HTMX, + mut htmx: HTMX, library: &State, v: String, + um: &State, ) -> (Status, (ContentType, String)) { + let user = htmx.user(um).await; + let video = if let Some(video) = library.get_video_by_id(&v).await { video } else { @@ -61,5 +64,5 @@ pub async fn watch_page( }; ); - render_page(htmx, content, &format!("{} - WatchDogs", video.title)) + render_page(htmx, content, &format!("{} - WatchDogs", video.title), user).await }