2024-12-18 14:33:53 +01:00
use rocket ::{ Request , http ::Status , outcome ::Outcome , request ::FromRequest } ;
2024-12-17 23:28:43 +01:00
use serde ::{ Deserialize , Serialize } ;
use serde_json ::json ;
use sqlx ::FromRow ;
2024-12-18 14:33:53 +01:00
use super ::{ Session , gen_token } ;
2024-12-17 23:28:43 +01:00
use crate ::{ get_pg , request ::api ::ToAPI } ;
2024-12-18 19:55:21 +01:00
/// User
///
/// # Example:
///
/// ```ignore
///
/// // Needs login
/// #[get("/myaccount")]
/// pub async fn account_page(ctx: RequestContext, user: User) -> StringResponse {
/// ...
/// }
/// ```
2024-12-17 23:28:43 +01:00
#[ derive(Debug, Clone, Serialize, Deserialize, FromRow) ]
pub struct User {
/// The username chosen by the user
pub username : String ,
/// The hashed password for the user
pub password : String ,
/// The role of the user
pub user_role : UserRole ,
}
#[ derive(Debug, Clone, Serialize, Deserialize, sqlx::Type) ]
#[ sqlx(type_name = " user_role " , rename_all = " lowercase " ) ]
pub enum UserRole {
/// A regular user with limited permissions
Regular ,
/// An admin user with full system privileges
Admin ,
}
impl User {
2024-12-18 14:33:53 +01:00
// Get a user from session ID
pub async fn from_session ( session : & str ) -> Option < Self > {
sqlx ::query_as ( " SELECT * FROM users WHERE username = (SELECT \" user \" FROM user_session WHERE token = $1) " ) . bind ( session ) . fetch_optional ( get_pg! ( ) ) . await . unwrap ( )
}
2024-12-17 23:28:43 +01:00
/// Find a user by their username
pub async fn find ( username : & str ) -> Option < Self > {
sqlx ::query_as ( " SELECT * FROM users WHERE username = $1 " )
. bind ( username )
. fetch_optional ( get_pg! ( ) )
. await
. unwrap ( )
}
/// Create a new user with the given details
///
/// Returns an Option containing the created user, or None if a user already exists with the same username
pub async fn create ( username : & str , password : & str , role : UserRole ) -> Option < Self > {
// Check if a user already exists with the same username
if Self ::find ( username ) . await . is_some ( ) {
return None ;
}
let u = Self {
username : username . to_string ( ) ,
password : bcrypt ::hash ( password , bcrypt ::DEFAULT_COST ) . unwrap ( ) ,
user_role : role ,
} ;
sqlx ::query ( " INSERT INTO users (username, \" password \" , user_role) VALUES ($1, $2, $3) " )
. bind ( & u . username )
. bind ( & u . password )
. bind ( & u . user_role )
. execute ( get_pg! ( ) )
. await
. unwrap ( ) ;
Some ( u )
}
/// Login a user with the given username and password
pub async fn login ( username : & str , password : & str ) -> Option < ( Session , UserRole ) > {
let u = Self ::find ( username ) . await ? ;
if ! u . verify_pw ( password ) {
return None ;
}
Some ( ( u . session ( ) . await , u . user_role ) )
}
/// Change the password of a User
///
/// Returns a Result indicating whether the password change was successful or not
pub async fn passwd ( self , old : & str , new : & str ) -> Result < ( ) , ( ) > {
if self . verify_pw ( old ) {
sqlx ::query ( " UPDATE users SET \" password \" = $1 WHERE username = $2; " )
. bind ( bcrypt ::hash ( new , bcrypt ::DEFAULT_COST ) . unwrap ( ) )
. bind ( & self . username )
. fetch_one ( get_pg! ( ) )
. await
. unwrap ( ) ;
return Ok ( ( ) ) ;
}
Err ( ( ) )
}
/// Find all users in the system
pub async fn find_all ( ) -> Vec < Self > {
sqlx ::query_as ( " SELECT * FROM users " )
. fetch_all ( get_pg! ( ) )
. await
. unwrap ( )
}
/// Generate a new session token for the user
///
/// Returns a Session instance containing the generated token and associated user
pub async fn session ( & self ) -> Session {
sqlx ::query_as (
" INSERT INTO user_session (token, \" user \" ) VALUES ($1, $2) RETURNING id, token, \" user \" " ,
)
. bind ( gen_token ( 64 ) )
. bind ( & self . username )
. fetch_one ( get_pg! ( ) )
. await
. unwrap ( )
}
/// Check if the user is an admin
pub const fn is_admin ( & self ) -> bool {
matches! ( self . user_role , UserRole ::Admin )
}
/// Verify that a provided password matches the hashed password for the user
///
/// Returns a boolean indicating whether the passwords match or not
pub fn verify_pw ( & self , password : & str ) -> bool {
bcrypt ::verify ( password , & self . password ) . unwrap ( )
}
}
impl ToAPI for User {
async fn api ( & self ) -> serde_json ::Value {
json! ( {
" username " : self . username ,
" role " : self . user_role
} )
}
}
2024-12-18 19:55:21 +01:00
async fn extract_user < ' r > ( request : & ' r Request < '_ > ) -> Option < User > {
if let Some ( session_id ) = request . cookies ( ) . get ( " session_id " ) {
if let Some ( user ) = User ::from_session ( session_id . value ( ) ) . await {
return Some ( user ) ;
} else {
return None ;
}
}
match request . headers ( ) . get_one ( " token " ) {
Some ( key ) = > {
if let Some ( user ) = User ::from_session ( key ) . await {
return Some ( user ) ;
} else {
return None ;
}
}
None = > None ,
}
}
2024-12-17 23:28:43 +01:00
#[ rocket::async_trait ]
impl < ' r > FromRequest < ' r > for User {
type Error = ( ) ;
async fn from_request ( request : & ' r Request < '_ > ) -> rocket ::request ::Outcome < Self , Self ::Error > {
2024-12-18 19:55:21 +01:00
if let Some ( user ) = extract_user ( request ) . await {
return Outcome ::Success ( user ) ;
} else {
return Outcome ::Error ( ( Status ::Unauthorized , ( ) ) ) ;
2024-12-18 14:33:53 +01:00
}
2024-12-18 19:55:21 +01:00
}
}
2024-12-18 14:33:53 +01:00
2024-12-18 19:55:21 +01:00
/// Maybe User?
///
/// This struct extracts a user if possible, but also allows anybody.
///
/// # Example:
///
/// ```ignore
///
/// // Publicly accessable
/// #[get("/")]
/// pub async fn index(ctx: RequestContext, user: MaybeUser) -> StringResponse {
/// match user {
/// MaybeUser::User(user) => println!("You are {}", user.username),
/// MaybeUser::Anonymous => println!("Who are you?")
/// }
/// }
/// ```
pub enum MaybeUser {
User ( User ) ,
Anonymous ,
}
#[ rocket::async_trait ]
impl < ' r > FromRequest < ' r > for MaybeUser {
type Error = ( ) ;
async fn from_request ( request : & ' r Request < '_ > ) -> rocket ::request ::Outcome < Self , Self ::Error > {
if let Some ( user ) = extract_user ( request ) . await {
return Outcome ::Success ( MaybeUser ::User ( user ) ) ;
} else {
return Outcome ::Success ( MaybeUser ::Anonymous ) ;
}
}
}
2024-12-18 20:03:59 +01:00
impl From < MaybeUser > for Option < User > {
fn from ( value : MaybeUser ) -> Self {
match value {
MaybeUser ::User ( user ) = > Some ( user ) ,
MaybeUser ::Anonymous = > None ,
}
}
}
2024-12-18 19:55:21 +01:00
/// Admin User
///
/// This struct expects an Admin User and returns `Forbidden` otherwise.
///
/// # Example:
///
/// ```ignore
///
/// // Only admin users can access this route
/// #[get("/admin")]
/// pub async fn admin_panel(ctx: RequestContext, user: AdminUser) -> StringResponse {
/// ...
/// }
/// ```
pub struct AdminUser ( User ) ;
#[ rocket::async_trait ]
impl < ' r > FromRequest < ' r > for AdminUser {
type Error = ( ) ;
async fn from_request ( request : & ' r Request < '_ > ) -> rocket ::request ::Outcome < Self , Self ::Error > {
if let Some ( user ) = extract_user ( request ) . await {
if user . is_admin ( ) {
return Outcome ::Success ( AdminUser ( user ) ) ;
2024-12-17 23:28:43 +01:00
}
2024-12-18 19:55:21 +01:00
} else {
2024-12-17 23:28:43 +01:00
}
2024-12-18 19:55:21 +01:00
Outcome ::Error ( ( Status ::Unauthorized , ( ) ) )
2024-12-17 23:28:43 +01:00
}
}