use rocket::{Request, http::Status, outcome::Outcome, request::FromRequest}; use serde::{Deserialize, Serialize}; use serde_json::json; use sqlx::FromRow; use super::{Session, gen_token}; use crate::{get_pg, request::api::ToAPI}; /// User /// /// # Example: /// /// ```ignore /// /// // Needs login /// #[get("/myaccount")] /// pub async fn account_page(ctx: RequestContext, user: User) -> StringResponse { /// ... /// } /// ``` #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct User { /// The username chosen by the user pub username: String, /// The hashed password for the user pub password: String, /// The role of the user pub user_role: UserRole, } #[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)] #[sqlx(type_name = "user_role", rename_all = "lowercase")] pub enum UserRole { /// A regular user with limited permissions Regular, /// An admin user with full system privileges Admin, } impl User { // Get a user from session ID pub async fn from_session(session: &str) -> Option { sqlx::query_as("SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)").bind(session).fetch_optional(get_pg!()).await.unwrap() } /// Find a user by their username pub async fn find(username: &str) -> Option { sqlx::query_as("SELECT * FROM users WHERE username = $1") .bind(username) .fetch_optional(get_pg!()) .await .unwrap() } /// Create a new user with the given details /// /// Returns an Option containing the created user, or None if a user already exists with the same username pub async fn create(username: &str, password: &str, role: UserRole) -> Option { // Check if a user already exists with the same username if Self::find(username).await.is_some() { return None; } let u = Self { username: username.to_string(), password: bcrypt::hash(password, bcrypt::DEFAULT_COST).unwrap(), user_role: role, }; sqlx::query("INSERT INTO users (username, \"password\", user_role) VALUES ($1, $2, $3)") .bind(&u.username) .bind(&u.password) .bind(&u.user_role) .execute(get_pg!()) .await .unwrap(); Some(u) } /// Login a user with the given username and password pub async fn login(username: &str, password: &str) -> Option<(Session, UserRole)> { let u = Self::find(username).await?; if !u.verify_pw(password) { return None; } Some((u.session().await, u.user_role)) } /// Change the password of a User /// /// Returns a Result indicating whether the password change was successful or not pub async fn passwd(self, old: &str, new: &str) -> Result<(), ()> { if self.verify_pw(old) { sqlx::query("UPDATE users SET \"password\" = $1 WHERE username = $2;") .bind(bcrypt::hash(new, bcrypt::DEFAULT_COST).unwrap()) .bind(&self.username) .fetch_one(get_pg!()) .await .unwrap(); return Ok(()); } Err(()) } /// Find all users in the system pub async fn find_all() -> Vec { sqlx::query_as("SELECT * FROM users") .fetch_all(get_pg!()) .await .unwrap() } /// Generate a new session token for the user /// /// Returns a Session instance containing the generated token and associated user pub async fn session(&self) -> Session { sqlx::query_as( "INSERT INTO user_session (token, \"user\") VALUES ($1, $2) RETURNING id, token, \"user\"", ) .bind(gen_token(64)) .bind(&self.username) .fetch_one(get_pg!()) .await .unwrap() } /// Check if the user is an admin pub const fn is_admin(&self) -> bool { matches!(self.user_role, UserRole::Admin) } /// Verify that a provided password matches the hashed password for the user /// /// Returns a boolean indicating whether the passwords match or not pub fn verify_pw(&self, password: &str) -> bool { bcrypt::verify(password, &self.password).unwrap() } } impl ToAPI for User { async fn api(&self) -> serde_json::Value { json!({ "username": self.username, "role": self.user_role }) } } async fn extract_user<'r>(request: &'r Request<'_>) -> Option { if let Some(session_id) = request.cookies().get("session_id") { if let Some(user) = User::from_session(session_id.value()).await { return Some(user); } else { return None; } } match request.headers().get_one("token") { Some(key) => { if let Some(user) = User::from_session(key).await { return Some(user); } else { return None; } } None => None, } } #[rocket::async_trait] impl<'r> FromRequest<'r> for User { type Error = (); async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { if let Some(user) = extract_user(request).await { return Outcome::Success(user); } else { return Outcome::Error((Status::Unauthorized, ())); } } } /// Maybe User? /// /// This struct extracts a user if possible, but also allows anybody. /// /// # Example: /// /// ```ignore /// /// // Publicly accessable /// #[get("/")] /// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse { /// match user { /// MaybeUser::User(user) => println!("You are {}", user.username), /// MaybeUser::Anonymous => println!("Who are you?") /// } /// } /// ``` pub enum MaybeUser { User(User), Anonymous, } #[rocket::async_trait] impl<'r> FromRequest<'r> for MaybeUser { type Error = (); async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { if let Some(user) = extract_user(request).await { return Outcome::Success(MaybeUser::User(user)); } else { return Outcome::Success(MaybeUser::Anonymous); } } } /// Admin User /// /// This struct expects an Admin User and returns `Forbidden` otherwise. /// /// # Example: /// /// ```ignore /// /// // Only admin users can access this route /// #[get("/admin")] /// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse { /// ... /// } /// ``` pub struct AdminUser(User); #[rocket::async_trait] impl<'r> FromRequest<'r> for AdminUser { type Error = (); async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { if let Some(user) = extract_user(request).await { if user.is_admin() { return Outcome::Success(AdminUser(user)); } } else { } Outcome::Error((Status::Unauthorized, ())) } }