extractors

This commit is contained in:
JMARyA 2025-05-06 12:39:14 +02:00
parent 3e3fd0fd83
commit bab73914bd
Signed by: jmarya
GPG key ID: 901B2ADDF27C2263
6 changed files with 471 additions and 213 deletions

224
Cargo.lock generated
View file

@ -207,6 +207,84 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "axum"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"itoa",
"matchit",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68464cd0412f486726fb3373129ef5d2993f90c34bc2bc1c1e9943b2f4fc7ca6"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-extra"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45bf463831f5131b7d3c756525b305d40f1185b688565648a92e1392ca35713d"
dependencies = [
"axum",
"axum-core",
"bytes",
"cookie",
"futures-util",
"headers",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"serde",
"tower",
"tower-layer",
"tower-service",
]
[[package]]
name = "backtrace"
version = "0.3.74"
@ -222,6 +300,12 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]]
name = "base64"
version = "0.22.1"
@ -238,6 +322,8 @@ checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
name = "based_auth"
version = "0.1.0"
dependencies = [
"axum",
"axum-extra",
"bcrypt",
"chrono",
"data-encoding",
@ -257,7 +343,7 @@ version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b1866ecef4f2d06a0bb77880015fdf2b89e25a1c2e5addacb87e459c86dc67e"
dependencies = [
"base64",
"base64 0.22.1",
"blowfish",
"getrandom 0.2.16",
"subtle",
@ -1009,6 +1095,30 @@ dependencies = [
"hashbrown 0.15.3",
]
[[package]]
name = "headers"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "322106e6bd0cba2d5ead589ddb8150a13d7c4217cf80d7c4f682ca994ccc6aa9"
dependencies = [
"base64 0.21.7",
"bytes",
"headers-core",
"http 1.3.1",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
dependencies = [
"http 1.3.1",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -1093,6 +1203,29 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http 1.3.1",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"pin-project-lite",
]
[[package]]
name = "httparse"
version = "1.10.1"
@ -1123,7 +1256,7 @@ dependencies = [
"futures-util",
"h2",
"http 0.2.12",
"http-body",
"http-body 0.4.6",
"httparse",
"httpdate",
"itoa",
@ -1135,6 +1268,41 @@ dependencies = [
"want",
]
[[package]]
name = "hyper"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
"tokio",
]
[[package]]
name = "hyper-util"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2"
dependencies = [
"bytes",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"hyper 1.6.0",
"pin-project-lite",
"tokio",
"tower-service",
]
[[package]]
name = "iana-time-zone"
version = "0.1.63"
@ -1475,6 +1643,12 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3"
[[package]]
name = "md-5"
version = "0.10.6"
@ -2162,7 +2336,7 @@ dependencies = [
"either",
"futures",
"http 0.2.12",
"hyper",
"hyper 0.14.32",
"indexmap",
"log",
"memchr",
@ -2321,6 +2495,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_spanned"
version = "0.6.8"
@ -2464,7 +2648,7 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f743f2a3cea30a58cd479013f75550e879009e3a02f616f18ca699335aa248c3"
dependencies = [
"base64",
"base64 0.22.1",
"bytes",
"chrono",
"crc",
@ -2541,7 +2725,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0afdd3aa7a629683c2d750c2df343025545087081ab5942593a5288855b1b7a7"
dependencies = [
"atoi",
"base64",
"base64 0.22.1",
"bitflags",
"byteorder",
"bytes",
@ -2585,7 +2769,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0bedbe1bbb5e2615ef347a5e9d8cd7680fb63e77d9dafc0f29be15e53f1ebe6"
dependencies = [
"atoi",
"base64",
"base64 0.22.1",
"bitflags",
"byteorder",
"chrono",
@ -2695,6 +2879,12 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "sync_wrapper"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263"
[[package]]
name = "synstructure"
version = "0.13.2"
@ -2928,6 +3118,28 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076"
[[package]]
name = "tower"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"pin-project-lite",
"sync_wrapper",
"tokio",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"

View file

@ -3,15 +3,23 @@ name = "based_auth"
version = "0.1.0"
edition = "2024"
[features]
default = []
rocket = ["dep:rocket"]
axum = ["dep:axum", "dep:axum-extra"]
[dependencies]
env_logger = "0.10.0"
hex = "0.4.3"
chrono = { version = "0.4", features = ["serde"] }
log = "0.4.20"
rocket = { version = "0.5.1", features = ["json"] }
serde = { version = "1.0.195", features = ["derive"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }
rand = "0.8.5"
data-encoding = "2.6.0"
bcrypt = "0.16.0"
owl = { git = "https://git.hydrar.de/red/owl" }
rocket = { version = "0.5.1", features = ["json"], optional = true }
axum = { version = "0.8.4", optional = true }
axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"], optional = true }

239
src/extractor.rs Normal file
View file

@ -0,0 +1,239 @@
#[cfg(feature = "axum")]
use axum::{
extract::FromRequestParts,
http::{StatusCode, request::Parts},
};
#[cfg(feature = "axum")]
use axum_extra::{
TypedHeader,
headers::{Authorization, authorization::Bearer},
};
use owl::db::Model;
#[cfg(feature = "rocket")]
use rocket::{Request, http::Status, outcome::Outcome};
use crate::{Sessions, User};
pub struct UserAuth(pub Model<User>);
#[cfg(feature = "axum")]
impl<S> FromRequestParts<S> for UserAuth
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let jar = axum_extra::extract::CookieJar::from_headers(&parts.headers);
if let Some(cookie) = jar.get("session") {
if let Some(user) = User::from_session(cookie.value().to_string()).await {
return Ok(UserAuth(user));
}
}
Err((StatusCode::UNAUTHORIZED, "Unauthorized"))
}
}
#[cfg(feature = "rocket")]
#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for UserAuth {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(session_id) = request.cookies().get("session") {
if let Some(user) = User::from_session(session_id.value().to_string()).await {
return Outcome::Success(UserAuth(user));
}
}
Outcome::Error((Status::Unauthorized, ()))
}
}
/// Struct which extracts a user with session from `Token` HTTP Header.
pub struct APIUser(pub Model<User>);
#[cfg(feature = "rocket")]
#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for APIUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(auth_header) = request.headers().get_one("Authorization") {
// Expect "Bearer <token>"
if let Some(token) = auth_header.strip_prefix("Bearer ").map(str::trim) {
if let Some(user) = User::from_session(token.to_string()).await {
return Outcome::Success(APIUser(user));
}
}
}
Outcome::Error((Status::Unauthorized, ()))
}
}
#[cfg(feature = "axum")]
impl<S> FromRequestParts<S> for APIUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let TypedHeader(Authorization(bearer)) =
TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state)
.await
.map_err(|_| {
(
StatusCode::UNAUTHORIZED,
"Missing or invalid Authorization header",
)
})?;
let token = bearer.token();
match User::from_session(token.to_string()).await {
Some(user) => Ok(APIUser(user)),
None => Err((StatusCode::UNAUTHORIZED, "Invalid token")),
}
}
}
/// Maybe User?
///
/// This struct extracts a user if possible, but also allows anybody.
///
/// # Example:
///
/// ```ignore
///
/// // Publicly accessable
/// #[get("/")]
/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse {
/// match user {
/// MaybeUser::User(user) => println!("You are {}", user.username),
/// MaybeUser::Anonymous => println!("Who are you?")
/// }
/// }
/// ```
pub enum MaybeUser {
User(Model<User>),
Anonymous,
}
#[cfg(feature = "rocket")]
#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for MaybeUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(session_id) = request.cookies().get("session") {
if let Some(user) = User::from_session(session_id.value().to_string()).await {
return Outcome::Success(MaybeUser::User(user));
}
}
Outcome::Success(MaybeUser::Anonymous)
}
}
#[cfg(feature = "axum")]
impl<S> FromRequestParts<S> for MaybeUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let jar = axum_extra::extract::CookieJar::from_headers(&parts.headers);
if let Some(cookie) = jar.get("session") {
if let Some(user) = User::from_session(cookie.value().to_string()).await {
return Ok(MaybeUser::User(user));
}
}
Ok(MaybeUser::Anonymous)
}
}
impl From<MaybeUser> for Option<Model<User>> {
fn from(value: MaybeUser) -> Self {
value.take_user()
}
}
impl MaybeUser {
#[must_use]
pub const fn user(&self) -> Option<&Model<User>> {
match self {
MaybeUser::User(user) => Some(user),
MaybeUser::Anonymous => None,
}
}
#[must_use]
pub fn take_user(self) -> Option<Model<User>> {
match self {
MaybeUser::User(user) => Some(user),
MaybeUser::Anonymous => None,
}
}
}
/// Admin User
///
/// This struct expects an Admin User and returns `Forbidden` otherwise.
///
/// # Example:
///
/// ```ignore
///
/// // Only admin users can access this route
/// #[get("/admin")]
/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse {
/// ...
/// }
/// ```
pub struct AdminUser(pub Model<User>);
#[cfg(feature = "rocket")]
#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for AdminUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(session_id) = request.cookies().get("session") {
if let Some(user) = User::from_session(session_id.value().to_string()).await {
if user.read().is_admin() {
return Outcome::Success(AdminUser(user));
}
}
}
Outcome::Error((Status::Unauthorized, ()))
}
}
#[cfg(feature = "axum")]
impl<S> FromRequestParts<S> for AdminUser
where
S: Send + Sync,
{
type Rejection = (StatusCode, &'static str);
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let jar = axum_extra::extract::CookieJar::from_headers(&parts.headers);
if let Some(cookie) = jar.get("session") {
if let Some(user) = User::from_session(cookie.value().to_string()).await {
if user.read().is_admin() {
return Ok(AdminUser(user));
}
}
}
Err((StatusCode::UNAUTHORIZED, "Unauthorized"))
}
}

