update chunked embed
Some checks failed
ci/woodpecker/push/build Pipeline failed

This commit is contained in:
JMARyA 2024-12-31 02:03:03 +01:00
parent e50d31479c
commit 6aea22576c
Signed by: jmarya
GPG key ID: 901B2ADDF27C2263
5 changed files with 136 additions and 25 deletions

3
env
View file

@ -9,7 +9,8 @@ ROUTE_INTERNAL=true
DOWNLOAD_ON_DEMAND=true DOWNLOAD_ON_DEMAND=true
# Blacklisted domains (Comma-seperated regex) # Blacklisted domains (Comma-seperated regex)
BLACKLIST_DOMAINS="google.com,.*.youtube.com" # You can blacklist sites which wont work well
BLACKLIST_DOMAINS="^gitlab"
# Database # Database
DATABASE_URL=postgres://user:pass@postgres/webarc DATABASE_URL=postgres://user:pass@postgres/webarc

View file

@ -5,6 +5,7 @@ CREATE TABLE doc_embedding (
domain VARCHAR(500) NOT NULL, domain VARCHAR(500) NOT NULL,
path VARCHAR(1000) NOT NULL, path VARCHAR(1000) NOT NULL,
ver VARCHAR(10) NOT NULL, ver VARCHAR(10) NOT NULL,
chunk INTEGER NOT NULL,
embed_mxbai_embed_large vector(1024) NOT NULL, embed_mxbai_embed_large vector(1024) NOT NULL,
PRIMARY KEY (domain, path, ver) PRIMARY KEY (domain, path, ver, chunk)
) )

141
src/ai.rs
View file

