Llogiq on stuff

Semantic Search with Rust, Bert and Qdrant

First for a bit of backstory: While working at Qdrant, I proposed a talk to RustLab Italy 2023, where I would live code a semantic search. Unfortunately, I missed that my window manager (sway) did not support mirroring the screen, so I had to crane my neck to at least partially see what I was doing the whole time, which wasn’t conductive to explaining what I was doing. To everyone who attended the talk and was unhappy about my lackluster explanations, I offer this post, along with my sincere apology.

Before we delve into the code, let me give a short explanation of what semantic search is: Unlike a plain text search, we first transform the input (that may be text or anything else, depending on our transformation) into a vector (usually of floating-point numbers). The transformation has been designed and optimize to give inputs with similar “meaning” nearby values according to some distance metrics. In machine learning circles, we call the transformation “transformer model” and the resulting vector “embedding”. A vector database is a database specialized to search in such a vector space to find nearby points in such a vector space very efficiently.

Now the program I wrote didn’t only search but also set up the database for the search. So the code contained the following steps:

  1. set up the model
  2. set up the Qdrant client
  3. match on the first program argument. If “insert”, go to step 4, if “find”, go to step 8, otherwise exit with an error
  4. (optionally) delete and re-create the collection in Qdrant
  5. read a JSONL file containing the text along with some data
  6. iterate over the lines, parse the JSON into a HashMap, get the text as a value, embed it using the model and collect all of that into a Vec.
  7. insert the points into Qdrant in chunks of 100 and exit
  8. embed the next argument and
  9. search for it in Qdrant, print the result and exit

The reason I used JSONL for importing my points was that I had a file from the Qdrant page search lying around that was created by crawling the page, containing text, page links and CSS selectors, so the page could highlight the match. In hindsight, I could have pre-populated my database and the talk would still have been pretty decent and 33% shorter, but this way it was really from scratch, unless you count the Cargo.toml, which because the network at the conference had been flaky the day before, I had prepopulated with the following dependencies, running cargo build to pre-download all the things I might need during the talk:

[dependencies]
qdrant-client = "1.6.0"
rust-bert = { version = "0.21.0", features = ["download-libtorch"] }
serde_json = "1.0.108"
tokio = { version = "1.34.0", features = ["macros"] }

I didn’t use an argument parsing library because the arg parsing I had to do was so primitive (just two subcommands with one argument each) and chosing a crate would invariably have led to questions as to why. std::env::args is actually OK, folks. Just don’t forget that the zeroeth item is the running program. I also had main return a Result<(), Box<dyn std::error::Error>>, which let me use the ? operator almost everywhere, but in one closure (which used unwraps because I wasn’t ready to also explain how Rust will collect Iterators of Result<T, E> into Resule<Vec<T>, E>).

For the embedding I chose the rust-bert crate, because it is the easiest way to do workable embeddings. The model I used, MiniLMallL12V2, is a variant of the simple and well-tested BERT model with an output size of 384 values and trained to perform semantic search via cosine similarity (both facts will feature later in the code). For the data storage and retrieval, I obviously chose Qdrant, not only for the reason that it’s my former employer, but also because it performs really well, and I could run it locally without any compromise. One thing that was a bit irritating is that rust-bert always embeds a slice of texts and always returns a Vec<Vec<f32>> where we only need one Vec<f32>, so I did the .into_iter().next().unwrap() dance to extract that.

In practice, chosing such a small-ish model allows one to obtain well-performing embedding code, leading to lean resource usage and acceptable recall in most cases. For specialized applications, the model may be re-trained to perform better without any performance penalty. While the trend has been going towards larger and larger models (the so called Large Language Models, or LLMs for short), there is still a lot of mileage in the small ones, and they’re much less cost-intensive.

In Qdrant, the workflow for setting up the database for search is to create the collection, which must at least be configured with the vector size (here 384) and the distance function (here cosine similarity). I left everything else at the default value, which for such a small demo is just fine. Qdrant, being the speed devil among the vector databases, has a lot of knobs to turn to make it, in the best Rust tradition, run blazingly fast.

In the code I removed the collection first, so I would delete any failed attempts of setting up the collection before, just as a safety measure. Then I embedded all the JSONL objects, converted them into Qdrant’s PointStructs and upserted (a term that means a mixture of inserting and updating) them in batches of 100s, another safety measure to steer widely clear of any possible timeouts that might have derailed the demo. For the search, the minimum parameters are the vector, the limit (one can also use a distance threshold instead) and a with_payload argument that will make Qdrant actually retrieve the payload. I cannot explain why the latter is not the default, but it’s not too painful to ask for it. I can only conjecture that some users will only care for IDs, having downstream data sources from where to retrieve their payloads.

Because Qdrant’s API is async, I also pulled in tokio with the main macro enabled. This way, I had to await a few things, but it otherwise didn’t detract from the clarity of the code, which is nice.

I hope that clears things up. Here is the code 100% as it was at the end of my talk:

use std::collections::HashMap;

use qdrant_client::{
    prelude::*,
    qdrant::{VectorParams, VectorsConfig},
};
use rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsBuilder;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    let model = SentenceEmbeddingsBuilder::remote(
        rust_bert::pipelines::sentence_embeddings::SentenceEmbeddingsModelType::AllMiniLmL12V2,
    )
    .create_model()?;
    let client = QdrantClient::from_url(&std::env::var("QDRANT_URL").unwrap())
        .with_api_key(std::env::var("QDRANT_API_KEY"))
        .build()?;
    let mut args = std::env::args();
    match args.nth(1).as_deref() {
        Some("insert") => {
            let _ = client.delete_collection("points").await;
            client
                .create_collection(&CreateCollection {
                    collection_name: "points".into(),
                    vectors_config: Some(VectorsConfig {
                        config: Some(qdrant_client::qdrant::vectors_config::Config::Params(
                            VectorParams {
                                size: 384,
                                distance: Distance::Cosine as i32,
                                ..Default::default()
                            },
                        )),
                    }),
                    ..Default::default()
                })
                .await?;
            let Some(input_file) = args.next() else {
                eprintln!("usage: semantic insert <path/to/file>");
                return Ok(());
            };
            let contents = std::fs::read_to_string(input_file)?;
            let mut id = 0;
            let points = contents
                .lines()
                .map(|line| {
                    let payload: HashMap<String, Value> = serde_json::from_str(line).unwrap();
                    let text = payload.get("text").unwrap().to_string();
                    let embeddings = model
                        .encode(&[text])
                        .unwrap()
                        .into_iter()
                        .next()
                        .unwrap()
                        .into();
                    id += 1;
                    PointStruct {
                        id: Some(id.into()),
                        payload,
                        vectors: Some(embeddings),
                        ..Default::default()
                    }
                })
                .collect::<Vec<_>>();
            for batch in points.chunks(100) {
                let _ = client
                    .upsert_points("points", batch.to_owned(), None)
                    .await?;
                print!(".");
            }
            println!()
        }
        Some("find") => {
            let Some(text) = args.next() else {
                eprintln!("usage: semantic insert <path/to/file>");
                return Ok(());
            };
            let vector = model.encode(&[text])?.into_iter().next().unwrap();
            let result = client
                .search_points(&SearchPoints {
                    collection_name: "points".into(),
                    vector,
                    limit: 3,
                    with_payload: Some(true.into()),
                    ..Default::default()
                })
                .await?;
            println!("{result:#?}");
        }
        _ => {
            eprintln!("usage: semantic insert <path/to/file>");
            return Ok(());
        }
    }
    Ok(())
}