View file

@ -3,38 +3,19 @@ use data_encoding::HEXUPPER;
use rand::RngCore;
pub mod csrf;
pub mod extractor;
pub mod profile_pic;
mod session;
mod user;
pub use extractor::APIUser;
pub use extractor::AdminUser;
pub use extractor::MaybeUser;
pub use extractor::UserAuth;
pub use session::Session;
pub use session::Sessions;
pub use user::APIUser;
pub use user::AdminUser;
pub use user::MaybeUser;
pub use user::User;
pub use user::UserAuth;
pub use user::UserRole;
/// A macro to check if a user has admin privileges.
///
/// This macro checks whether the provided user has admin privileges by calling the `is_admin` method on it.
/// If the user is not an admin, it returns a `Forbidden` error with a message indicating the restriction.
///
/// # Arguments
/// * `$u` - The user to check.
///
/// # Returns
/// The macro does not return a value directly but controls the flow of execution. If the user is not an admin,
/// it returns a `Forbidden` error immediately and prevents further execution.
#[macro_export]
macro_rules! check_admin {
($u:ident) => {
if !$u.is_admin() {
return Err($crate::request::api::api_error("Forbidden"));
}
};
}
pub fn gen_random(token_length: usize) -> String {
let mut token_bytes = vec![0u8; token_length];

View file

@ -1,32 +0,0 @@
pub mod csrf;
pub mod profile_pic;
mod session;
mod user;
pub use session::Session;
pub use session::Sessions;
pub use user::APIUser;
pub use user::AdminUser;
pub use user::MaybeUser;
pub use user::User;
pub use user::UserAuth;
pub use user::UserRole;
/// A macro to check if a user has admin privileges.
///
/// This macro checks whether the provided user has admin privileges by calling the `is_admin` method on it.
/// If the user is not an admin, it returns a `Forbidden` error with a message indicating the restriction.
///
/// # Arguments
/// * `$u` - The user to check.
///
/// # Returns
/// The macro does not return a value directly but controls the flow of execution. If the user is not an admin,
/// it returns a `Forbidden` error immediately and prevents further execution.
#[macro_export]
macro_rules! check_admin {
($u:ident) => {
if !$u.is_admin() {
return Err($crate::request::api::api_error("Forbidden"));
}
};
}

View file

@ -1,23 +1,9 @@
use owl::{db::model::file::File, get, prelude::*, query, save, update};
use rocket::{Request, http::Status, outcome::Outcome, request::FromRequest};
use serde::{Deserialize, Serialize};
use super::Sessions;
// TODO : 2FA
/// User
///
/// # Example:
///
/// ```ignore
///
/// // Needs login
/// #[get("/myaccount")]
/// pub async fn account_page(ctx: RequestContext, user: User) -> StringResponse {
/// ...
/// }
/// ```
#[derive(Debug)]
#[model]
pub struct User {
@ -102,139 +88,3 @@ impl User {
bcrypt::verify(password, &self.password).unwrap()
}
}
/// extracts a user from a request with `session` cookie
async fn extract_user(request: &Request<'_>) -> Option<Model<User>> {
if let Some(session_id) = request.cookies().get("session") {
if let Some(user) = User::from_session(session_id.value().to_string()).await {
return Some(user);
}
return None;
}
None
}
pub struct UserAuth(pub Model<User>);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for UserAuth {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(user) = extract_user(request).await {
return Outcome::Success(UserAuth(user));
}
Outcome::Error((Status::Unauthorized, ()))
}
}
/// Struct which extracts a user with session from `Token` HTTP Header.
pub struct APIUser(pub Model<User>);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for APIUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
match request.headers().get_one("token") {
Some(key) => {
if let Some(user) = User::from_session(key.to_string()).await {
return Outcome::Success(APIUser(user));
}
return Outcome::Error((Status::Unauthorized, ()));
}
None => Outcome::Error((Status::Unauthorized, ())),
}
}
}
/// Maybe User?
///
/// This struct extracts a user if possible, but also allows anybody.
///
/// # Example:
///
/// ```ignore
///
/// // Publicly accessable
/// #[get("/")]
/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse {
/// match user {
/// MaybeUser::User(user) => println!("You are {}", user.username),
/// MaybeUser::Anonymous => println!("Who are you?")
/// }
/// }
/// ```
pub enum MaybeUser {
User(Model<User>),
Anonymous,
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for MaybeUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(user) = extract_user(request).await {
return Outcome::Success(MaybeUser::User(user));
}
Outcome::Success(MaybeUser::Anonymous)
}
}
impl From<MaybeUser> for Option<Model<User>> {
fn from(value: MaybeUser) -> Self {
value.take_user()
}
}
impl MaybeUser {
#[must_use]
pub const fn user(&self) -> Option<&Model<User>> {
match self {
MaybeUser::User(user) => Some(user),
MaybeUser::Anonymous => None,
}
}
#[must_use]
pub fn take_user(self) -> Option<Model<User>> {
match self {
MaybeUser::User(user) => Some(user),
MaybeUser::Anonymous => None,
}
}
}
/// Admin User
///
/// This struct expects an Admin User and returns `Forbidden` otherwise.
///
/// # Example:
///
/// ```ignore
///
/// // Only admin users can access this route
/// #[get("/admin")]
/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse {
/// ...
/// }
/// ```
pub struct AdminUser(pub Model<User>);
#[rocket::async_trait]
impl<'r> FromRequest<'r> for AdminUser {
type Error = ();
async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
if let Some(user) = extract_user(request).await {
if user.read().is_admin() {
return Outcome::Success(AdminUser(user));
}
}
Outcome::Error((Status::Unauthorized, ()))
}
}