Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| from itertools import islice | |
| import shutil | |
| from threading import Thread | |
| import lancedb | |
| import gradio as gr | |
| import polars as pl | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| STYLE = """ | |
| .gradio-container td span { | |
| overflow: auto !important; | |
| } | |
| """.strip() | |
| # | |
| EMBEDDING_MODEL = SentenceTransformer("TaylorAI/bge-micro") | |
| MAX_N_ROWS = 3_000_000 | |
| N_ROWS_BATCH = 5_000 | |
| N_SEARCH_RESULTS = 15 | |
| CRAWL_DUMP = "CC-MAIN-2020-05" | |
| DB = None | |
| DISPLAY_COLUMNS = [ | |
| "text", | |
| "url", | |
| "token_count", | |
| "count", | |
| ] | |
| DISPLAY_COLUMN_TYPES = [ | |
| "str", | |
| "str", | |
| "number", | |
| "number", | |
| ] | |
| DISPLAY_COLUMN_WIDTHS = [ | |
| "300px", | |
| "100px", | |
| "50px", | |
| "25px", | |
| ] | |
| def rename_embedding_column(row): | |
| vector = row["embedding"] | |
| row["vector"] = vector | |
| del row["embedding"] | |
| return row | |
| def read_header_markdown() -> str: | |
| with open("./README.md", "r") as fp: | |
| text = fp.read(-1) | |
| # Get only the markdown following the HF metadata section. | |
| text = text.split("\n---\n")[-1] | |
| return text.replace("{{CRAWL_DUMP}}", CRAWL_DUMP) | |
| def db(): | |
| global DB | |
| if DB is None: | |
| DB = lancedb.connect("data") | |
| return DB | |
| def load_data_sample(): | |
| time.sleep(5) | |
| # remove any data that was already there; we want to replace it. | |
| if os.path.exists("data"): | |
| shutil.rmtree("data") | |
| rows = load_dataset( | |
| "airtrain-ai/fineweb-edu-fortified", | |
| name=CRAWL_DUMP, | |
| split="train", | |
| streaming=True, | |
| ) | |
| print("Loading data") | |
| # at this point you could iterate over the rows. | |
| # Here, we'll take a sample of rows with size | |
| # MAX_N_ROWS. Using islice will load only the amount | |
| # we asked for and no extras. | |
| sample = islice(rows, MAX_N_ROWS) | |
| table = None | |
| n_rows_loaded = 0 | |
| while True: | |
| batch = list(islice(sample, N_ROWS_BATCH)) | |
| if len(batch) == 0: | |
| break | |
| # We'll put it in a vector DB for easy vector search. | |
| # rename "embedding" column to "vector" | |
| data = [rename_embedding_column(row) for row in batch] | |
| n_rows_loaded += len(data) | |
| if table is None: | |
| print("Creating table") | |
| table = db().create_table("data", data=data) | |
| # index the embedding column for fast search. | |
| print("Indexing table") | |
| table.create_index(num_sub_vectors=1) | |
| else: | |
| table.add(data) | |
| print(f"Loaded {n_rows_loaded} rows") | |
| print("Done loading data") | |
| def search(search_phrase: str) -> tuple[pl.DataFrame, int]: | |
| while "data" not in db().table_names(): | |
| # Data is loaded asynchronously. Make sure there is at least | |
| # some in the table before searching. | |
| time.sleep(1) | |
| # Create our search vector | |
| embedding = EMBEDDING_MODEL.encode([search_phrase])[0] | |
| # Search | |
| table = db().open_table("data") | |
| data_frame = table.search(embedding).limit(N_SEARCH_RESULTS).to_polars() | |
| return ( | |
| # Return only what we want to display | |
| data_frame.select(*[pl.col(c) for c in DISPLAY_COLUMNS]).to_pandas(), | |
| table.count_rows(), | |
| ) | |
| with gr.Blocks(css=STYLE) as demo: | |
| gr.HTML(f"<style>{STYLE}</style>") | |
| with gr.Row(): | |
| gr.Markdown(read_header_markdown()) | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Search phrase", scale=100) | |
| search_button = gr.Button("Search", scale=1, min_width=100) | |
| with gr.Row(): | |
| rows_searched = gr.Number( | |
| label="Rows searched", | |
| show_label=True, | |
| ) | |
| with gr.Row(): | |
| search_results = gr.DataFrame( | |
| headers=DISPLAY_COLUMNS, | |
| type="pandas", | |
| datatype=DISPLAY_COLUMN_TYPES, | |
| row_count=N_SEARCH_RESULTS, | |
| col_count=(len(DISPLAY_COLUMNS), "fixed"), | |
| column_widths=DISPLAY_COLUMN_WIDTHS, | |
| elem_classes=".df-text-col", | |
| ) | |
| search_button.click( | |
| search, | |
| [input_text], | |
| [search_results, rows_searched], | |
| ) | |
| # load data on another thread so we can start searching even before it's | |
| # all loaded. | |
| data_load_thread = Thread(target=load_data_sample, daemon=True) | |
| data_load_thread.start() | |
| print("Launching app") | |
| demo.launch() | |