use maud::PreEscaped; use super::User; use crate::{get_pg, ui::prelude::script}; use std::str::FromStr; pub trait CSRF { fn get_csrf(&self) -> impl std::future::Future; fn verify_csrf(&self, csrf: &str) -> impl std::future::Future; fn update_csrf(&self) -> impl std::future::Future>; } impl CSRF for User { /// Javascript to update the `value` of an element with id `csrf`. /// /// This is useful for htmx requests to update the CSRF token in place. async fn update_csrf(&self) -> PreEscaped { script(&format!( "document.querySelectorAll('.csrf').forEach(element => {{ element.value = '{}'; }});", self.get_csrf().await )) } /// Get CSRF Token for the current session async fn get_csrf(&self) -> uuid::Uuid { let res: (uuid::Uuid,) = sqlx::query_as("SELECT csrf FROM user_session WHERE token = $1") .bind(&self.session) .fetch_one(get_pg!()) .await .unwrap(); res.0 } /// Verify CSRF and generate a new one async fn verify_csrf(&self, csrf: &str) -> bool { if self.get_csrf().await == uuid::Uuid::from_str(csrf).unwrap_or_default() { sqlx::query("UPDATE user_session SET csrf = gen_random_uuid() WHERE token = $1") .bind(&self.session) .execute(get_pg!()) .await .unwrap(); return true; } false } }