This commit is contained in:
parent
e50d31479c
commit
6aea22576c
5 changed files with 136 additions and 25 deletions
3
env
3
env
|
@ -9,7 +9,8 @@ ROUTE_INTERNAL=true
|
|||
DOWNLOAD_ON_DEMAND=true
|
||||
|
||||
# Blacklisted domains (Comma-seperated regex)
|
||||
BLACKLIST_DOMAINS="google.com,.*.youtube.com"
|
||||
# You can blacklist sites which wont work well
|
||||
BLACKLIST_DOMAINS="^gitlab"
|
||||
|
||||
# Database
|
||||
DATABASE_URL=postgres://user:pass@postgres/webarc
|
||||
|
|
|
@ -5,6 +5,7 @@ CREATE TABLE doc_embedding (
|
|||
domain VARCHAR(500) NOT NULL,
|
||||
path VARCHAR(1000) NOT NULL,
|
||||
ver VARCHAR(10) NOT NULL,
|
||||
chunk INTEGER NOT NULL,
|
||||
embed_mxbai_embed_large vector(1024) NOT NULL,
|
||||
PRIMARY KEY (domain, path, ver)
|
||||
PRIMARY KEY (domain, path, ver, chunk)
|
||||
)
|
||||
|
|
141
src/ai.rs
141
src/ai.rs
|
@ -1,4 +1,4 @@
|
|||
use std::collections::VecDeque;
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
|
||||
use based::{get_pg, request::api::ToAPI, result::LogNoneAndPass};
|
||||
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
|
||||
|
@ -8,11 +8,15 @@ use sqlx::FromRow;
|
|||
|
||||
use crate::archive::{Document, Domain, WebsiteArchive};
|
||||
|
||||
// TODO : Chunked embeddings + better search + ranking
|
||||
// TODO : Real citese embeddings + search
|
||||
|
||||
#[derive(Debug, Clone, FromRow, Serialize)]
|
||||
pub struct DocEmbedding {
|
||||
pub domain: String,
|
||||
pub path: String,
|
||||
pub ver: String,
|
||||
pub chunk: i32,
|
||||
|
||||
#[serde(skip)]
|
||||
embed_mxbai_embed_large: pgvector::Vector,
|
||||
|
@ -21,6 +25,53 @@ pub struct DocEmbedding {
|
|||
pub similarity: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub domain: String,
|
||||
pub path: String,
|
||||
pub chunks: Vec<DocEmbedding>,
|
||||
}
|
||||
|
||||
impl SearchResult {
|
||||
pub fn new(domain: String, path: String) -> Self {
|
||||
Self {
|
||||
domain,
|
||||
path,
|
||||
chunks: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn similarity(&self) -> f64 {
|
||||
total_score(&self.chunks)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn avg_sim(e: &[DocEmbedding]) -> f64 {
|
||||
let mut score = 0.0;
|
||||
|
||||
for e in e {
|
||||
score += e.similarity;
|
||||
}
|
||||
|
||||
score / e.len() as f64
|
||||
}
|
||||
|
||||
pub fn max_sim(e: &[DocEmbedding]) -> f64 {
|
||||
let mut score = 0.0;
|
||||
|
||||
for e in e {
|
||||
if e.similarity > score {
|
||||
score = e.similarity;
|
||||
}
|
||||
}
|
||||
|
||||
score
|
||||
}
|
||||
|
||||
pub fn total_score(e: &[DocEmbedding]) -> f64 {
|
||||
(avg_sim(e) + max_sim(e)) / 2.0
|
||||
}
|
||||
|
||||
impl ToAPI for DocEmbedding {
|
||||
async fn api(&self) -> serde_json::Value {
|
||||
json!({
|
||||
|
@ -33,12 +84,23 @@ impl ToAPI for DocEmbedding {
|
|||
}
|
||||
|
||||
pub trait Embedding {
|
||||
fn embedding(&self, ver: Option<String>)
|
||||
-> impl std::future::Future<Output = Option<Vec<f32>>>;
|
||||
fn embedding(
|
||||
&self,
|
||||
ver: Option<String>,
|
||||
) -> impl std::future::Future<Output = Option<Vec<Vec<f32>>>>;
|
||||
}
|
||||
|
||||
pub fn chunked(s: &str) -> Vec<String> {
|
||||
const CHUNK_SIZE: usize = 500;
|
||||
s.chars()
|
||||
.collect::<Vec<char>>()
|
||||
.chunks(CHUNK_SIZE)
|
||||
.map(|chunk| chunk.iter().collect())
|
||||
.collect()
|
||||
}
|
||||
|
||||
impl Embedding for Document {
|
||||
async fn embedding(&self, ver: Option<String>) -> Option<Vec<f32>> {
|
||||
async fn embedding(&self, ver: Option<String>) -> Option<Vec<Vec<f32>>> {
|
||||
let latest = "latest".to_string();
|
||||
log::info!(
|
||||
"Generating Vector embeddings for {} / {} @ {}",
|
||||
|
@ -47,9 +109,26 @@ impl Embedding for Document {
|
|||
ver.as_ref().unwrap_or(&latest)
|
||||
);
|
||||
|
||||
let content_html = self.render_local(ver).await?;
|
||||
let content_html = self.render_local(ver.clone()).await?;
|
||||
let content = html2md::parse_html(&content_html);
|
||||
generate_embedding(content).await
|
||||
|
||||
let mut embeddings = Vec::new();
|
||||
let content = chunked(&content);
|
||||
let len = content.len();
|
||||
|
||||
for (index, c) in content.into_iter().enumerate() {
|
||||
log::info!(
|
||||
"Generating Vector embeddings for {} / {} @ {} [ {} / {} ]",
|
||||
self.domain,
|
||||
self.path,
|
||||
ver.as_ref().unwrap_or(&latest),
|
||||
index + 1,
|
||||
len
|
||||
);
|
||||
embeddings.push(generate_embedding(c).await?);
|
||||
}
|
||||
|
||||
Some(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -129,14 +208,17 @@ impl EmbedStore {
|
|||
.execute(get_pg!())
|
||||
.await;
|
||||
|
||||
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4)")
|
||||
.bind(&doc.domain)
|
||||
.bind(&doc.path)
|
||||
.bind(ver)
|
||||
.bind(embed)
|
||||
.execute(get_pg!())
|
||||
.await
|
||||
.unwrap();
|
||||
for (index, embed) in embed.iter().enumerate() {
|
||||
sqlx::query("INSERT INTO doc_embedding VALUES ($1, $2, $3, $4, $5)")
|
||||
.bind(&doc.domain)
|
||||
.bind(&doc.path)
|
||||
.bind(ver)
|
||||
.bind(index as i64)
|
||||
.bind(embed)
|
||||
.execute(get_pg!())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,8 +230,12 @@ impl EmbedStore {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<DocEmbedding> {
|
||||
sqlx::query_as(
|
||||
pub async fn search_vector(v: &pgvector::Vector, limit: i64, offset: i64) -> Vec<SearchResult> {
|
||||
// TODO : fix search
|
||||
// + new ranked algorithm
|
||||
// + better repr
|
||||
|
||||
let results: Vec<DocEmbedding> = sqlx::query_as(
|
||||
"SELECT *, 1 / (1 + (embed_mxbai_embed_large <-> $1)) AS similarity FROM doc_embedding ORDER BY embed_mxbai_embed_large <-> $1 LIMIT $2 OFFSET $3",
|
||||
)
|
||||
.bind(v)
|
||||
|
@ -157,7 +243,28 @@ impl EmbedStore {
|
|||
.bind(offset)
|
||||
.fetch_all(get_pg!())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
let mut search_res: HashMap<String, HashMap<String, SearchResult>> = HashMap::new();
|
||||
|
||||
for res in results {
|
||||
let domain = search_res
|
||||
.entry(res.domain.clone())
|
||||
.or_insert(HashMap::new());
|
||||
let doc = domain
|
||||
.entry(res.path.clone())
|
||||
.or_insert(SearchResult::new(res.domain.clone(), res.path.clone()));
|
||||
doc.chunks.push(res);
|
||||
}
|
||||
|
||||
let mut flat = search_res
|
||||
.into_values()
|
||||
.map(|x| x.into_values().collect::<Vec<SearchResult>>())
|
||||
.flatten()
|
||||
.collect::<Vec<SearchResult>>();
|
||||
|
||||
flat.sort_by(|a, b| b.chunks.len().cmp(&a.chunks.len()));
|
||||
flat
|
||||
}
|
||||
|
||||
pub async fn generate_embeddings_for(arc: &WebsiteArchive) {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::{io::Read, path::PathBuf};
|
||||
|
||||
use based::request::RequestContext;
|
||||
use based::{request::RequestContext, result::LogAndIgnore};
|
||||
use maud::html;
|
||||
|
||||
use crate::{blacklist::check_blacklist, favicon::download_fav_for, pages::component::render_page};
|
||||
|
@ -62,7 +62,8 @@ impl Domain {
|
|||
/// A new `Domain` instance.
|
||||
pub fn new(name: &str, dir: PathBuf) -> Self {
|
||||
if !check_blacklist(name) {
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
std::fs::create_dir_all(&dir)
|
||||
.log_err_and_ignore(&format!("Could not create domain dir {name}"));
|
||||
}
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
|
|
|
@ -14,7 +14,7 @@ use component::*;
|
|||
use serde_json::json;
|
||||
|
||||
use crate::{
|
||||
ai::{generate_embedding, DocEmbedding, EmbedStore},
|
||||
ai::{generate_embedding, EmbedStore, SearchResult},
|
||||
archive::WebsiteArchive,
|
||||
};
|
||||
|
||||
|
@ -176,7 +176,7 @@ pub async fn render_website(
|
|||
None
|
||||
}
|
||||
|
||||
pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
|
||||
pub fn gen_search_element(x: &SearchResult) -> PreEscaped<String> {
|
||||
html! {
|
||||
div class="text-xl font-bold mt-4 p-4 flex items-center w-full max-w-4xl max-h-40 mx-auto bg-neutral-800 shadow-md rounded-lg overflow-hidden border border-neutral-900 hover:cursor-pointer"
|
||||
hx-get=(format!("/d/{}/{}", x.domain, x.path))
|
||||
|
@ -186,7 +186,7 @@ pub fn gen_search_element(x: &DocEmbedding) -> PreEscaped<String> {
|
|||
a { (x.domain) };
|
||||
(slash_seperator());
|
||||
(gen_path_header(x.path.split('/').collect(), &x.domain, false));
|
||||
p class="font-bold p-2 text-stone-400" { (format!("{:.2} %", x.similarity * 100.0)) };
|
||||
p class="font-bold p-2 text-stone-400" { (format!("{:.2} % [{} matches]", x.similarity() * 100.0, x.chunks.len())) };
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -220,13 +220,14 @@ pub async fn vector_search(
|
|||
EmbedStore::search_vector(&input, limit as i64, offset as i64).await
|
||||
})
|
||||
},
|
||||
5,
|
||||
50,
|
||||
)
|
||||
.pager(page as u64, vector)
|
||||
.await;
|
||||
|
||||
// API Route
|
||||
if query.ends_with(".json") {
|
||||
// TODO : Better search API
|
||||
return Some(respond_json(&json!(&results.page(page as u64))));
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue