update chunked embed
Some checks failed
ci/woodpecker/push/build Pipeline failed

This commit is contained in:
JMARyA 2024-12-31 02:03:03 +01:00
parent e50d31479c
commit 6aea22576c
Signed by: jmarya
GPG key ID: 901B2ADDF27C2263
5 changed files with 136 additions and 25 deletions

3
env
View file

@ -9,7 +9,8 @@ ROUTE_INTERNAL=true
DOWNLOAD_ON_DEMAND=true
# Blacklisted domains (Comma-seperated regex)
BLACKLIST_DOMAINS="google.com,.*.youtube.com"
# You can blacklist sites which wont work well
BLACKLIST_DOMAINS="^gitlab"
# Database
DATABASE_URL=postgres://user:pass@postgres/webarc

View file

@ -5,6 +5,7 @@ CREATE TABLE doc_embedding (
domain VARCHAR(500) NOT NULL,
path VARCHAR(1000) NOT NULL,
ver VARCHAR(10) NOT NULL,
chunk INTEGER NOT NULL,
embed_mxbai_embed_large vector(1024) NOT NULL,
PRIMARY KEY (domain, path, ver)
PRIMARY KEY (domain, path, ver, chunk)
)

127
src/ai.rs
View file

@ -1,4 +1,4 @@
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass};
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
@ -8,11 +8,15 @@ use sqlx::FromRow;
use crate::archive::{Document, Domain, WebsiteArchive};
// TODO : Chunked embeddings + better search + ranking
// TODO : Real citese embeddings + search
#[derive(Debug, Clone, FromRow, Serialize)]
pub struct DocEmbedding {
pub domain: String,
pub path: String,
pub ver: String,
pub chunk: i32,
#[serde(skip)]
embed_mxbai_embed_large: pgvector::Vector,
@ -21,6 +25,53 @@ pub struct DocEmbedding {
pub similarity: f64,
}
#[derive(Debug, Clone, Serialize)]
pub struct SearchResult {
pub domain: String,
pub path: String,
pub chunks: Vec<DocEmbedding>,
}
impl SearchResult {
pub fn new(domain: String, path: String) -> Self {
Self {
domain,
path,
chunks: vec![],
}
}
pub fn similarity(&self) -> f64 {
total_score(&self.chunks)
}
}
pub fn avg_sim(e: &[DocEmbedding]) -> f64 {
let mut score = 0.0;
for e in e {
score += e.similarity;
}
score / e.len() as f64
}
pub fn max_sim(e: &[DocEmbedding]) -> f64 {
let mut score = 0.0;
for e in e {
if e.similarity > score {
score = e.similarity;
}
}
score
}
pub fn total_score(e: &[DocEmbedding]) -> f64 {
(avg_sim(e) + max_sim(e)) / 2.0
}
impl ToAPI for DocEmbedding {
async fn api(&self) -> serde_json::Value {
json!({
@ -33,12 +84,23 @@ impl ToAPI for DocEmbedding {
}
pub trait Embedding {
fn embedding(&self, ver: Option<String>)
-> impl std::future::Future<Output = Option<Vec<f32>>>;
fn embedding(
&self,
ver: Option<String>,
) -> impl std::future::Future<Output = Option<Vec<Vec<f32>>>>;
}
pub fn chunked(s: &str) -> Vec<String> {
const CHUNK_SIZE: usize = 500;
s.chars()
.collect::<Vec<char>>()
.chunks(CHUNK_SIZE)
.map(|chunk| chunk.iter().collect())
.collect()
}
impl Embedding for Document {
async fn embedding(&self, ver: Option<String>) -> Option<Vec<f32>> {
async fn embedding(&self, ver: Option<String>) -> Option<Vec<Vec<f32>>> {
let latest = "latest".to_string();
log::info!(
"Generating Vector embeddings for {} / {} @ {}",
@ -47,9 +109,26 @@ impl Embedding for Document {
ver.as_ref().unwrap_or(&latest)
);
let content_html = self.render_local(ver).await?;
let content_html = self.render_local(ver.clone()).await?;
let content = html2md::parse_html(&content_html);
generate_embedding(content).await
let mut embeddings = Vec::new();
let content = chunked(&content);
let len = content.len();
for (index, c) in content.into_iter().enumerate() {
log::info!(
"Generating Vector embeddings for {} / {} @ {} [ {} / {} ]",
self.domain,
self.path,
ver.as_ref().unwrap_or(&latest),
index + 1,
len
);
embeddings.push(generate_embedding(c).await?);
}
Some(embeddings)
}
}
@ -129,16 +208,19 @@ impl EmbedStore {
.execute(get_pg!())
.await;
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)")
for (index, embed) in embed.iter().enumerate() {
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4, $5)")
.bind(&doc.domain)
.bind(&doc.path)
.bind(ver)
.bind(index as i64)
.bind(embed)
.execute(get_pg!())
.await
.unwrap();
}
}
}
pub async fn ensure_embedding(doc: &Document) {
for ver in doc.versions() {
@ -148,8 +230,12 @@ impl EmbedStore {
}
}
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<DocEmbedding> {
sqlx::query_as(
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<SearchResult> {
// TODO : fix search
// + new ranked algorithm
// + better repr
let results: Vec<DocEmbedding> = sqlx::query_as(
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3",
)
.bind(v)
@ -157,7 +243,28 @@ impl EmbedStore {
.bind(offset)
.fetch_all(get_pg!())
.await
.unwrap()
.unwrap();
let mut search_res: HashMap<String, HashMap<String, SearchResult>> = HashMap::new();
for res in results {
let domain = search_res
.entry(res.domain.clone())
.or_insert(HashMap::new());
let doc = domain
.entry(res.path.clone())
.or_insert(SearchResult::new(res.domain.clone(), res.path.clone()));
doc.chunks.push(res);
}
let mut flat = search_res
.into_values()
.map(|x| x.into_values().collect::<Vec<SearchResult>>())
.flatten()
.collect::<Vec<SearchResult>>();
flat.sort_by(|a, b| b.chunks.len().cmp(&a.chunks.len()));
flat
}
pub async fn generate_embeddings_for(arc: &WebsiteArchive) {

View file

@ -1,6 +1,6 @@
use std::{io::Read, path::PathBuf};
use based::request::RequestContext;
use based::{request::RequestContext, result::LogAndIgnore};
use maud::html;
use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page};
@ -62,7 +62,8 @@ impl Domain {
/// A new `Domain` instance.
pub fn new(name: &str, dir: PathBuf) -> Self {
if !check_blacklist(name) {
std::fs::create_dir_all(&dir).unwrap();
std::fs::create_dir_all(&dir)
.log_err_and_ignore(&format!("Could not create domain dir {name}"));
}
Self {
name: name.to_string(),

View file

@ -14,7 +14,7 @@ use component::*;
use serde_json::json;
use crate::{
ai::{generate_embedding, DocEmbedding, EmbedStore},
ai::{generate_embedding, EmbedStore, SearchResult},
archive::WebsiteArchive,
};
@ -176,7 +176,7 @@ pub async fn render_website(
None
}
pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
pub fn gen_search_element(x: &SearchResult) -> PreEscaped<String> {
html! {
div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer"
hx-get=(format!("/d/{}/{}", x.domain, x.path))
@ -186,7 +186,7 @@ pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
a { (x.domain) };
(slash_seperator());
(gen_path_header(x.path.split('/').collect(), &x.domain, false));
p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) };
p class="font-bold p-2 text-stone-400" { (format!("{:.2} % [{} matches]", x.similarity() * 100.0, x.chunks.len())) };
};
}
}
@ -220,13 +220,14 @@ pub async fn vector_search(
EmbedStore::search_vector(&input, limit as i64, offset as i64).await
})
},
5,
50,
)
.pager(page as u64, vector)
.await;
// API Route
if query.ends_with(".json") {
// TODO : Better search API
return Some(respond_json(&json!(&results.page(page as u64))));
}