From 6aea22576c651584087f5903af14c8e323f90f17 Mon Sep 17 00:00:00 2001 From: JMARyA Date: Tue, 31 Dec 2024 02:03:03 +0100 Subject: [PATCH] update chunked embed --- env | 3 +- migrations/0001_embedding.sql | 3 +- src/ai.rs | 141 ++++++++++++++++++++++++++++++---- src/archive.rs | 5 +- src/pages/mod.rs | 9 ++- 5 files changed, 136 insertions(+), 25 deletions(-) diff --git a/env b/env index c2354bd..f0a8e54 100644 --- a/env +++ b/env @@ -9,7 +9,8 @@ ROUTE_INTERNAL=true DOWNLOAD_ON_DEMAND=true # Blacklisted domains (Comma-seperated regex) -BLACKLIST_DOMAINS="google.com,.*.youtube.com" +# You can blacklist sites which wont work well +BLACKLIST_DOMAINS="^gitlab" # Database DATABASE_URL=postgres://user:pass@postgres/webarc diff --git a/migrations/0001_embedding.sql b/migrations/0001_embedding.sql index 1826b2b..15f027e 100644 --- a/migrations/0001_embedding.sql +++ b/migrations/0001_embedding.sql @@ -5,6 +5,7 @@ CREATE TABLE doc_embedding ( domain VARCHAR(500) NOT NULL, path VARCHAR(1000) NOT NULL, ver VARCHAR(10) NOT NULL, + chunk INTEGER NOT NULL, embed_mxbai_embed_large vector(1024) NOT NULL, - PRIMARY KEY (domain, path, ver) + PRIMARY KEY (domain, path, ver, chunk) ) diff --git a/src/ai.rs b/src/ai.rs index 44aad7e..1463258 100644 --- a/src/ai.rs +++ b/src/ai.rs @@ -1,4 +1,4 @@ -use std::collections::VecDeque; +use std::collections::{HashMap, VecDeque}; use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass}; use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; @@ -8,11 +8,15 @@ use sqlx::FromRow; use crate::archive::{Document, Domain, WebsiteArchive}; +// TODO : Chunked embeddings + better search + ranking +// TODO : Real citese embeddings + search + #[derive(Debug, Clone, FromRow, Serialize)] pub struct DocEmbedding { pub domain: String, pub path: String, pub ver: String, + pub chunk: i32, #[serde(skip)] embed_mxbai_embed_large: pgvector::Vector, @@ -21,6 +25,53 @@ pub struct DocEmbedding { pub similarity: f64, } +#[derive(Debug, Clone, Serialize)] +pub struct SearchResult { + pub domain: String, + pub path: String, + pub chunks: Vec, +} + +impl SearchResult { + pub fn new(domain: String, path: String) -> Self { + Self { + domain, + path, + chunks: vec![], + } + } + + pub fn similarity(&self) -> f64 { + total_score(&self.chunks) + } +} + +pub fn avg_sim(e: &[DocEmbedding]) -> f64 { + let mut score = 0.0; + + for e in e { + score += e.similarity; + } + + score / e.len() as f64 +} + +pub fn max_sim(e: &[DocEmbedding]) -> f64 { + let mut score = 0.0; + + for e in e { + if e.similarity > score { + score = e.similarity; + } + } + + score +} + +pub fn total_score(e: &[DocEmbedding]) -> f64 { + (avg_sim(e) + max_sim(e)) / 2.0 +} + impl ToAPI for DocEmbedding { async fn api(&self) -> serde_json::Value { json!({ @@ -33,12 +84,23 @@ impl ToAPI for DocEmbedding { } pub trait Embedding { - fn embedding(&self, ver: Option) - -> impl std::future::Future>>; + fn embedding( + &self, + ver: Option, + ) -> impl std::future::Future>>>; +} + +pub fn chunked(s: &str) -> Vec { + const CHUNK_SIZE: usize = 500; + s.chars() + .collect::>() + .chunks(CHUNK_SIZE) + .map(|chunk| chunk.iter().collect()) + .collect() } impl Embedding for Document { - async fn embedding(&self, ver: Option) -> Option> { + async fn embedding(&self, ver: Option) -> Option>> { let latest = "latest".to_string(); log::info!( "Generating Vector embeddings for {} / {} @ {}", @@ -47,9 +109,26 @@ impl Embedding for Document { ver.as_ref().unwrap_or(&latest) ); - let content_html = self.render_local(ver).await?; + let content_html = self.render_local(ver.clone()).await?; let content = html2md::parse_html(&content_html); - generate_embedding(content).await + + let mut embeddings = Vec::new(); + let content = chunked(&content); + let len = content.len(); + + for (index, c) in content.into_iter().enumerate() { + log::info!( + "Generating Vector embeddings for {} / {} @ {} [ {} / {} ]", + self.domain, + self.path, + ver.as_ref().unwrap_or(&latest), + index + 1, + len + ); + embeddings.push(generate_embedding(c).await?); + } + + Some(embeddings) } } @@ -129,14 +208,17 @@ impl EmbedStore { .execute(get_pg!()) .await; - sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)") - .bind(&doc.domain) - .bind(&doc.path) - .bind(ver) - .bind(embed) - .execute(get_pg!()) - .await - .unwrap(); + for (index, embed) in embed.iter().enumerate() { + sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4, $5)") + .bind(&doc.domain) + .bind(&doc.path) + .bind(ver) + .bind(index as i64) + .bind(embed) + .execute(get_pg!()) + .await + .unwrap(); + } } } @@ -148,8 +230,12 @@ impl EmbedStore { } } - pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec { - sqlx::query_as( + pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec { + // TODO : fix search + // + new ranked algorithm + // + better repr + + let results: Vec = sqlx::query_as( "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3", ) .bind(v) @@ -157,7 +243,28 @@ impl EmbedStore { .bind(offset) .fetch_all(get_pg!()) .await - .unwrap() + .unwrap(); + + let mut search_res: HashMap> = HashMap::new(); + + for res in results { + let domain = search_res + .entry(res.domain.clone()) + .or_insert(HashMap::new()); + let doc = domain + .entry(res.path.clone()) + .or_insert(SearchResult::new(res.domain.clone(), res.path.clone())); + doc.chunks.push(res); + } + + let mut flat = search_res + .into_values() + .map(|x| x.into_values().collect::>()) + .flatten() + .collect::>(); + + flat.sort_by(|a, b| b.chunks.len().cmp(&a.chunks.len())); + flat } pub async fn generate_embeddings_for(arc: &WebsiteArchive) { diff --git a/src/archive.rs b/src/archive.rs index 39e2a91..361705b 100644 --- a/src/archive.rs +++ b/src/archive.rs @@ -1,6 +1,6 @@ use std::{io::Read, path::PathBuf}; -use based::request::RequestContext; +use based::{request::RequestContext, result::LogAndIgnore}; use maud::html; use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page}; @@ -62,7 +62,8 @@ impl Domain { /// A new `Domain` instance. pub fn new(name: &str, dir: PathBuf) -> Self { if !check_blacklist(name) { - std::fs::create_dir_all(&dir).unwrap(); + std::fs::create_dir_all(&dir) + .log_err_and_ignore(&format!("Could not create domain dir {name}")); } Self { name: name.to_string(), diff --git a/src/pages/mod.rs b/src/pages/mod.rs index 99d11c5..9cdc829 100644 --- a/src/pages/mod.rs +++ b/src/pages/mod.rs @@ -14,7 +14,7 @@ use component::*; use serde_json::json; use crate::{ - ai::{generate_embedding, DocEmbedding, EmbedStore}, + ai::{generate_embedding, EmbedStore, SearchResult}, archive::WebsiteArchive, }; @@ -176,7 +176,7 @@ pub async fn render_website( None } -pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped { +pub fn gen_search_element(x: &SearchResult) -> PreEscaped { html! { div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer" hx-get=(format!("/d/{}/{}", x.domain, x.path)) @@ -186,7 +186,7 @@ pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped { a { (x.domain) }; (slash_seperator()); (gen_path_header(x.path.split('/').collect(), &x.domain, false)); - p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) }; + p class="font-bold p-2 text-stone-400" { (format!("{:.2} % [{} matches]", x.similarity() * 100.0, x.chunks.len())) }; }; } } @@ -220,13 +220,14 @@ pub async fn vector_search( EmbedStore::search_vector(&input, limit as i64, offset as i64).await }) }, - 5, + 50, ) .pager(page as u64, vector) .await; // API Route if query.ends_with(".json") { + // TODO : Better search API return Some(respond_json(&json!(&results.page(page as u64)))); }