update
Some checks failed
ci/woodpecker/push/build Pipeline failed

This commit is contained in:
JMARyA 2025-01-02 19:00:47 +01:00
parent 0f6e5f5b10
commit 8df8edeeca
Signed by: jmarya
GPG key ID: 901B2ADDF27C2263
15 changed files with 591 additions and 124 deletions

View file

@ -6,10 +6,12 @@ use serde::Serialize;
use serde_json::json;
use sqlx::FromRow;
use crate::archive::{Document, Domain, WebsiteArchive};
use crate::{
archive::{Document, Domain, WebsiteArchive},
conf::get_config,
};
// TODO : Chunked embeddings + better search + ranking
// TODO : Real citese embeddings + search
// TODO : Cite found chunks in search res?
#[derive(Debug, Clone, FromRow, Serialize)]
pub struct DocEmbedding {
@ -18,6 +20,7 @@ pub struct DocEmbedding {
pub ver: String,
pub chunk: i32,
#[allow(dead_code)]
#[serde(skip)]
embed_mxbai_embed_large: pgvector::Vector,
@ -25,24 +28,42 @@ pub struct DocEmbedding {
pub similarity: f64,
}
impl DocEmbedding {
pub async fn total_chunks(&self) -> i64 {
let res: (i64,) = sqlx::query_as(
"SELECT MAX(chunk) FROM doc_embedding WHERE domain = $1 AND path = $2 AND ver = $3",
)
.bind(&self.domain)
.bind(&self.path)
.bind(&self.ver)
.fetch_one(get_pg!())
.await
.unwrap();
res.0
}
}
#[derive(Debug, Clone, Serialize)]
pub struct SearchResult {
pub domain: String,
pub path: String,
pub total_chunks: i64,
pub chunks: Vec<DocEmbedding>,
}
impl SearchResult {
pub fn new(domain: String, path: String) -> Self {
pub fn new(domain: String, path: String, total_chunks: i64) -> Self {
Self {
domain,
path,
total_chunks,
chunks: vec![],
}
}
pub fn similarity(&self) -> f64 {
total_score(&self.chunks)
total_score(&self.chunks) * (self.chunks.len() as f64 / self.total_chunks as f64)
}
}
@ -99,6 +120,13 @@ pub fn chunked(s: &str) -> Vec<String> {
.collect()
}
fn remove_data_urls(input: &str) -> String {
let re = regex::Regex::new("data:(.*?)(;base64)?,(.*)").unwrap();
// Replace all occurrences of data URLs with an empty string
re.replace_all(input, "").to_string()
}
impl Embedding for Document {
async fn embedding(&self, ver: Option<String>) -> Option<Vec<Vec<f32>>> {
let latest = "latest".to_string();
@ -110,7 +138,7 @@ impl Embedding for Document {
);
let content_html = self.render_local(ver.clone()).await?;
let content = html2md::parse_html(&content_html);
let content = remove_data_urls(&html2md::parse_html(&content_html));
let mut embeddings = Vec::new();
let content = chunked(&content);
@ -133,7 +161,8 @@ impl Embedding for Document {
}
pub async fn generate_embedding(mut input: String) -> Option<Vec<f32>> {
if let Ok(ollama_url) = std::env::var("OLLAMA_URL") {
// TODO : Ollama load balancing
if let Some(ollama_url) = get_config().ai.as_ref().map(|x| x.OLLAMA_URL.clone()) {
let (host, port) = ollama_url.split_once(':')?;
let ollama = ollama_rs::Ollama::new(format!("http://{host}"), port.parse().ok()?);
@ -231,13 +260,10 @@ impl EmbedStore {
}
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<SearchResult> {
// TODO : fix search
// + new ranked algorithm
// + better repr
// limit should cover SearchResults not the query -> rework
let results: Vec<DocEmbedding> = sqlx::query_as(
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3",
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <=> $1 LIMIT $2 OFFSET $3",
)
.bind(v)
.bind(limit)
@ -249,26 +275,24 @@ impl EmbedStore {
let mut search_res: HashMap<String, HashMap<String, SearchResult>> = HashMap::new();
for res in results {
let domain = search_res
.entry(res.domain.clone())
.or_insert(HashMap::new());
let doc = domain
.entry(res.path.clone())
.or_insert(SearchResult::new(res.domain.clone(), res.path.clone()));
let domain = search_res.entry(res.domain.clone()).or_default();
let doc = domain.entry(res.path.clone()).or_insert(SearchResult::new(
res.domain.clone(),
res.path.clone(),
res.total_chunks().await,
));
doc.chunks.push(res);
}
let mut flat = search_res
.into_values()
.map(|x| x.into_values().collect::<Vec<SearchResult>>())
.flatten()
.flat_map(|x| x.into_values().collect::<Vec<SearchResult>>())
.collect::<Vec<SearchResult>>();
flat.sort_by(|a, b| {
b.similarity()
.partial_cmp(&a.similarity())
.unwrap_or(std::cmp::Ordering::Equal)
.then(b.chunks.len().cmp(&a.chunks.len()))
});
flat
}

View file

@ -3,7 +3,7 @@ use std::{io::Read, path::PathBuf};
use based::{request::RequestContext, result::LogAndIgnore};
use maud::html;
use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page};
use crate::{blacklist::check_blacklist, conf::get_config, favicon::download_fav_for, render_page};
/// Read directory entries into `Vec<String>`
pub fn read_dir(dir: &PathBuf) -> Vec<String> {
@ -22,16 +22,19 @@ pub fn read_dir(dir: &PathBuf) -> Vec<String> {
/// Rewrite all URLs in `input` to the format `/s/<domain>/<path..>`
fn internalize_urls(input: &str) -> String {
// TODO : Ignore blacklisted urls
let url_pattern = r"https?://([a-zA-Z0-9.-]+)(/[\w./-]*)";
let re = regex::Regex::new(url_pattern).unwrap();
re.replace_all(input, |caps: &regex::Captures| {
format!(
"/s/{}/{}",
&caps[1].trim_start_matches("www."), // Domain
&caps[2] // Path
)
let domain = caps[1].trim_start_matches("www.");
let path = &caps[2];
// Dont transform if in blacklist
if check_blacklist(domain) {
return format!("https://{domain}/{path}");
}
format!("/s/{domain}/{path}")
})
.to_string()
}
@ -82,6 +85,23 @@ impl Domain {
Document::new(&self.name, path, self.dir.parent().unwrap().to_path_buf())
}
/// Get all paths associated with the domain
pub fn all_paths(&self) -> Vec<PathEntry> {
let mut queue = self.paths("/").0;
let mut ret = Vec::new();
ret.push(PathEntry(self.name.clone(), "/".to_string()));
while let Some(el) = queue.pop() {
ret.push(el.clone());
let paths = self.paths(&el.1).0;
queue.extend(paths);
}
ret
}
/// Retrieves entries and metadata for a given path within the domain.
///
/// # Parameters
@ -98,6 +118,12 @@ impl Domain {
base_path = base_path.join(p);
}
let path = path
.split("/")
.filter(|x| !x.is_empty())
.collect::<Vec<&str>>()
.join("/");
let dir_content = read_dir(&base_path);
let mut ret = Vec::new();
@ -106,6 +132,11 @@ impl Domain {
for entry in dir_content {
let url_path = format!("{path}/{entry}");
let url_path = url_path
.split("/")
.filter(|x| !x.is_empty())
.collect::<Vec<&str>>()
.join("/");
if entry.starts_with("index_") && entry.ends_with(".html") {
is_doc = true;
continue;
@ -119,6 +150,7 @@ impl Domain {
}
/// Represents an entry within a domain's path, containing its name and URL path.
#[derive(Debug, Clone)]
pub struct PathEntry(String, String);
impl PathEntry {
@ -203,7 +235,7 @@ impl Document {
.unwrap();
let content = String::from_utf8_lossy(&buf);
if std::env::var("ROUTE_INTERNAL").unwrap_or("false".to_string()) == "true" {
if get_config().ROUTE_INTERNAL {
Some(internalize_urls(&content))
} else {
Some(content.to_string())
@ -291,6 +323,7 @@ impl WebsiteArchive {
///
/// This function downloads the content of the URL, processes it, and saves it to the archive.
pub async fn archive_url(&self, url: &str) {
// TODO : refactor
let parsed_url = url::Url::parse(url).unwrap();
let domain = parsed_url.domain().unwrap().trim_start_matches("www.");

61
src/args.rs Normal file
View file

@ -0,0 +1,61 @@
use clap::{arg, command};
pub fn get_args() -> clap::ArgMatches {
command!()
.about("Web Archive")
.arg(
arg!(-d --dir <dir> "Web archive directory")
.required(false)
.default_value("./websites"),
)
.subcommand(
command!()
.name("serve")
.about("Start web archive server")
.arg(
arg!(-c --config <config> "Web archive config file")
.required(false)
.default_value("./config.toml"),
),
)
.subcommand(
command!()
.name("archive")
.about("Work with web archives")
.subcommand(
command!()
.name("download")
.about("Download a new URL into the archive")
.arg(
arg!(-c --config <config> "Web archive config file")
.required(false)
.default_value("./config.toml"),
)
.arg(arg!([URL] "The URL to download").required(true))
)
.subcommand(
command!()
.name("list")
.about("List domains contained in the archive. If a domain is provided all paths of this domain will be listed.")
.arg(arg!([DOMAIN] "A domain to list").required(false))
.arg(arg!(-j --json "Ouput JSON").required(false)),
)
.subcommand(
command!()
.name("versions")
.about("List saved versions of a document")
.arg(arg!(-j --json "Ouput JSON").required(false))
.arg(arg!([DOMAIN] "A domain").required(true))
.arg(arg!([PATH] "A path").required(false))
)
.subcommand(
command!()
.name("get")
.about("Get a saved document")
.arg(arg!(--md "Ouput Markdown").required(false))
.arg(arg!([DOMAIN] "A domain").required(true))
.arg(arg!([PATH] "A path").required(false))
.arg(arg!([VERSION] "A version").required(false))
))
.get_matches()
}

View file

@ -1,17 +1,17 @@
use crate::conf::get_config;
/// Checks if a domain is present in the blacklist of unwanted domains.
///
/// This function checks the `$BLACKLIST_DOMAINS` environment variable for a comma-separated list of regular expressions to match against.
/// If a match is found, it immediately returns `true`. Otherwise, it returns `false`.
pub fn check_blacklist(domain: &str) -> bool {
let blacklist_raw = std::env::var("BLACKLIST_DOMAINS").unwrap_or_default();
let conf = get_config();
let conf = conf.websites.as_ref();
if blacklist_raw.is_empty() {
return false;
}
let blacklisted_domains = conf
.map(|x| x.BLACKLIST_DOMAINS.as_ref())
.unwrap_or_default();
let blacklist: Vec<&str> = blacklist_raw.split(',').collect();
for domain_regex in blacklist {
for domain_regex in blacklisted_domains.unwrap_or(&Vec::new()) {
let rgx = regex::Regex::new(domain_regex).unwrap();
if rgx.is_match(domain) {
return true;

70
src/conf.rs Normal file
View file

@ -0,0 +1,70 @@
use std::sync::Arc;
use serde::Deserialize;
use tokio::sync::OnceCell;
pub static CONFIG: OnceCell<Arc<Config>> = OnceCell::const_new();
/// Get a reference to global config
pub fn get_config() -> &'static Arc<Config> {
crate::conf::CONFIG.get().unwrap()
}
/// Load a global config
pub fn load_config(path: &str) {
// TODO : Other load locations
if let Ok(file_content) = std::fs::read_to_string(path) {
let conf: Config =
toml::from_str(&file_content).expect("Could not deserialize config file");
crate::conf::CONFIG.set(std::sync::Arc::new(conf)).unwrap();
}
}
/// Load a default global config
pub fn load_default_config() {
if crate::conf::CONFIG.get().is_none() {
crate::conf::CONFIG
.set(std::sync::Arc::new(Config::default()))
.unwrap();
}
}
#[allow(non_snake_case)]
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub ROUTE_INTERNAL: bool,
pub DOWNLOAD_ON_DEMAND: bool,
pub ai: Option<AIConfig>,
pub websites: Option<WebsiteConfig>,
}
#[allow(non_snake_case)]
#[derive(Debug, Clone, Deserialize)]
pub struct AIConfig {
pub OLLAMA_URL: String,
}
#[allow(non_snake_case)]
#[derive(Debug, Clone, Deserialize)]
pub struct WebsiteConfig {
pub BLACKLIST_DOMAINS: Option<Vec<String>>,
pub domains: Option<Vec<DomainConfig>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DomainConfig {
// TODO : Domain specific config
pub blacklist_paths: Option<Vec<String>>,
pub no_javascript: bool,
}
impl Default for Config {
fn default() -> Self {
Self {
ROUTE_INTERNAL: false,
DOWNLOAD_ON_DEMAND: false,
ai: None,
websites: None,
}
}
}

29
src/lib.rs Normal file
View file

@ -0,0 +1,29 @@
use based::{
page::Shell,
request::{RequestContext, StringResponse},
};
use maud::{html, PreEscaped};
pub mod ai;
pub mod archive;
pub mod blacklist;
pub mod conf;
pub mod favicon;
pub async fn render_page(content: PreEscaped<String>, ctx: RequestContext) -> StringResponse {
based::page::render_page(
content,
"Website Archive",
ctx,
&Shell::new(
html! {
script src="https://cdn.tailwindcss.com" {};
meta name="viewport" content="width=device-width, initial-scale=1.0" {};
script src="/assets/htmx.min.js" {};
},
html! {},
Some("bg-zinc-950 text-white min-h-screen flex pt-8 justify-center".to_string()),
),
)
.await
}

View file

@ -1,49 +1,170 @@
use ai::EmbedStore;
use archive::WebsiteArchive;
use based::get_pg;
use rocket::routes;
use webarc::ai::EmbedStore;
use webarc::archive::WebsiteArchive;
use webarc::conf::{get_config, load_config, load_default_config};
mod ai;
mod archive;
mod blacklist;
mod favicon;
mod args;
mod pages;
#[rocket::launch]
async fn launch() -> _ {
#[tokio::main]
async fn main() {
env_logger::init();
let arc = WebsiteArchive::new("./websites");
let args = args::get_args();
if std::env::var("DATABASE_URL").is_ok() {
let pg = get_pg!();
sqlx::migrate!("./migrations").run(pg).await.unwrap();
let archive_dir: &String = args.get_one("dir").unwrap();
match args.subcommand() {
Some(("serve", serve_args)) => {
let config: &String = serve_args.get_one("config").unwrap();
load_config(config);
let arc = WebsiteArchive::new(archive_dir);
if std::env::var("DATABASE_URL").is_ok() {
let pg = get_pg!();
sqlx::migrate!("./migrations").run(pg).await.unwrap();
}
let archive = arc.clone();
if get_config().ai.is_some() {
tokio::spawn(async move {
EmbedStore::generate_embeddings_for(&archive).await;
});
}
let archive = arc.clone();
tokio::spawn(async move {
webarc::favicon::download_favicons_for_sites(&archive.domains()).await;
});
rocket::build()
.mount(
"/",
routes![
based::htmx::htmx_script_route,
pages::index,
pages::render_website,
pages::domain_info_route,
pages::favicon_route,
pages::vector_search,
pages::render_txt_website
],
)
.manage(arc)
.launch()
.await
.unwrap();
}
Some(("archive", archive_args)) => {
let arc = WebsiteArchive::new(archive_dir);
match archive_args.subcommand() {
Some(("list", list_args)) => {
let json = list_args.get_flag("json");
load_default_config();
let elements = if let Some(domain) = list_args.get_one::<String>("DOMAIN") {
arc.get_domain(domain)
.all_paths()
.into_iter()
.map(|x| x.path().clone())
.collect()
} else {
arc.domains()
};
if json {
println!(
"{}",
serde_json::to_string(&serde_json::json!(elements)).unwrap()
);
} else {
if let Some(domain) = list_args.get_one::<String>("DOMAIN") {
println!("Paths in {domain}:");
} else {
println!("Domains in {}:", archive_dir);
}
if elements.is_empty() {
println!("No domains");
}
for d in elements {
println!("- {d}");
}
}
}
Some(("download", dl_args)) => {
let url: &String = dl_args.get_one("URL").unwrap();
let config: &String = dl_args.get_one("config").unwrap();
load_config(config);
arc.archive_url(url).await;
println!("Saved {url} to archive");
}
Some(("versions", ver_args)) => {
load_default_config();
let domain: &String = ver_args.get_one("DOMAIN").unwrap();
let path: String = if let Some(path) = ver_args.get_one::<String>("PATH") {
path.clone()
} else {
"/".to_string()
};
let versions = arc.get_domain(domain).path(&path).versions();
let json = ver_args.get_flag("json");
if json {
println!("{}", serde_json::to_string(&versions).unwrap());
} else {
println!("Versions for {domain} / {path}:");
for v in versions {
println!("- {v}");
}
}
}
Some(("get", get_args)) => {
load_default_config();
let domain: &String = get_args.get_one("DOMAIN").unwrap();
let path = if let Some(path) = get_args.get_one::<String>("PATH") {
path.clone()
} else {
"/".to_string()
};
let doc = arc.get_domain(domain).path(&path);
let ver = if let Some(ver) = get_args.get_one::<String>("VERSION") {
ver.clone()
} else {
doc.versions().first().unwrap().clone()
};
let md = get_args.get_flag("md");
let content = doc.render_local(Some(ver)).await;
if content.is_none() {
println!("No document found");
std::process::exit(1);
}
if md {
let markdown = html2md::parse_html(&content.unwrap());
println!("{markdown}");
} else {
println!("{}", content.unwrap());
}
}
Some((&_, _)) => {}
None => {}
};
}
Some((&_, _)) => {}
None => {}
}
let archive = arc.clone();
if std::env::var("OLLAMA_URL").is_ok() {
tokio::spawn(async move {
EmbedStore::generate_embeddings_for(&archive).await;
});
}
let archive = arc.clone();
tokio::spawn(async move {
favicon::download_favicons_for_sites(&archive.domains()).await;
});
rocket::build()
.mount(
"/",
routes![
based::htmx::htmx_script_route,
pages::index,
pages::render_website,
pages::domain_info_route,
pages::favicon_route,
pages::vector_search,
pages::render_txt_website
],
)
.manage(arc)
}

View file

@ -1,7 +1,3 @@
use based::{
page::Shell,
request::{RequestContext, StringResponse},
};
use maud::{html, PreEscaped};
/// Generates an SVG arrow icon with the specified color.
@ -78,24 +74,6 @@ pub fn gen_path_header(
}
}
pub async fn render_page(content: PreEscaped<String>, ctx: RequestContext) -> StringResponse {
based::page::render_page(
content,
"Website Archive",
ctx,
&Shell::new(
html! {
script src="https://cdn.tailwindcss.com" {};
meta name="viewport" content="width=device-width, initial-scale=1.0" {};
script src="/assets/htmx.min.js" {};
},
html! {},
Some("bg-zinc-950 text-white min-h-screen flex pt-8 justify-center".to_string()),
),
)
.await
}
pub fn favicon(site: &str) -> PreEscaped<String> {
html! {
img class="h-8 w-8 m-2" src=(format!("/favicon/{site}")) {};

View file

@ -13,12 +13,14 @@ pub mod component;
use component::*;
use serde_json::json;
use crate::{
use webarc::{
ai::{generate_embedding, EmbedStore, SearchResult},
archive::WebsiteArchive,
conf::get_config,
render_page,
};
const SEARCH_BAR_STYLE: &'static str = "w-full px-4 mb-4 py-2 text-white bg-black border-2 border-neon-blue placeholder-neon-blue focus:ring-2 focus:ring-neon-pink focus:outline-none font-mono text-lg";
const SEARCH_BAR_STYLE: &str = "w-full px-4 mb-4 py-2 text-white bg-black border-2 border-neon-blue placeholder-neon-blue focus:ring-2 focus:ring-neon-pink focus:outline-none font-mono text-lg";
/// Get the favicon of a domain
#[get("/favicon/<domain>")]
@ -29,6 +31,8 @@ pub async fn favicon_route(domain: &str) -> Option<DataResponse> {
.read_to_end(&mut buf)
.ok()?;
// TODO : Default favicon
Some(DataResponse::new(
buf,
"image/x-icon".to_string(),
@ -171,12 +175,7 @@ pub async fn render_website(
"text/html".to_string(),
Some(60 * 60 * 24),
));
} else if std::env::var("DOWNLOAD_ON_DEMAND")
.unwrap_or("false".to_string())
.as_str()
== "true"
&& time.is_none()
{
} else if get_config().DOWNLOAD_ON_DEMAND && time.is_none() {
arc.archive_url(&format!("https://{domain}/{}", path.to_str().unwrap()))
.await;
@ -213,9 +212,7 @@ pub async fn vector_search(
page: Option<i64>,
ctx: RequestContext,
) -> Option<StringResponse> {
if std::env::var("OLLAMA_URL").is_err() {
return None;
}
get_config().ai.as_ref()?;
let page = page.unwrap_or(1);