Skip to main content

flams_search/
lib.rs

1#![allow(unexpected_cfgs)]
2#![cfg_attr(all(doc, CHANNEL_NIGHTLY), feature(doc_cfg))]
3#![doc = include_str!("../README.md")]
4/*!
5 * ## Feature flags
6 */
7#![cfg_attr(doc,doc = document_features::document_features!())]
8
9use crate::index::SearchIndex;
10use flams_backend_types::search::FragmentQueryFilter;
11use flams_backend_types::search::SearchResult;
12use flams_math_archives::{
13    Archive, LocallyBuilt,
14    artifacts::{Artifact, ContentResult, FileOrString},
15    backend::{AnyBackend, GlobalBackend, LocalBackend},
16    build_target,
17    formats::BuildResult,
18    utils::errors::{ArtifactSaveError, FileError},
19};
20use flams_system::FlamsExtension;
21use ftml_uris::DocumentElementUri;
22use ftml_uris::{DocumentUri, SymbolUri, UriPath, UriWithArchive};
23
24#[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
25use crate::schema::SearchSchema;
26
27pub mod index;
28pub mod query;
29#[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
30pub mod schema;
31pub mod textify;
32
33#[cfg(feature = "tantivy")]
34const MEMORY_SIZE: usize = 150_000_000;
35
36flams_system::register_exension!(FlamsExtension {
37    name: "search",
38    on_start: initialize,
39    on_build_result: |b, uri, rel_path, a| {
40        //println!("Here");
41        if let Some(content) = a.as_any().downcast_ref::<ContentResult>() {
42            index(b, uri, rel_path, content);
43        }
44    },
45    on_reload: initialize
46});
47
48#[cfg(feature = "tantivy")]
49build_target!(TANTIVY {
50    name: "tantivy_search",
51    description: "search index",
52    run: |_| BuildResult::default()
53});
54
55#[cfg(feature = "vectorsearch")]
56build_target!(VECTORSEARCH {
57    name: "vector_search",
58    description: "search index",
59    run: |_| BuildResult::default()
60});
61
62// -------------------------------------------------
63
64#[cfg(feature = "vectorsearch")]
65pub(crate) struct Embedder;
66
67#[cfg(feature = "vectorsearch")]
68impl Embedder {
69    pub fn embed<S: AsRef<str> + Send + Sync>(
70        texts: impl AsRef<[S]>,
71    ) -> Result<Vec<flams_backend_types::search::Embedding>, String> {
72        MODEL.as_ref().map_or_else(
73            || Err("No model".to_string()),
74            |lock| {
75                let mut model = lock.lock();
76                model.embed(texts, None).map_or_else(
77                    |e| Err(format!("Error embedding texts: {e}")),
78                    |r| {
79                        Ok(r.into_iter()
80                            .map(|v| {
81                                // SAFETY: invariant
82                                let boxed = unsafe { v.try_into().unwrap_unchecked() };
83                                flams_backend_types::search::Embedding::new(boxed)
84                            })
85                            .collect())
86                    },
87                )
88            },
89        )
90    }
91}
92
93#[cfg(feature = "vectorsearch")]
94pub(crate) static MODEL: std::sync::LazyLock<Option<parking_lot::Mutex<fastembed::TextEmbedding>>> =
95    std::sync::LazyLock::new(|| {
96        tracing::info_span!("initializing vector search model").in_scope(|| {
97            use flams_system::settings::CONFIG_DIR;
98            // https://ort.pyke.io/backends/candle
99            //ort::set_api(ort_candle::api());
100            let model_path = flams_system::settings::Settings::get()
101                .embedding_dir
102                .as_ref()
103                .map_or_else(
104                    || {
105                        CONFIG_DIR
106                            .as_ref()
107                            .expect("no default directory")
108                            .join("embedding")
109                    },
110                    |d| (*d).to_path_buf(),
111                );
112
113            match fastembed::TextEmbedding::try_new(
114                fastembed::InitOptions::new(fastembed::EmbeddingModel::ParaphraseMLMiniLML12V2Q)
115                    .with_show_download_progress(false)
116                    .with_cache_dir(model_path),
117            ) {
118                Ok(m) => {
119                    tracing::info!("Model initialized");
120                    Some(parking_lot::Mutex::new(m))
121                }
122                Err(e) => {
123                    tracing::error!("Error downloading embedding model: {e}");
124                    None
125                }
126            }
127        })
128    });
129
130// -------------------------------------------------
131
132static SEARCHER: std::sync::LazyLock<Searcher> = std::sync::LazyLock::new(Searcher::new);
133static SPAN: std::sync::LazyLock<tracing::Span> =
134    std::sync::LazyLock::new(|| tracing::info_span!(target:"search",parent:None,"search"));
135
136#[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
137pub struct Searcher {
138    index: parking_lot::RwLock<tantivy::index::Index>,
139    reader: parking_lot::RwLock<tantivy::IndexReader>,
140    writer: parking_lot::Mutex<()>,
141}
142
143#[cfg(feature = "vectorsearch")]
144pub struct Searcher {
145    index: parking_lot::RwLock<Vec<SearchIndex>>,
146}
147
148impl Searcher {
149    #[inline]
150    #[must_use]
151    pub fn get() -> &'static Self {
152        &SEARCHER
153    }
154
155    #[cfg(feature = "vectorsearch")]
156    pub fn size(&self) -> (usize, usize) {
157        let slf = self.index.read();
158        (slf.len(), slf.len() * std::mem::size_of::<SearchIndex>())
159    }
160
161    #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
162    pub fn size(&self) -> (usize, usize) {
163        let reader = self.reader.read();
164        (
165            reader.searcher().num_docs() as usize,
166            reader
167                .searcher()
168                .space_usage()
169                .expect("test")
170                .total()
171                .get_bytes() as usize,
172        )
173    }
174
175    #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
176    fn new() -> Self {
177        let index =
178            tantivy::index::Index::create_in_ram(schema::SearchSchema::get().schema.clone());
179        Self {
180            reader: parking_lot::RwLock::new(index.reader().expect("Failed to build reader")),
181            index: parking_lot::RwLock::new(index),
182            writer: parking_lot::Mutex::new(()),
183        }
184    }
185
186    #[cfg(feature = "vectorsearch")]
187    #[inline]
188    const fn new() -> Self {
189        Self {
190            index: parking_lot::RwLock::new(Vec::new()),
191        }
192    }
193
194    #[cfg(feature = "vectorsearch")]
195    #[inline]
196    pub fn add_one(&self, index: SearchIndex) {
197        self.index.write().push(index);
198    }
199
200    #[cfg(feature = "vectorsearch")]
201    #[inline]
202    pub fn add(&self, iter: impl IntoIterator<Item = SearchIndex>) {
203        self.index.write().extend(iter);
204    }
205
206    #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
207    pub fn query(
208        &self,
209        s: &str,
210        mut opts: FragmentQueryFilter,
211        num_results: usize,
212    ) -> Option<Vec<(f32, SearchResult)>> {
213        SPAN.in_scope(move || {
214            let searcher = self.reader.read().searcher();
215            let in_documents = std::mem::take(&mut opts.in_documents)
216                .into_iter()
217                .map(|u| u.to_string())
218                .collect::<Vec<_>>();
219            let query = query::build_query(s, &self.index.read(), opts)?;
220            let top_num = if num_results == 0 {
221                usize::MAX / 2
222            } else {
223                num_results
224            };
225            let mut ret = Vec::new();
226            let iter = if in_documents.is_empty() {
227                searcher
228                    .search(&*query, &tantivy::collector::TopDocs::with_limit(top_num))
229                    .map_err(|e| tracing::error!("Search Error A: {e}"))
230                    .ok()?
231            } else {
232                searcher
233                    .search(
234                        &*query,
235                        &tantivy::collector::BytesFilterCollector::new(
236                            "uri".to_string(),
237                            move |u: &[u8]| {
238                                in_documents.iter().any(|d| u.starts_with(d.as_bytes()))
239                            },
240                            tantivy::collector::TopDocs::with_limit(top_num),
241                        ),
242                    )
243                    .map_err(|e| tracing::error!("Search Error B: {e}"))
244                    .ok()?
245            };
246            for (s, a) in iter {
247                let Ok(doc) = searcher
248                    .doc::<tantivy::schema::TantivyDocument>(a)
249                    .map_err(|e| tracing::error!("Search Error: {e}"))
250                else {
251                    continue;
252                };
253                if let Some(doc) = SearchIndex::from_document(doc) {
254                    ret.push((s, doc));
255                };
256            }
257            Some(ret)
258        })
259    }
260
261    #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
262    #[allow(clippy::type_complexity)]
263    pub fn query_symbols(
264        &self,
265        s: &str,
266        num_results: usize,
267    ) -> Option<Vec<(f32, SymbolUri, DocumentElementUri)>> {
268        SPAN.in_scope(move || {
269            const FILTER: FragmentQueryFilter = {
270                use flams_backend_types::search::QueryFilterFlags;
271
272                let mut f = FragmentQueryFilter::new();
273                f.flags = QueryFilterFlags::definition_like_only();
274                f
275            };
276            let searcher = self.reader.read().searcher();
277
278            let query = query::build_query(s, &self.index.read(), FILTER)?;
279            let top_num = if num_results == 0 {
280                usize::MAX / 2
281            } else {
282                num_results
283            };
284            let mut ret: Vec<(f32, SymbolUri, DocumentElementUri)> = Vec::new();
285            for (score, a) in searcher
286                .search(
287                    &*query,
288                    &tantivy::collector::TopDocs::with_limit(top_num * 3),
289                )
290                .map_err(|e| tracing::error!("Search Error A: {e}"))
291                .ok()?
292            {
293                let Ok(doc) = searcher
294                    .doc::<tantivy::schema::TantivyDocument>(a)
295                    .map_err(|e| tracing::error!("Search Error: {e}"))
296                else {
297                    continue;
298                };
299                if let Some(doc) = SearchIndex::from_document(doc) {
300                    if let SearchResult::Paragraph { fors, uri, .. } = doc {
301                        for sym in fors {
302                            if let Some((r, _, e)) = ret.iter_mut().find(|(_, k, _)| *k == sym) {
303                                if score > *r {
304                                    *e = uri.clone();
305                                }
306                            } else {
307                                ret.push((score, sym, uri.clone()));
308                            }
309                        }
310                    }
311                }
312            }
313            ret.sort_by_key(|(s, _, _)| ordered_float::OrderedFloat(-*s));
314            ret.truncate(num_results);
315            Some(ret)
316        })
317    }
318
319    #[cfg(feature = "vectorsearch")]
320    #[allow(clippy::cast_possible_truncation)]
321    pub fn query_symbols(
322        &self,
323        s: &str,
324        num_results: usize,
325    ) -> Option<Vec<(f32, SymbolUri, DocumentElementUri)>> {
326        // SAFETY: invariant: input.len() == output.len()
327        let query = unsafe { crate::Embedder::embed([s]).ok()?.pop().unwrap_unchecked() };
328        let top_num = if num_results == 0 {
329            usize::MAX / 2
330        } else {
331            num_results
332        };
333        let mut ret: Vec<(f32, SymbolUri, DocumentElementUri)> =
334            Vec::with_capacity(if num_results == 0 { 1 } else { num_results + 1 });
335        let searcher = self.index.read();
336        for par in searcher.iter().filter(|e| {
337            if let SearchIndex::Paragraph {
338                definition_like: true,
339                fors,
340                ..
341            } = e
342                && !fors.is_empty()
343            {
344                true
345            } else {
346                false
347            }
348        }) {
349            let SearchIndex::Paragraph {
350                title,
351                fors,
352                body,
353                uri: elem_uri,
354                ..
355            } = par
356            else {
357                // SAFETY: filter_map above
358                unsafe {
359                    use std::hint::unreachable_unchecked;
360                    unreachable_unchecked()
361                }
362            };
363            let title_score = title.as_ref().map(|t| (t % &query) as f32);
364            let body_score = (body % &query) as f32;
365            let neg_score = ordered_float::OrderedFloat(
366                -title_score.map_or(body_score, |t| t.mul_add(2.0, body_score) / 3.0),
367            );
368            let index = ret
369                .binary_search_by_key(&neg_score, |(e, _, _)| ordered_float::OrderedFloat(-*e))
370                .unwrap_or_else(|i| i);
371
372            // this could be optimized to iterate less
373            for f in fors {
374                if let Some((i, (_, _, _))) =
375                    ret.iter().enumerate().find(|(_, (_, uri, _))| uri == f)
376                {
377                    if i >= index {
378                        let (_, uri, _) = ret.remove(i);
379                        ret.insert(index, (-neg_score.0, uri, elem_uri.clone()));
380                    }
381                } else {
382                    ret.insert(index, (-neg_score.0, f.clone(), elem_uri.clone()));
383                }
384            }
385            ret.truncate(top_num);
386        }
387        drop(searcher);
388        Some(ret)
389    }
390
391    #[cfg(feature = "vectorsearch")]
392    #[allow(clippy::cast_possible_truncation)]
393    pub fn query(
394        &self,
395        s: &str,
396        opts: FragmentQueryFilter,
397        num_results: usize,
398    ) -> Option<Vec<(f32, SearchResult)>> {
399        // SAFETY: invariant: input.len() == output.len()
400        let query = unsafe { crate::Embedder::embed([s]).ok()?.pop().unwrap_unchecked() };
401        let top_num = if num_results == 0 {
402            usize::MAX / 2
403        } else {
404            num_results
405        };
406        let mut ret: Vec<(f32, SearchResult)> =
407            Vec::with_capacity(if num_results == 0 { 1 } else { num_results + 1 });
408        let searcher = self.index.read();
409        for e in searcher.iter().filter(|e| filter(&opts, e)) {
410            match e {
411                SearchIndex::Document { uri, title, body } => {
412                    let title_score = title.as_ref().map(|t| (t % &query) as f32);
413                    let body_score = (body % &query) as f32;
414                    let neg_score = ordered_float::OrderedFloat(
415                        -title_score.map_or(body_score, |t| t.mul_add(2.0, body_score) / 3.0),
416                    );
417                    let i = ret
418                        .binary_search_by_key(&neg_score, |(e, _)| ordered_float::OrderedFloat(-*e))
419                        .unwrap_or_else(|i| i);
420                    ret.insert(i, (-neg_score.0, SearchResult::Document(uri.clone())));
421                    if ret.len() > top_num {
422                        let _ = ret.pop();
423                    }
424                }
425                SearchIndex::Paragraph {
426                    uri,
427                    kind,
428                    definition_like,
429                    title,
430                    fors,
431                    body,
432                } => {
433                    let title_score = title.as_ref().map(|t| (t % &query) as f32);
434                    let body_score = (body % &query) as f32;
435                    let neg_score = ordered_float::OrderedFloat(
436                        -title_score.map_or(body_score, |t| t.mul_add(2.0, body_score) / 3.0),
437                    );
438                    let i = ret
439                        .binary_search_by_key(&neg_score, |(e, _)| ordered_float::OrderedFloat(-*e))
440                        .unwrap_or_else(|i| i);
441                    ret.insert(
442                        i,
443                        (
444                            -neg_score.0,
445                            SearchResult::Paragraph {
446                                uri: uri.clone(),
447                                fors: fors.clone(),
448                                def_like: *definition_like,
449                                kind: *kind,
450                            },
451                        ),
452                    );
453                    if ret.len() > top_num {
454                        let _ = ret.pop();
455                    }
456                }
457            }
458        }
459        drop(searcher);
460        Some(ret)
461    }
462}
463
464pub fn index(backend: &AnyBackend, uri: &DocumentUri, rel_path: &UriPath, result: &ContentResult) {
465    backend.with_buildable_archive(uri.archive_id(), |a| {
466        if let Some(a) = a {
467            let it = index::index_document(&result.document, &result.ftml);
468            let _ = a.save(
469                uri,
470                Some(rel_path),
471                FileOrString::Str(String::new().into_boxed_str()),
472                #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
473                TANTIVY.id(),
474                #[cfg(feature = "vectorsearch")]
475                VECTORSEARCH.id(),
476                Some(Box::new(IndexFile(it)) as _),
477                GlobalBackend.triple_store(),
478                false,
479            );
480        } else {
481            tracing::error!("Archive not found! {}", uri.archive_id());
482        }
483    });
484}
485
486struct IndexFile(Vec<SearchIndex>);
487impl Artifact for IndexFile {
488    fn as_any(&self) -> &dyn std::any::Any {
489        self as _
490    }
491    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
492        self as _
493    }
494    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any> {
495        self as _
496    }
497
498    #[cfg(all(not(feature = "tantivy"), feature = "vectorsearch"))]
499    fn kind(&self) -> &'static str {
500        "vectorsearch"
501    }
502
503    #[cfg(feature = "tantivy")]
504    fn kind(&self) -> &'static str {
505        "tantivy"
506    }
507
508    fn write(&self, into: &std::path::Path) -> Result<(), ArtifactSaveError> {
509        let file = std::fs::File::create(into)
510            .map_err(|e| ArtifactSaveError::Fs(FileError::Creation(into.to_path_buf(), e)))?;
511        #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
512        {
513            bincode::encode_into_std_write(
514                &self.0,
515                &mut std::io::BufWriter::new(file),
516                bincode::config::standard(),
517            )?;
518        }
519        #[cfg(feature = "vectorsearch")]
520        {
521            bincode::encode_into_std_write(
522                &self.0,
523                &mut std::io::BufWriter::new(file),
524                bincode::config::standard(),
525            )?;
526        }
527        Ok(())
528    }
529}
530
531#[allow(clippy::too_many_lines)]
532fn initialize() {
533    #[cfg(feature = "vectorsearch")]
534    SPAN.in_scope(|| {
535        use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
536
537        let _ = std::thread::spawn(|| {
538            std::sync::LazyLock::force(&MODEL);
539        });
540        let mut index = SEARCHER.index.write();
541        let nidx = tracing::info_span!("Loading search indices").in_scope(move || {
542            GlobalBackend
543                .all_archives()
544                .par_iter()
545                .filter_map(|a| match a {
546                    Archive::Local(a) => Some(a),
547                    Archive::Ext(_, _) => None,
548                })
549                .flat_map(|a| {
550                    let out = a.out_dir();
551                    if out.exists() && out.is_dir() {
552                        Some(
553                            walkdir::WalkDir::new(out)
554                                .into_iter()
555                                .filter_map(Result::ok)
556                                .filter(|entry| entry.file_name() == "vectorsearch")
557                                .filter_map(|e| {
558                                    let Ok(f) = std::fs::File::open(e.path()) else {
559                                        tracing::error!(
560                                            "error reading file {}",
561                                            e.path().display()
562                                        );
563                                        return None;
564                                    };
565                                    let file = std::io::BufReader::new(f);
566
567                                    let Ok(v): Result<Vec<SearchIndex>, _> =
568                                        bincode::decode_from_reader(
569                                            file,
570                                            bincode::config::standard(),
571                                        )
572                                    else {
573                                        tracing::error!(
574                                            "error deserializing file {}",
575                                            e.path().display()
576                                        );
577                                        return None;
578                                    };
579                                    Some(v)
580                                })
581                                .collect::<Vec<_>>(),
582                        )
583                    } else {
584                        None
585                    }
586                })
587                .flatten()
588                .flatten()
589                .collect::<Vec<_>>()
590        });
591        *index = nidx;
592    });
593
594    #[cfg(all(feature = "tantivy", not(feature = "vectorsearch")))]
595    SPAN.in_scope(|| {
596        use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
597        let index = tantivy::index::Index::create_in_ram(SearchSchema::get().schema.clone());
598        let mut writer = index
599            .writer(MEMORY_SIZE)
600            .expect("Failed to instantiate search writer");
601        let wr = &writer;
602        tracing::info_span!("Loading search indices").in_scope(move || {
603            GlobalBackend
604                .all_archives()
605                .par_iter()
606                .filter_map(|a| match a {
607                    Archive::Local(a) => Some(a),
608                    Archive::Ext(_, _) => None,
609                })
610                .for_each(|a| {
611                    let out = a.out_dir();
612                    if out.exists() && out.is_dir() {
613                        for e in walkdir::WalkDir::new(out)
614                            .into_iter()
615                            .filter_map(Result::ok)
616                            .filter(|entry| entry.file_name() == "tantivy")
617                        {
618                            let Ok(f) = std::fs::File::open(e.path()) else {
619                                tracing::error!("error reading file {}", e.path().display());
620                                return;
621                            };
622                            let file = std::io::BufReader::new(f);
623
624                            let Ok(v): Result<Vec<SearchIndex>, _> =
625                                bincode::decode_from_reader(file, bincode::config::standard())
626                            else {
627                                tracing::error!("error deserializing file {}", e.path().display());
628                                return;
629                            };
630                            for d in v {
631                                use tantivy::schema::Value;
632
633                                let d: tantivy::TantivyDocument = d.to_document();
634                                if let Err(e) = wr.add_document(d) {
635                                    tracing::error!("{e}");
636                                }
637                            }
638                        }
639                    }
640                });
641        });
642        match writer.commit() {
643            Ok(i) => tracing::info!("Loaded {i} entries"),
644            Err(e) => tracing::error!("Error: {e}"),
645        }
646        let slf = Searcher::get();
647        let writer = slf.writer.lock();
648        let mut old_index = slf.index.write();
649        let mut reader = slf.reader.write();
650        let Ok(r) = index.reader() else {
651            tracing::error!("Failed to instantiate search reader");
652            return;
653        };
654        *reader = r;
655        *old_index = index;
656        drop(reader);
657        drop(old_index);
658        drop(writer);
659    });
660}
661
662#[cfg(feature = "vectorsearch")]
663fn filter(cplx: &FragmentQueryFilter, idx: &SearchIndex) -> bool {
664    use flams_backend_types::search::SearchResultKind;
665    use ftml_uris::IsNarrativeUri;
666
667    match idx {
668        SearchIndex::Document { uri, .. } => {
669            cplx.flags.allow_documents()
670                && (cplx.languages.is_empty() || cplx.languages.contains(&uri.language))
671        }
672        SearchIndex::Paragraph {
673            definition_like: true,
674            uri,
675            ..
676        } => {
677            cplx.flags.allow_definitions()
678                && (cplx.languages.is_empty() || cplx.languages.contains(&uri.language()))
679        }
680        SearchIndex::Paragraph {
681            kind: SearchResultKind::Assertion,
682            uri,
683            ..
684        } => {
685            cplx.flags.allow_assertions()
686                && (cplx.languages.is_empty() || cplx.languages.contains(&uri.language()))
687        }
688        SearchIndex::Paragraph {
689            kind: SearchResultKind::Example,
690            uri,
691            ..
692        } => {
693            cplx.flags.allow_examples()
694                && (cplx.languages.is_empty() || cplx.languages.contains(&uri.language()))
695        }
696        SearchIndex::Paragraph {
697            kind: SearchResultKind::Paragraph,
698            uri,
699            ..
700        } => {
701            cplx.flags.allow_paragraphs()
702                && (cplx.languages.is_empty() || cplx.languages.contains(&uri.language()))
703        }
704        SearchIndex::Paragraph {
705            kind: SearchResultKind::Problem,
706            uri,
707            ..
708        } => {
709            cplx.flags.allow_problems()
710                && (cplx.languages.is_empty() || cplx.languages.contains(&uri.language()))
711        }
712        SearchIndex::Paragraph { .. } => false,
713    }
714}