diff --git a/src/auth/mod.rs b/src/auth/mod.rs index fd2d82d..82585bb 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize}; use sqlx::FromRow; mod user; +pub use user::AdminUser; +pub use user::MaybeUser; pub use user::User; pub use user::UserRole; diff --git a/src/auth/user.rs b/src/auth/user.rs index 08b2b12..c31063b 100644 --- a/src/auth/user.rs +++ b/src/auth/user.rs @@ -6,6 +6,18 @@ use sqlx::FromRow; use super::{Session, gen_token}; use crate::{get_pg, request::api::ToAPI}; +/// User +/// +/// # Example: +/// +/// ```ignore +/// +/// // Needs login +/// #[get("/myaccount")] +/// pub async fn account_page(ctx: RequestContext, user: User) -> StringResponse { +/// ... +/// } +/// ``` #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct User { /// The username chosen by the user @@ -139,28 +151,102 @@ impl ToAPI for User { } } +async fn extract_user<'r>(request: &'r Request<'_>) -> Option { + if let Some(session_id) = request.cookies().get("session_id") { + if let Some(user) = User::from_session(session_id.value()).await { + return Some(user); + } else { + return None; + } + } + + match request.headers().get_one("token") { + Some(key) => { + if let Some(user) = User::from_session(key).await { + return Some(user); + } else { + return None; + } + } + None => None, + } +} + #[rocket::async_trait] impl<'r> FromRequest<'r> for User { type Error = (); async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { - if let Some(session_id) = request.cookies().get("session_id") { - if let Some(user) = User::from_session(session_id.value()).await { - return Outcome::Success(user); - } else { - return Outcome::Error((Status::Unauthorized, ())); - } - } - - match request.headers().get_one("token") { - Some(key) => { - if let Some(user) = User::from_session(key).await { - Outcome::Success(user) - } else { - Outcome::Error((Status::Unauthorized, ())) - } - } - None => Outcome::Error((Status::Unauthorized, ())), + if let Some(user) = extract_user(request).await { + return Outcome::Success(user); + } else { + return Outcome::Error((Status::Unauthorized, ())); } } } + +/// Maybe User? +/// +/// This struct extracts a user if possible, but also allows anybody. +/// +/// # Example: +/// +/// ```ignore +/// +/// // Publicly accessable +/// #[get("/")] +/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse { +/// match user { +/// MaybeUser::User(user) => println!("You are {}", user.username), +/// MaybeUser::Anonymous => println!("Who are you?") +/// } +/// } +/// ``` +pub enum MaybeUser { + User(User), + Anonymous, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for MaybeUser { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { + if let Some(user) = extract_user(request).await { + return Outcome::Success(MaybeUser::User(user)); + } else { + return Outcome::Success(MaybeUser::Anonymous); + } + } +} + +/// Admin User +/// +/// This struct expects an Admin User and returns `Forbidden` otherwise. +/// +/// # Example: +/// +/// ```ignore +/// +/// // Only admin users can access this route +/// #[get("/admin")] +/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse { +/// ... +/// } +/// ``` +pub struct AdminUser(User); + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for AdminUser { + type Error = (); + + async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { + if let Some(user) = extract_user(request).await { + if user.is_admin() { + return Outcome::Success(AdminUser(user)); + } + } else { + } + Outcome::Error((Status::Unauthorized, ())) + } +}