@ -1,4 +1,4 @@
use std::collections::VecDeque; use std::collections::{HashMap, VecDeque};
use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass}; use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass};
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest}; use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
@ -8,11 +8,15 @@ use sqlx::FromRow;
use crate::archive::{Document, Domain, WebsiteArchive}; use crate::archive::{Document, Domain, WebsiteArchive};
// TODO : Chunked embeddings + better search + ranking
// TODO : Real citese embeddings + search
#[derive(Debug, Clone, FromRow, Serialize)] #[derive(Debug, Clone, FromRow, Serialize)]
pub struct DocEmbedding { pub struct DocEmbedding {
pub domain: String, pub domain: String,
pub path: String, pub path: String,
pub ver: String, pub ver: String,
pub chunk: i32,
#[serde(skip)] #[serde(skip)]
embed_mxbai_embed_large: pgvector::Vector, embed_mxbai_embed_large: pgvector::Vector,
@ -21,6 +25,53 @@ pub struct DocEmbedding {
pub similarity: f64, pub similarity: f64,
} }
#[derive(Debug, Clone, Serialize)]
pub struct SearchResult {
pub domain: String,
pub path: String,
pub chunks: Vec<DocEmbedding>,
}
impl SearchResult {
pub fn new(domain: String, path: String) -> Self {
Self {
domain,
path,
chunks: vec![],
}
}
pub fn similarity(&self) -> f64 {
total_score(&self.chunks)
}
}
pub fn avg_sim(e: &[DocEmbedding]) -> f64 {
let mut score = 0.0;
for e in e {
score += e.similarity;
}
score / e.len() as f64
}
pub fn max_sim(e: &[DocEmbedding]) -> f64 {
let mut score = 0.0;
for e in e {
if e.similarity > score {
score = e.similarity;
}
}
score
}
pub fn total_score(e: &[DocEmbedding]) -> f64 {
(avg_sim(e) + max_sim(e)) / 2.0
}
impl ToAPI for DocEmbedding { impl ToAPI for DocEmbedding {
async fn api(&self) -> serde_json::Value { async fn api(&self) -> serde_json::Value {
json!({ json!({
@ -33,12 +84,23 @@ impl ToAPI for DocEmbedding {
} }
pub trait Embedding { pub trait Embedding {
fn embedding(&self, ver: Option<String>) fn embedding(
-> impl std::future::Future<Output = Option<Vec<f32>>>; &self,
ver: Option<String>,
) -> impl std::future::Future<Output = Option<Vec<Vec<f32>>>>;
}
pub fn chunked(s: &str) -> Vec<String> {
const CHUNK_SIZE: usize = 500;
s.chars()
.collect::<Vec<char>>()
.chunks(CHUNK_SIZE)
.map(|chunk| chunk.iter().collect())
.collect()
} }
impl Embedding for Document { impl Embedding for Document {
async fn embedding(&self, ver: Option<String>) -> Option<Vec<f32>> { async fn embedding(&self, ver: Option<String>) -> Option<Vec<Vec<f32>>> {
let latest = "latest".to_string(); let latest = "latest".to_string();
log::info!( log::info!(
"Generating Vector embeddings for {} / {} @ {}", "Generating Vector embeddings for {} / {} @ {}",
@ -47,9 +109,26 @@ impl Embedding for Document {
ver.as_ref().unwrap_or(&latest) ver.as_ref().unwrap_or(&latest)
); );
let content_html = self.render_local(ver).await?; let content_html = self.render_local(ver.clone()).await?;
let content = html2md::parse_html(&content_html); let content = html2md::parse_html(&content_html);
generate_embedding(content).await
let mut embeddings = Vec::new();
let content = chunked(&content);
let len = content.len();
for (index, c) in content.into_iter().enumerate() {
log::info!(
"Generating Vector embeddings for {} / {} @ {} [ {} / {} ]",
self.domain,
self.path,
ver.as_ref().unwrap_or(&latest),
index + 1,
len
);
embeddings.push(generate_embedding(c).await?);
}
Some(embeddings)
} }
} }
@ -129,14 +208,17 @@ impl EmbedStore {
.execute(get_pg!()) .execute(get_pg!())
.await; .await;
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)") for (index, embed) in embed.iter().enumerate() {
.bind(&doc.domain) sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4, $5)")
.bind(&doc.path) .bind(&doc.domain)
.bind(ver) .bind(&doc.path)
.bind(embed) .bind(ver)
.execute(get_pg!()) .bind(index as i64)
.await .bind(embed)
.unwrap(); .execute(get_pg!())
.await
.unwrap();
}
} }
} }
@ -148,8 +230,12 @@ impl EmbedStore {
} }
} }
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<DocEmbedding> { pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<SearchResult> {
sqlx::query_as( // TODO : fix search
// + new ranked algorithm
// + better repr
let results: Vec<DocEmbedding> = sqlx::query_as(
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3", "SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3",
) )
.bind(v) .bind(v)
@ -157,7 +243,28 @@ impl EmbedStore {
.bind(offset) .bind(offset)
.fetch_all(get_pg!()) .fetch_all(get_pg!())
.await .await
.unwrap() .unwrap();
let mut search_res: HashMap<String, HashMap<String, SearchResult>> = HashMap::new();
for res in results {
let domain = search_res
.entry(res.domain.clone())
.or_insert(HashMap::new());
let doc = domain
.entry(res.path.clone())
.or_insert(SearchResult::new(res.domain.clone(), res.path.clone()));
doc.chunks.push(res);
}
let mut flat = search_res
.into_values()
.map(|x| x.into_values().collect::<Vec<SearchResult>>())
.flatten()
.collect::<Vec<SearchResult>>();
flat.sort_by(|a, b| b.chunks.len().cmp(&a.chunks.len()));
flat
} }
pub async fn generate_embeddings_for(arc: &WebsiteArchive) { pub async fn generate_embeddings_for(arc: &WebsiteArchive) {

View file

@ -1,6 +1,6 @@
use std::{io::Read, path::PathBuf}; use std::{io::Read, path::PathBuf};
use based::request::RequestContext; use based::{request::RequestContext, result::LogAndIgnore};
use maud::html; use maud::html;
use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page}; use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page};
@ -62,7 +62,8 @@ impl Domain {
/// A new `Domain` instance. /// A new `Domain` instance.
pub fn new(name: &str, dir: PathBuf) -> Self { pub fn new(name: &str, dir: PathBuf) -> Self {
if !check_blacklist(name) { if !check_blacklist(name) {
std::fs::create_dir_all(&dir).unwrap(); std::fs::create_dir_all(&dir)
.log_err_and_ignore(&format!("Could not create domain dir {name}"));
} }
Self { Self {
name: name.to_string(), name: name.to_string(),

View file

@ -14,7 +14,7 @@ use component::*;
use serde_json::json; use serde_json::json;
use crate::{ use crate::{
ai::{generate_embedding, DocEmbedding, EmbedStore}, ai::{generate_embedding, EmbedStore, SearchResult},
archive::WebsiteArchive, archive::WebsiteArchive,
}; };
@ -176,7 +176,7 @@ pub async fn render_website(
None None
} }
pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> { pub fn gen_search_element(x: &SearchResult) -> PreEscaped<String> {
html! { html! {
div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer" div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer"
hx-get=(format!("/d/{}/{}", x.domain, x.path)) hx-get=(format!("/d/{}/{}", x.domain, x.path))
@ -186,7 +186,7 @@ pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
a { (x.domain) }; a { (x.domain) };
(slash_seperator()); (slash_seperator());
(gen_path_header(x.path.split('/').collect(), &x.domain, false)); (gen_path_header(x.path.split('/').collect(), &x.domain, false));
p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) }; p class="font-bold p-2 text-stone-400" { (format!("{:.2} % [{} matches]", x.similarity() * 100.0, x.chunks.len())) };
}; };
} }
} }
@ -220,13 +220,14 @@ pub async fn vector_search(
EmbedStore::search_vector(&input, limit as i64, offset as i64).await EmbedStore::search_vector(&input, limit as i64, offset as i64).await
}) })
}, },
5, 50,
) )
.pager(page as u64, vector) .pager(page as u64, vector)
.await; .await;
// API Route // API Route
if query.ends_with(".json") { if query.ends_with(".json") {
// TODO : Better search API
return Some(respond_json(&json!(&results.page(page as u64)))); return Some(respond_json(&json!(&results.page(page as u64))));
} }