2024-12-18 14:33:53 +01:00
use rocket ::{ Request , http ::Status , outcome ::Outcome , request ::FromRequest } ;
2024-12-17 23:28:43 +01:00
use serde ::{ Deserialize , Serialize } ;
use serde_json ::json ;
use sqlx ::FromRow ;
2024-12-18 14:33:53 +01:00
use super ::{ Session , gen_token } ;
2024-12-17 23:28:43 +01:00
use crate ::{ get_pg , request ::api ::ToAPI } ;
#[ derive(Debug, Clone, Serialize, Deserialize, FromRow) ]
pub struct User {
/// The username chosen by the user
pub username : String ,
/// The hashed password for the user
pub password : String ,
/// The role of the user
pub user_role : UserRole ,
}
#[ derive(Debug, Clone, Serialize, Deserialize, sqlx::Type) ]
#[ sqlx(type_name = " user_role " , rename_all = " lowercase " ) ]
pub enum UserRole {
/// A regular user with limited permissions
Regular ,
/// An admin user with full system privileges
Admin ,
}
impl User {
2024-12-18 14:33:53 +01:00
// Get a user from session ID
pub async fn from_session ( session : & str ) -> Option < Self > {
sqlx ::query_as ( " SELECT * FROM users WHERE username = (SELECT \" user \" FROM user_session WHERE token = $1) " ) . bind ( session ) . fetch_optional ( get_pg! ( ) ) . await . unwrap ( )
}
2024-12-17 23:28:43 +01:00
/// Find a user by their username
pub async fn find ( username : & str ) -> Option < Self > {
sqlx ::query_as ( " SELECT * FROM users WHERE username = $1 " )
. bind ( username )
. fetch_optional ( get_pg! ( ) )
. await
. unwrap ( )
}
/// Create a new user with the given details
///
/// Returns an Option containing the created user, or None if a user already exists with the same username
pub async fn create ( username : & str , password : & str , role : UserRole ) -> Option < Self > {
// Check if a user already exists with the same username
if Self ::find ( username ) . await . is_some ( ) {
return None ;
}
let u = Self {
username : username . to_string ( ) ,
password : bcrypt ::hash ( password , bcrypt ::DEFAULT_COST ) . unwrap ( ) ,
user_role : role ,
} ;
sqlx ::query ( " INSERT INTO users (username, \" password \" , user_role) VALUES ($1, $2, $3) " )
. bind ( & u . username )
. bind ( & u . password )
. bind ( & u . user_role )
. execute ( get_pg! ( ) )
. await
. unwrap ( ) ;
Some ( u )
}
/// Login a user with the given username and password
pub async fn login ( username : & str , password : & str ) -> Option < ( Session , UserRole ) > {
let u = Self ::find ( username ) . await ? ;
if ! u . verify_pw ( password ) {
return None ;
}
Some ( ( u . session ( ) . await , u . user_role ) )
}
/// Change the password of a User
///
/// Returns a Result indicating whether the password change was successful or not
pub async fn passwd ( self , old : & str , new : & str ) -> Result < ( ) , ( ) > {
if self . verify_pw ( old ) {
sqlx ::query ( " UPDATE users SET \" password \" = $1 WHERE username = $2; " )
. bind ( bcrypt ::hash ( new , bcrypt ::DEFAULT_COST ) . unwrap ( ) )
. bind ( & self . username )
. fetch_one ( get_pg! ( ) )
. await
. unwrap ( ) ;
return Ok ( ( ) ) ;
}
Err ( ( ) )
}
/// Find all users in the system
pub async fn find_all ( ) -> Vec < Self > {
sqlx ::query_as ( " SELECT * FROM users " )
. fetch_all ( get_pg! ( ) )
. await
. unwrap ( )
}
/// Generate a new session token for the user
///
/// Returns a Session instance containing the generated token and associated user
pub async fn session ( & self ) -> Session {
sqlx ::query_as (
" INSERT INTO user_session (token, \" user \" ) VALUES ($1, $2) RETURNING id, token, \" user \" " ,
)
. bind ( gen_token ( 64 ) )
. bind ( & self . username )
. fetch_one ( get_pg! ( ) )
. await
. unwrap ( )
}
/// Check if the user is an admin
pub const fn is_admin ( & self ) -> bool {
matches! ( self . user_role , UserRole ::Admin )
}
/// Verify that a provided password matches the hashed password for the user
///
/// Returns a boolean indicating whether the passwords match or not
pub fn verify_pw ( & self , password : & str ) -> bool {
bcrypt ::verify ( password , & self . password ) . unwrap ( )
}
}
impl ToAPI for User {
async fn api ( & self ) -> serde_json ::Value {
json! ( {
" username " : self . username ,
" role " : self . user_role
} )
}
}
#[ rocket::async_trait ]
impl < ' r > FromRequest < ' r > for User {
type Error = ( ) ;
async fn from_request ( request : & ' r Request < '_ > ) -> rocket ::request ::Outcome < Self , Self ::Error > {
2024-12-18 14:33:53 +01:00
if let Some ( session_id ) = request . cookies ( ) . get ( " session_id " ) {
if let Some ( user ) = User ::from_session ( session_id . value ( ) ) . await {
return Outcome ::Success ( user ) ;
} else {
return Outcome ::Error ( ( Status ::Unauthorized , ( ) ) ) ;
}
}
2024-12-17 23:28:43 +01:00
match request . headers ( ) . get_one ( " token " ) {
Some ( key ) = > {
2024-12-18 14:33:53 +01:00
if let Some ( user ) = User ::from_session ( key ) . await {
2024-12-17 23:28:43 +01:00
Outcome ::Success ( user )
} else {
Outcome ::Error ( ( Status ::Unauthorized , ( ) ) )
}
}
None = > Outcome ::Error ( ( Status ::Unauthorized , ( ) ) ) ,
}
}
}