From bab73914bdddc53cbec5c1d2fabdcfe857838aa8 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Tue, 6 May 2025 12:39:14 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20extractors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Cargo.lock | 224 ++++++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 10 +- src/extractor.rs | 239 +++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 29 +----- src/mod.rs | 32 ------- src/user.rs | 150 ----------------------------- 6 files changed, 471 insertions(+), 213 deletions(-) create mode 100644 src/extractor.rs delete mode 100644 src/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 25e7ec1..120ba2f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -207,6 +207,84 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "axum" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-extra" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d" +dependencies = [ + "axum", + "axum-core", + "bytes", + "cookie", + "futures-util", + "headers", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "serde", + "tower", + "tower-layer", + "tower-service", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -222,6 +300,12 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "base64" version = "0.22.1" @@ -238,6 +322,8 @@ checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" name = "based_auth" version = "0.1.0" dependencies = [ + "axum", + "axum-extra", "bcrypt", "chrono", "data-encoding", @@ -257,7 +343,7 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b1866ecef4f2d06a0bb77880015fdf2b89e25a1c2e5addacb87e459c86dc67e" dependencies = [ - "base64", + "base64 0.22.1", "blowfish", "getrandom 0.2.16", "subtle", @@ -1009,6 +1095,30 @@ dependencies = [ "hashbrown 0.15.3", ] +[[package]] +name = "headers" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9" +dependencies = [ + "base64 0.21.7", + "bytes", + "headers-core", + "http 1.3.1", + "httpdate", + "mime", + "sha1", +] + +[[package]] +name = "headers-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" +dependencies = [ + "http 1.3.1", +] + [[package]] name = "heck" version = "0.5.0" @@ -1093,6 +1203,29 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http 1.3.1", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "pin-project-lite", +] + [[package]] name = "httparse" version = "1.10.1" @@ -1123,7 +1256,7 @@ dependencies = [ "futures-util", "h2", "http 0.2.12", - "http-body", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1135,6 +1268,41 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" +dependencies = [ + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.6.0", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.63" @@ -1475,6 +1643,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "md-5" version = "0.10.6" @@ -2162,7 +2336,7 @@ dependencies = [ "either", "futures", "http 0.2.12", - "hyper", + "hyper 0.14.32", "indexmap", "log", "memchr", @@ -2321,6 +2495,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -2464,7 +2648,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f743f2a3cea30a58cd479013f75550e879009e3a02f616f18ca699335aa248c3" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "chrono", "crc", @@ -2541,7 +2725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0afdd3aa7a629683c2d750c2df343025545087081ab5942593a5288855b1b7a7" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bitflags", "byteorder", "bytes", @@ -2585,7 +2769,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0bedbe1bbb5e2615ef347a5e9d8cd7680fb63e77d9dafc0f29be15e53f1ebe6" dependencies = [ "atoi", - "base64", + "base64 0.22.1", "bitflags", "byteorder", "chrono", @@ -2695,6 +2879,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.2" @@ -2928,6 +3118,28 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" diff --git a/Cargo.toml b/Cargo.toml index 010e777..1a33dfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,15 +3,23 @@ name = "based_auth" version = "0.1.0" edition = "2024" +[features] +default = [] +rocket = ["dep:rocket"] +axum = ["dep:axum", "dep:axum-extra"] + [dependencies] env_logger = "0.10.0" hex = "0.4.3" chrono = { version = "0.4", features = ["serde"] } log = "0.4.20" -rocket = { version = "0.5.1", features = ["json"] } serde = { version = "1.0.195", features = ["derive"] } uuid = { version = "1.8.0", features = ["v4", "serde"] } rand = "0.8.5" data-encoding = "2.6.0" bcrypt = "0.16.0" owl = { git = "https://git.hydrar.de/red/owl" } + +rocket = { version = "0.5.1", features = ["json"], optional = true } +axum = { version = "0.8.4", optional = true } +axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"], optional = true } diff --git a/src/extractor.rs b/src/extractor.rs new file mode 100644 index 0000000..364c6f0 --- /dev/null +++ b/src/extractor.rs @@ -0,0 +1,239 @@ +#[cfg(feature = "axum")] +use axum::{ + extract::FromRequestParts, + http::{StatusCode, request::Parts}, +}; +#[cfg(feature = "axum")] +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; +use owl::db::Model; +#[cfg(feature = "rocket")] +use rocket::{Request, http::Status, outcome::Outcome}; + +use crate::{Sessions, User}; + +pub struct UserAuth(pub Model); + +#[cfg(feature = "axum")] +impl FromRequestParts for UserAuth +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let jar = axum_extra::extract::CookieJar::from_headers(&parts.headers); + + if let Some(cookie) = jar.get("session") { + if let Some(user) = User::from_session(cookie.value().to_string()).await { + return Ok(UserAuth(user)); + } + } + + Err((StatusCode::UNAUTHORIZED, "Unauthorized")) + } +} + +#[cfg(feature = "rocket")] +#[rocket::async_trait] +impl<'r> rocket::request::FromRequest<'r> for UserAuth { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { + if let Some(session_id) = request.cookies().get("session") { + if let Some(user) = User::from_session(session_id.value().to_string()).await { + return Outcome::Success(UserAuth(user)); + } + } + Outcome::Error((Status::Unauthorized, ())) + } +} + +/// Struct which extracts a user with session from `Token` HTTP Header. +pub struct APIUser(pub Model); + +#[cfg(feature = "rocket")] +#[rocket::async_trait] +impl<'r> rocket::request::FromRequest<'r> for APIUser { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { + if let Some(auth_header) = request.headers().get_one("Authorization") { + // Expect "Bearer " + if let Some(token) = auth_header.strip_prefix("Bearer ").map(str::trim) { + if let Some(user) = User::from_session(token.to_string()).await { + return Outcome::Success(APIUser(user)); + } + } + } + + Outcome::Error((Status::Unauthorized, ())) + } +} + +#[cfg(feature = "axum")] +impl FromRequestParts for APIUser +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { + let TypedHeader(Authorization(bearer)) = + TypedHeader::>::from_request_parts(parts, state) + .await + .map_err(|_| { + ( + StatusCode::UNAUTHORIZED, + "Missing or invalid Authorization header", + ) + })?; + + let token = bearer.token(); + + match User::from_session(token.to_string()).await { + Some(user) => Ok(APIUser(user)), + None => Err((StatusCode::UNAUTHORIZED, "Invalid token")), + } + } +} + +/// Maybe User? +/// +/// This struct extracts a user if possible, but also allows anybody. +/// +/// # Example: +/// +/// ```ignore +/// +/// // Publicly accessable +/// #[get("/")] +/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse { +/// match user { +/// MaybeUser::User(user) => println!("You are {}", user.username), +/// MaybeUser::Anonymous => println!("Who are you?") +/// } +/// } +/// ``` +pub enum MaybeUser { + User(Model), + Anonymous, +} + +#[cfg(feature = "rocket")] +#[rocket::async_trait] +impl<'r> rocket::request::FromRequest<'r> for MaybeUser { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { + if let Some(session_id) = request.cookies().get("session") { + if let Some(user) = User::from_session(session_id.value().to_string()).await { + return Outcome::Success(MaybeUser::User(user)); + } + } + + Outcome::Success(MaybeUser::Anonymous) + } +} + +#[cfg(feature = "axum")] +impl FromRequestParts for MaybeUser +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let jar = axum_extra::extract::CookieJar::from_headers(&parts.headers); + + if let Some(cookie) = jar.get("session") { + if let Some(user) = User::from_session(cookie.value().to_string()).await { + return Ok(MaybeUser::User(user)); + } + } + + Ok(MaybeUser::Anonymous) + } +} + +impl From for Option> { + fn from(value: MaybeUser) -> Self { + value.take_user() + } +} + +impl MaybeUser { + #[must_use] + pub const fn user(&self) -> Option<&Model> { + match self { + MaybeUser::User(user) => Some(user), + MaybeUser::Anonymous => None, + } + } + + #[must_use] + pub fn take_user(self) -> Option> { + match self { + MaybeUser::User(user) => Some(user), + MaybeUser::Anonymous => None, + } + } +} + +/// Admin User +/// +/// This struct expects an Admin User and returns `Forbidden` otherwise. +/// +/// # Example: +/// +/// ```ignore +/// +/// // Only admin users can access this route +/// #[get("/admin")] +/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse { +/// ... +/// } +/// ``` +pub struct AdminUser(pub Model); + +#[cfg(feature = "rocket")] +#[rocket::async_trait] +impl<'r> rocket::request::FromRequest<'r> for AdminUser { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { + if let Some(session_id) = request.cookies().get("session") { + if let Some(user) = User::from_session(session_id.value().to_string()).await { + if user.read().is_admin() { + return Outcome::Success(AdminUser(user)); + } + } + } + + Outcome::Error((Status::Unauthorized, ())) + } +} + +#[cfg(feature = "axum")] +impl FromRequestParts for AdminUser +where + S: Send + Sync, +{ + type Rejection = (StatusCode, &'static str); + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let jar = axum_extra::extract::CookieJar::from_headers(&parts.headers); + + if let Some(cookie) = jar.get("session") { + if let Some(user) = User::from_session(cookie.value().to_string()).await { + if user.read().is_admin() { + return Ok(AdminUser(user)); + } + } + } + + Err((StatusCode::UNAUTHORIZED, "Unauthorized")) + } +} diff --git a/src/lib.rs b/src/lib.rs index 76ed36a..d5e6304 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,38 +3,19 @@ use data_encoding::HEXUPPER; use rand::RngCore; pub mod csrf; +pub mod extractor; pub mod profile_pic; mod session; mod user; +pub use extractor::APIUser; +pub use extractor::AdminUser; +pub use extractor::MaybeUser; +pub use extractor::UserAuth; pub use session::Session; pub use session::Sessions; -pub use user::APIUser; -pub use user::AdminUser; -pub use user::MaybeUser; pub use user::User; -pub use user::UserAuth; pub use user::UserRole; -/// A macro to check if a user has admin privileges. -/// -/// This macro checks whether the provided user has admin privileges by calling the `is_admin` method on it. -/// If the user is not an admin, it returns a `Forbidden` error with a message indicating the restriction. -/// -/// # Arguments -/// * `$u` - The user to check. -/// -/// # Returns -/// The macro does not return a value directly but controls the flow of execution. If the user is not an admin, -/// it returns a `Forbidden` error immediately and prevents further execution. -#[macro_export] -macro_rules! check_admin { - ($u:ident) => { - if !$u.is_admin() { - return Err($crate::request::api::api_error("Forbidden")); - } - }; -} - pub fn gen_random(token_length: usize) -> String { let mut token_bytes = vec![0u8; token_length]; diff --git a/src/mod.rs b/src/mod.rs deleted file mode 100644 index c6b8f25..0000000 --- a/src/mod.rs +++ /dev/null @@ -1,32 +0,0 @@ -pub mod csrf; -pub mod profile_pic; -mod session; -mod user; -pub use session::Session; -pub use session::Sessions; -pub use user::APIUser; -pub use user::AdminUser; -pub use user::MaybeUser; -pub use user::User; -pub use user::UserAuth; -pub use user::UserRole; - -/// A macro to check if a user has admin privileges. -/// -/// This macro checks whether the provided user has admin privileges by calling the `is_admin` method on it. -/// If the user is not an admin, it returns a `Forbidden` error with a message indicating the restriction. -/// -/// # Arguments -/// * `$u` - The user to check. -/// -/// # Returns -/// The macro does not return a value directly but controls the flow of execution. If the user is not an admin, -/// it returns a `Forbidden` error immediately and prevents further execution. -#[macro_export] -macro_rules! check_admin { - ($u:ident) => { - if !$u.is_admin() { - return Err($crate::request::api::api_error("Forbidden")); - } - }; -} diff --git a/src/user.rs b/src/user.rs index a3bd37f..30ef1de 100644 --- a/src/user.rs +++ b/src/user.rs @@ -1,23 +1,9 @@ use owl::{db::model::file::File, get, prelude::*, query, save, update}; -use rocket::{Request, http::Status, outcome::Outcome, request::FromRequest}; use serde::{Deserialize, Serialize}; -use super::Sessions; - // TODO : 2FA /// User -/// -/// # Example: -/// -/// ```ignore -/// -/// // Needs login -/// #[get("/myaccount")] -/// pub async fn account_page(ctx: RequestContext, user: User) -> StringResponse { -/// ... -/// } -/// ``` #[derive(Debug)] #[model] pub struct User { @@ -102,139 +88,3 @@ impl User { bcrypt::verify(password, &self.password).unwrap() } } - -/// extracts a user from a request with `session` cookie -async fn extract_user(request: &Request<'_>) -> Option> { - if let Some(session_id) = request.cookies().get("session") { - if let Some(user) = User::from_session(session_id.value().to_string()).await { - return Some(user); - } - return None; - } - - None -} - -pub struct UserAuth(pub Model); - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for UserAuth { - type Error = (); - - async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { - if let Some(user) = extract_user(request).await { - return Outcome::Success(UserAuth(user)); - } - Outcome::Error((Status::Unauthorized, ())) - } -} - -/// Struct which extracts a user with session from `Token` HTTP Header. -pub struct APIUser(pub Model); - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for APIUser { - type Error = (); - - async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { - match request.headers().get_one("token") { - Some(key) => { - if let Some(user) = User::from_session(key.to_string()).await { - return Outcome::Success(APIUser(user)); - } - return Outcome::Error((Status::Unauthorized, ())); - } - None => Outcome::Error((Status::Unauthorized, ())), - } - } -} - -/// Maybe User? -/// -/// This struct extracts a user if possible, but also allows anybody. -/// -/// # Example: -/// -/// ```ignore -/// -/// // Publicly accessable -/// #[get("/")] -/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse { -/// match user { -/// MaybeUser::User(user) => println!("You are {}", user.username), -/// MaybeUser::Anonymous => println!("Who are you?") -/// } -/// } -/// ``` -pub enum MaybeUser { - User(Model), - Anonymous, -} - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for MaybeUser { - type Error = (); - - async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { - if let Some(user) = extract_user(request).await { - return Outcome::Success(MaybeUser::User(user)); - } - - Outcome::Success(MaybeUser::Anonymous) - } -} - -impl From for Option> { - fn from(value: MaybeUser) -> Self { - value.take_user() - } -} - -impl MaybeUser { - #[must_use] - pub const fn user(&self) -> Option<&Model> { - match self { - MaybeUser::User(user) => Some(user), - MaybeUser::Anonymous => None, - } - } - - #[must_use] - pub fn take_user(self) -> Option> { - match self { - MaybeUser::User(user) => Some(user), - MaybeUser::Anonymous => None, - } - } -} - -/// Admin User -/// -/// This struct expects an Admin User and returns `Forbidden` otherwise. -/// -/// # Example: -/// -/// ```ignore -/// -/// // Only admin users can access this route -/// #[get("/admin")] -/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse { -/// ... -/// } -/// ``` -pub struct AdminUser(pub Model); - -#[rocket::async_trait] -impl<'r> FromRequest<'r> for AdminUser { - type Error = (); - - async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { - if let Some(user) = extract_user(request).await { - if user.read().is_admin() { - return Outcome::Success(AdminUser(user)); - } - } - - Outcome::Error((Status::Unauthorized, ())) - } -}