use rocket::{Request, http::Status, outcome::Outcome, request::FromRequest}; use serde::{Deserialize, Serialize}; use serde_json::json; use sqlx::FromRow; use super::{Session, gen_token}; use crate::{get_pg, request::api::ToAPI}; #[derive(Debug, Clone, Serialize, Deserialize, FromRow)] pub struct User { /// The username chosen by the user pub username: String, /// The hashed password for the user pub password: String, /// The role of the user pub user_role: UserRole, } #[derive(Debug, Clone, Serialize, Deserialize, sqlx::Type)] #[sqlx(type_name = "user_role", rename_all = "lowercase")] pub enum UserRole { /// A regular user with limited permissions Regular, /// An admin user with full system privileges Admin, } impl User { // Get a user from session ID pub async fn from_session(session: &str) -> Option { sqlx::query_as("SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)").bind(session).fetch_optional(get_pg!()).await.unwrap() } /// Find a user by their username pub async fn find(username: &str) -> Option { sqlx::query_as("SELECT * FROM users WHERE username = $1") .bind(username) .fetch_optional(get_pg!()) .await .unwrap() } /// Create a new user with the given details /// /// Returns an Option containing the created user, or None if a user already exists with the same username pub async fn create(username: &str, password: &str, role: UserRole) -> Option { // Check if a user already exists with the same username if Self::find(username).await.is_some() { return None; } let u = Self { username: username.to_string(), password: bcrypt::hash(password, bcrypt::DEFAULT_COST).unwrap(), user_role: role, }; sqlx::query("INSERT INTO users (username, \"password\", user_role) VALUES ($1, $2, $3)") .bind(&u.username) .bind(&u.password) .bind(&u.user_role) .execute(get_pg!()) .await .unwrap(); Some(u) } /// Login a user with the given username and password pub async fn login(username: &str, password: &str) -> Option<(Session, UserRole)> { let u = Self::find(username).await?; if !u.verify_pw(password) { return None; } Some((u.session().await, u.user_role)) } /// Change the password of a User /// /// Returns a Result indicating whether the password change was successful or not pub async fn passwd(self, old: &str, new: &str) -> Result<(), ()> { if self.verify_pw(old) { sqlx::query("UPDATE users SET \"password\" = $1 WHERE username = $2;") .bind(bcrypt::hash(new, bcrypt::DEFAULT_COST).unwrap()) .bind(&self.username) .fetch_one(get_pg!()) .await .unwrap(); return Ok(()); } Err(()) } /// Find all users in the system pub async fn find_all() -> Vec { sqlx::query_as("SELECT * FROM users") .fetch_all(get_pg!()) .await .unwrap() } /// Generate a new session token for the user /// /// Returns a Session instance containing the generated token and associated user pub async fn session(&self) -> Session { sqlx::query_as( "INSERT INTO user_session (token, \"user\") VALUES ($1, $2) RETURNING id, token, \"user\"", ) .bind(gen_token(64)) .bind(&self.username) .fetch_one(get_pg!()) .await .unwrap() } /// Check if the user is an admin pub const fn is_admin(&self) -> bool { matches!(self.user_role, UserRole::Admin) } /// Verify that a provided password matches the hashed password for the user /// /// Returns a boolean indicating whether the passwords match or not pub fn verify_pw(&self, password: &str) -> bool { bcrypt::verify(password, &self.password).unwrap() } } impl ToAPI for User { async fn api(&self) -> serde_json::Value { json!({ "username": self.username, "role": self.user_role }) } } #[rocket::async_trait] impl<'r> FromRequest<'r> for User { type Error = (); async fn from_request(request: &'r Request<'_>) -> rocket::request::Outcome { if let Some(session_id) = request.cookies().get("session_id") { if let Some(user) = User::from_session(session_id.value()).await { return Outcome::Success(user); } else { return Outcome::Error((Status::Unauthorized, ())); } } match request.headers().get_one("token") { Some(key) => { if let Some(user) = User::from_session(key).await { Outcome::Success(user) } else { Outcome::Error((Status::Unauthorized, ())) } } None => Outcome::Error((Status::Unauthorized, ())), } } }