Compare commits

...
Sign in to create a new pull request.

1 commit
main ... sqlite

Author SHA1 Message Date
397369ae1e
awfully cursed 2024-12-27 00:07:01 +01:00
5 changed files with 140 additions and 3 deletions

View file

@ -19,7 +19,7 @@ serde = { version = "1.0.195", features = ["derive"] }
serde_json = "1.0.111"
tokio = { version = "1.35.1", features = ["full"] }
uuid = { version = "1.8.0", features = ["v4", "serde"] }
sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json"] }
sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-native-tls", "derive", "uuid", "chrono", "json", "sqlite"] }
maud = "0.26.0"
rand = "0.8.5"
data-encoding = "2.6.0"

21
examples/sqlite.rs Normal file
View file

@ -0,0 +1,21 @@
use based::{get_sqlite, get_sqlite_or_create};
use based::request::{RequestContext, StringResponse};
use based::{
get_pg,
page::{Shell, render_page},
};
use maud::html;
use rocket::get;
use rocket::routes;
#[rocket::launch]
async fn launch() -> _ {
// Logging
env_logger::init();
// Database
unsafe { std::env::set_var("DATABASE_URL", "test.db"); }
let db = based::db::Database::new().await;
rocket::build().mount("/", routes![])
}

View file

@ -4,7 +4,7 @@ use serde_json::json;
use sqlx::FromRow;
use super::{Session, gen_token};
use crate::{get_pg, request::api::ToAPI};
use crate::{db::Database, get_pg, request::api::ToAPI, with};
/// User
///
@ -40,7 +40,27 @@ pub enum UserRole {
impl User {
// Get a user from session ID
pub async fn from_session(session: &str) -> Option<Self> {
sqlx::query_as("SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)").bind(session).fetch_optional(get_pg!()).await.unwrap()
with!(
Database::new().await,
{
sqlx::query_as::<_, Self>(
"SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)"
)
.bind(session)
.fetch_optional(&pg)
.await
.ok()
},
{
sqlx::query_as::<_, Self>(
"SELECT * FROM users WHERE username = (SELECT \"user\" FROM user_session WHERE token = $1)"
)
.bind(session)
.fetch_optional(&sqlite)
.await
.ok()
}
)
}
/// Find a user by their username

47
src/db.rs Normal file
View file

@ -0,0 +1,47 @@
use crate::{get_pg, get_sqlite, get_sqlite_or_create};
pub enum DatabaseBackend {
Sqlite,
Postgres
}
pub struct Database {
pub in_use: DatabaseBackend,
pub sqlite: Option<sqlx::Pool<sqlx::Sqlite>>,
pub postgres: Option<sqlx::Pool<sqlx::Postgres>>
}
impl Database {
pub async fn new() -> Self {
if std::env::var("DATABASE_URL").unwrap().contains("postgres") {
return Self {
in_use: DatabaseBackend::Postgres,
sqlite: None,
postgres: Some(get_pg!().clone())
};
} else {
return Self {
in_use: DatabaseBackend::Sqlite,
sqlite: Some(get_sqlite_or_create!().clone()),
postgres: None
};
}
}
}
#[macro_export]
macro_rules! with {
($db:expr, $pg:block, $sqlite:block) => {{
match $db.in_use {
crate::db::DatabaseBackend::Postgres => {
let pg = $db.postgres.clone().expect("Postgres connection not available");
$pg
}
crate::db::DatabaseBackend::Sqlite => {
let sqlite = $db.sqlite.clone().expect("Sqlite connection not available");
$sqlite
}
}
}};
}

View file

@ -7,14 +7,22 @@ pub mod htmx;
pub mod page;
pub mod request;
pub mod result;
pub mod db;
// TODO : API Pagination?
// TODO : CORS?
// TODO : CSRF?
// TODO : Implement SQLITE backend
// Code sharing with macro
// Seperate migrations
// which db by $DB env var "postgres" "sqlite"
// Postgres
pub static PG: OnceCell<sqlx::PgPool> = OnceCell::const_new();
pub static SQ: OnceCell<sqlx::SqlitePool> = OnceCell::const_new();
/// A macro to retrieve or initialize the PostgreSQL connection pool.
///
@ -43,3 +51,44 @@ macro_rules! get_pg {
}
};
}
#[macro_export]
macro_rules! get_sqlite {
() => {
if let Some(client) = $crate::SQ.get() {
client
} else {
let client = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(5)
.connect(&std::env::var("DATABASE_URL").unwrap())
.await
.unwrap();
$crate::SQ.set(client).unwrap();
$crate::SQ.get().unwrap()
}
};
}
#[macro_export]
macro_rules! get_sqlite_or_create {
() => {
if let Some(client) = $crate::SQ.get() {
client
} else {
let db_url = std::env::var("DATABASE_URL").unwrap();
if !std::fs::exists(&db_url).unwrap_or(false) {
std::fs::File::create(&db_url).unwrap();
}
let client = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(5)
.connect(&db_url)
.await
.unwrap();
$crate::SQ.set(client).unwrap();
$crate::SQ.get().unwrap()
}
};
}