Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| DreaMS Gradio Web Application | |
| This module provides a web interface for the DreaMS (Deep Representations Empowering | |
| the Annotation of Mass Spectra) tool using Gradio. It allows users to upload MS/MS | |
| files and perform library matching with DreaMS embeddings. | |
| Author: DreaMS Team | |
| License: MIT | |
| """ | |
| import gradio as gr | |
| import spaces | |
| import shutil | |
| import urllib.request | |
| from datetime import datetime | |
| from functools import partial | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import pandas as pd | |
| import numpy as np | |
| from pathlib import Path | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from rdkit import Chem | |
| from rdkit.Chem.Draw import rdMolDraw2D | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import io | |
| import dreams.utils.spectra as su | |
| import dreams.utils.io as dio | |
| from dreams.utils.data import MSData | |
| from dreams.api import dreams_embeddings | |
| from dreams.definitions import * | |
| # ============================================================================= | |
| # CONSTANTS AND CONFIGURATION | |
| # ============================================================================= | |
| # Optimized image sizes for better performance | |
| SMILES_IMG_SIZE = 120 # Reduced from 200 for faster rendering | |
| SPECTRUM_IMG_SIZE = 800 # Reduced from 1500 for faster generation | |
| # Library and data paths | |
| LIBRARY_PATH = Path('DreaMS/data/MassSpecGym_DreaMS.hdf5') | |
| DATA_PATH = Path('./DreaMS/data') | |
| EXAMPLE_PATH = Path('./data') | |
| # Cache for SMILES images to avoid regeneration | |
| _smiles_cache = {} | |
| def clear_smiles_cache(): | |
| """Clear the SMILES image cache to free memory""" | |
| global _smiles_cache | |
| _smiles_cache.clear() | |
| print("SMILES image cache cleared") | |
| # ============================================================================= | |
| # UTILITY FUNCTIONS FOR IMAGE CONVERSION | |
| # ============================================================================= | |
| def _validate_input_file(file_path): | |
| """ | |
| Validate that the input file exists and has a supported format | |
| Args: | |
| file_path: Path to the input file | |
| Returns: | |
| bool: True if file is valid, False otherwise | |
| """ | |
| if not file_path or not Path(file_path).exists(): | |
| return False | |
| supported_extensions = ['.mgf', '.mzML', '.mzml'] | |
| file_ext = Path(file_path).suffix.lower() | |
| return file_ext in supported_extensions | |
| def _convert_pil_to_base64(img, format='PNG'): | |
| """ | |
| Convert a PIL Image to base64 encoded string | |
| Args: | |
| img: PIL Image object | |
| format: Image format (default: 'PNG') | |
| Returns: | |
| str: Base64 encoded image string | |
| """ | |
| buffered = io.BytesIO() | |
| img.save(buffered, format=format, optimize=True) # Added optimize=True | |
| img_str = base64.b64encode(buffered.getvalue()) | |
| return f"data:image/{format.lower()};base64,{repr(img_str)[2:-1]}" | |
| def _crop_transparent_edges(img): | |
| """ | |
| Crop transparent edges from a PIL Image | |
| Args: | |
| img: PIL Image object (should be RGBA) | |
| Returns: | |
| PIL Image: Cropped image | |
| """ | |
| # Convert to RGBA if not already | |
| if img.mode != 'RGBA': | |
| img = img.convert('RGBA') | |
| # Get the bounding box of non-transparent pixels | |
| bbox = img.getbbox() | |
| if bbox: | |
| # Crop the image to remove transparent space | |
| img = img.crop(bbox) | |
| return img | |
| def smiles_to_html_img(smiles, img_size=SMILES_IMG_SIZE): | |
| """ | |
| Convert SMILES string to HTML image for display in Gradio dataframe | |
| Uses caching to avoid regenerating the same molecule images | |
| Args: | |
| smiles: SMILES string representation of molecule | |
| img_size: Size of the output image (default: SMILES_IMG_SIZE) | |
| Returns: | |
| str: HTML img tag with base64 encoded image | |
| """ | |
| # Check cache first | |
| cache_key = f"{smiles}_{img_size}" | |
| if cache_key in _smiles_cache: | |
| return _smiles_cache[cache_key] | |
| try: | |
| # Parse SMILES to RDKit molecule | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| result = f"<div style='text-align: center; color: red;'>Invalid SMILES</div>" | |
| _smiles_cache[cache_key] = result | |
| return result | |
| # Create PNG drawing with Cairo backend for better control | |
| d2d = rdMolDraw2D.MolDraw2DCairo(img_size, img_size) | |
| opts = d2d.drawOptions() | |
| opts.clearBackground = False | |
| opts.padding = 0.05 # Minimal padding | |
| opts.bondLineWidth = 1.5 # Reduced from 2.0 for smaller images | |
| # Draw the molecule | |
| d2d.DrawMolecule(mol) | |
| d2d.FinishDrawing() | |
| # Get PNG data and convert to PIL Image | |
| png_data = d2d.GetDrawingText() | |
| img = Image.open(io.BytesIO(png_data)) | |
| # Crop transparent edges and convert to base64 | |
| img = _crop_transparent_edges(img) | |
| img_str = _convert_pil_to_base64(img) | |
| result = f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='{smiles}' />" | |
| # Cache the result | |
| _smiles_cache[cache_key] = result | |
| return result | |
| except Exception as e: | |
| result = f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>" | |
| _smiles_cache[cache_key] = result | |
| return result | |
| def spectrum_to_html_img(spec1, spec2, img_size=SPECTRUM_IMG_SIZE): | |
| """ | |
| Convert spectrum plot to HTML image for display in Gradio dataframe | |
| Optimized version based on working code | |
| Args: | |
| spec1: First spectrum data | |
| spec2: Second spectrum data (for mirror plot) | |
| img_size: Size of the output image (default: SPECTRUM_IMG_SIZE) | |
| Returns: | |
| str: HTML img tag with base64 encoded spectrum plot | |
| """ | |
| try: | |
| # Use non-interactive matplotlib backend | |
| matplotlib.use('Agg') | |
| # Create the spectrum plot using DreaMS utility function | |
| su.plot_spectrum(spec=spec1, mirror_spec=spec2, figsize=(1.6, 0.8)) # Reduced size for performance | |
| # Save figure to buffer with transparent background | |
| buffered = BytesIO() | |
| plt.savefig(buffered, format='png', bbox_inches='tight', dpi=80, transparent=True) | |
| buffered.seek(0) | |
| # Convert to PIL Image, crop edges, and convert to base64 | |
| img = Image.open(buffered) | |
| img = _crop_transparent_edges(img) | |
| img_str = _convert_pil_to_base64(img) | |
| # Clean up matplotlib figure to free memory | |
| plt.close() | |
| return f"<img src='{img_str}' style='max-width: 100%; height: auto;' title='Spectrum comparison' />" | |
| except Exception as e: | |
| return f"<div style='text-align: center; color: red;'>Error: {str(e)}</div>" | |
| # ============================================================================= | |
| # DATA DOWNLOAD AND SETUP FUNCTIONS | |
| # ============================================================================= | |
| def _download_file(url, target_path, description): | |
| """ | |
| Download a file from URL if it doesn't exist | |
| Args: | |
| url: Source URL | |
| target_path: Target file path | |
| description: Description for logging | |
| """ | |
| if not target_path.exists(): | |
| print(f"Downloading {description}...") | |
| target_path.parent.mkdir(parents=True, exist_ok=True) | |
| urllib.request.urlretrieve(url, target_path) | |
| print(f"Downloaded {description} to {target_path}") | |
| def setup(): | |
| """ | |
| Initialize the application by downloading required data files | |
| Downloads: | |
| - MassSpecGym spectral library | |
| - Example MS/MS files for testing | |
| Raises: | |
| Exception: If critical setup steps fail | |
| """ | |
| print("=" * 60) | |
| print("Setting up DreaMS application...") | |
| print("=" * 60) | |
| # Clear any existing cache | |
| clear_smiles_cache() | |
| try: | |
| # Download spectral library | |
| library_url = 'https://huggingface.co/datasets/roman-bushuiev/GeMS/resolve/main/data/auxiliary/MassSpecGym_DreaMS.hdf5' | |
| _download_file(library_url, LIBRARY_PATH, "MassSpecGym spectral library") | |
| # Download example files | |
| example_urls = [ | |
| ('https://huggingface.co/datasets/titodamiani/PiperNET/resolve/main/lcms/rawfiles/202312_147_P55-Leaf-r2_1uL.mzML', | |
| EXAMPLE_PATH / '202312_147_P55-Leaf-r2_1uL.mzML', | |
| "PiperNET example spectra"), | |
| ('https://raw.githubusercontent.com/pluskal-lab/DreaMS/refs/heads/main/data/examples/example_5_spectra.mgf', | |
| EXAMPLE_PATH / 'example_5_spectra.mgf', | |
| "DreaMS example spectra") | |
| ] | |
| for url, path, desc in example_urls: | |
| _download_file(url, path, desc) | |
| # Test DreaMS embeddings to ensure everything works | |
| print("\nTesting DreaMS embeddings...") | |
| test_path = EXAMPLE_PATH / 'example_5_spectra.mgf' | |
| embs = dreams_embeddings(test_path) | |
| print(f"✓ Setup complete - DreaMS embeddings test successful (shape: {embs.shape})") | |
| print("=" * 60) | |
| except Exception as e: | |
| print(f"✗ Setup failed: {e}") | |
| print("The application may not work properly. Please check your internet connection and try again.") | |
| raise | |
| # ============================================================================= | |
| # CORE PREDICTION FUNCTIONS | |
| # ============================================================================= | |
| def _predict_gpu(in_pth, progress): | |
| """ | |
| GPU-accelerated prediction of DreaMS embeddings | |
| Args: | |
| in_pth: Input file path | |
| progress: Gradio progress tracker | |
| Returns: | |
| numpy.ndarray: DreaMS embeddings | |
| """ | |
| progress(0.2, desc="Loading spectra data...") | |
| msdata = MSData.load(in_pth) | |
| progress(0.3, desc="Computing DreaMS embeddings...") | |
| embs = dreams_embeddings(msdata) | |
| print(f'Shape of the query embeddings: {embs.shape}') | |
| return embs | |
| def _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, similarity_threshold, calculate_modified_cosine=False): | |
| """ | |
| Create a single result row for the DataFrame | |
| Args: | |
| i: Query spectrum index | |
| j: Library spectrum index | |
| n: Top-k rank | |
| msdata: Query MS data | |
| msdata_lib: Library MS data | |
| sims: Similarity matrix | |
| cos_sim: Cosine similarity calculator | |
| embs: Query embeddings | |
| similarity_threshold: Similarity threshold for filtering results | |
| calculate_modified_cosine: Whether to calculate modified cosine similarity | |
| Returns: | |
| dict: Result row data | |
| """ | |
| smiles = msdata_lib.get_smiles(j) | |
| spec1 = msdata.get_spectra(i) | |
| spec2 = msdata_lib.get_spectra(j) | |
| dreams_similarity = sims[i, j] | |
| # Base row data | |
| row_data = { | |
| 'scan_number': msdata.get_values(SCAN_NUMBER, i) if SCAN_NUMBER in msdata.columns() else None, | |
| 'rt': msdata.get_values(RT, i) if RT in msdata.columns() else None, | |
| 'charge': msdata.get_values(CHARGE, i) if CHARGE in msdata.columns() else None, | |
| 'precursor_mz': msdata.get_prec_mzs(i), | |
| 'topk': n + 1, | |
| 'library_j': j, | |
| 'library_SMILES': smiles_to_html_img(smiles) if dreams_similarity > similarity_threshold else None, | |
| 'library_SMILES_raw': smiles, | |
| 'Spectrum': spectrum_to_html_img(spec1, spec2) if dreams_similarity > similarity_threshold else None, | |
| 'Spectrum_raw': su.unpad_peak_list(spec1), | |
| 'library_ID': msdata_lib.get_values('IDENTIFIER', j), | |
| 'DreaMS_similarity': dreams_similarity, | |
| 'i': i, | |
| 'j': j, | |
| 'DreaMS_embedding': embs[i], | |
| } | |
| # Add modified cosine similarity only if enabled | |
| if calculate_modified_cosine: | |
| modified_cosine_sim = cos_sim( | |
| spec1=spec1, | |
| prec_mz1=msdata.get_prec_mzs(i), | |
| spec2=spec2, | |
| prec_mz2=msdata_lib.get_prec_mzs(j), | |
| ) | |
| row_data['Modified_cosine_similarity'] = modified_cosine_sim | |
| return row_data | |
| def _process_results_dataframe(df, in_pth, similarity_threshold, calculate_modified_cosine=False): | |
| """ | |
| Process and clean the results DataFrame | |
| Args: | |
| df: Raw results DataFrame | |
| in_pth: Input file path for CSV export | |
| similarity_threshold: Similarity threshold for filtering results | |
| calculate_modified_cosine: Whether modified cosine similarity was calculated | |
| Returns: | |
| tuple: (processed_df, csv_path) | |
| """ | |
| # Remove unnecessary columns and round similarity scores | |
| df = df.drop(columns=['i', 'j', 'library_j']) | |
| df['DreaMS_similarity'] = df['DreaMS_similarity'].astype(float).round(4) | |
| # Handle modified cosine similarity column conditionally | |
| if calculate_modified_cosine and 'Modified_cosine_similarity' in df.columns: | |
| df['Modified_cosine_similarity'] = df['Modified_cosine_similarity'].astype(float).round(4) | |
| df['precursor_mz'] = df['precursor_mz'].astype(float).round(4) | |
| df['rt'] = df['rt'].astype(float).round(2) # Round retention time to 2 decimal places | |
| df['charge'] = df['charge'].astype(str) # Keep charge as string | |
| # Rename columns for display | |
| column_mapping = { | |
| 'topk': 'Top k', | |
| 'library_ID': 'Library ID', | |
| "scan_number": "Scan number", | |
| "rt": "Retention time", | |
| "charge": "Charge", | |
| "precursor_mz": "Precursor m/z", | |
| "library_SMILES": "Molecule", | |
| "library_SMILES_raw": "SMILES", | |
| "Spectrum": "Spectrum", | |
| "Spectrum_raw": "Input Spectrum", | |
| "DreaMS_similarity": "DreaMS similarity", | |
| "DreaMS_embedding": "DreaMS embedding", | |
| } | |
| # Add modified cosine similarity to column mapping only if it exists | |
| if calculate_modified_cosine and 'Modified_cosine_similarity' in df.columns: | |
| column_mapping["Modified_cosine_similarity"] = "Modified cos similarity" | |
| df = df.rename(columns=column_mapping) | |
| # Save full results to CSV | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| df_path = dio.append_to_stem(in_pth, f"MassSpecGym_hits_{timestamp}").with_suffix('.csv') | |
| df_to_save = df.drop(columns=['Molecule', 'Spectrum', 'Top k']) | |
| df_to_save.to_csv(df_path, index=False) | |
| # Filter and prepare final display DataFrame | |
| df = df.drop(columns=['DreaMS embedding', "SMILES", "Input Spectrum"]) | |
| df = df[df['Top k'] == 1].sort_values('DreaMS similarity', ascending=False) | |
| df = df.drop(columns=['Top k']) | |
| df = df[df["DreaMS similarity"] > similarity_threshold] | |
| # Add row numbers | |
| df.insert(0, 'Row', range(1, len(df) + 1)) | |
| return df, str(df_path) | |
| def _predict_core(lib_pth, in_pth, similarity_threshold, calculate_modified_cosine, progress): | |
| """ | |
| Core prediction function that orchestrates the entire prediction pipeline | |
| Args: | |
| lib_pth: Library file path | |
| in_pth: Input file path | |
| calculate_modified_cosine: Whether to calculate modified cosine similarity | |
| progress: Gradio progress tracker | |
| Returns: | |
| tuple: (results_dataframe, csv_file_path) | |
| """ | |
| in_pth = Path(in_pth) | |
| # Clear cache at start to prevent memory buildup | |
| clear_smiles_cache() | |
| # Create temporary copies of library and input files to allow multiple processes | |
| progress(0, desc="Creating temporary file copies...") | |
| temp_lib_path = Path(lib_pth).parent / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{Path(lib_pth).name}" | |
| temp_in_path = in_pth.parent / f"temp_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{in_pth.name}" | |
| shutil.copy2(lib_pth, temp_lib_path) | |
| shutil.copy2(in_pth, temp_in_path) | |
| try: | |
| # Load library data | |
| progress(0.1, desc="Loading library data...") | |
| msdata_lib = MSData.load(temp_lib_path, in_mem=True) | |
| embs_lib = msdata_lib[DREAMS_EMBEDDING] | |
| print(f'Shape of the library embeddings: {embs_lib.shape}') | |
| # Get query embeddings | |
| embs = _predict_gpu(temp_in_path, progress) | |
| # Compute similarity matrix | |
| progress(0.4, desc="Computing similarity matrix...") | |
| sims = cosine_similarity(embs, embs_lib) | |
| print(f'Shape of the similarity matrix: {sims.shape}') | |
| # Get top-k candidates | |
| k = 1 | |
| topk_cands = np.argsort(sims, axis=1)[:, -k:][:, ::-1] | |
| # Load query data for processing | |
| msdata = MSData.load(temp_in_path, in_mem=True) | |
| print(f'Available columns: {msdata.columns()}') | |
| # Construct results DataFrame | |
| progress(0.5, desc="Constructing results table...") | |
| df = [] | |
| cos_sim = su.PeakListModifiedCosine() | |
| total_spectra = len(topk_cands) | |
| for i, topk in enumerate(topk_cands): | |
| progress(0.5 + 0.4 * (i / total_spectra), | |
| desc=f"Processing hits for spectrum {i+1}/{total_spectra}...") | |
| for n, j in enumerate(topk): | |
| row_data = _create_result_row(i, j, n, msdata, msdata_lib, sims, cos_sim, embs, similarity_threshold, calculate_modified_cosine) | |
| df.append(row_data) | |
| # Clear cache every 100 spectra to prevent memory buildup | |
| if (i + 1) % 100 == 0: | |
| clear_smiles_cache() | |
| df = pd.DataFrame(df) | |
| # Process and clean results | |
| progress(0.9, desc="Post-processing results...") | |
| df, csv_path = _process_results_dataframe(df, in_pth, similarity_threshold, calculate_modified_cosine) | |
| progress(1.0, desc=f"Predictions complete! Found {len(df)} high-confidence matches.") | |
| return df, csv_path | |
| finally: | |
| # Clean up temporary files | |
| if temp_lib_path.exists(): | |
| temp_lib_path.unlink() | |
| if temp_in_path.exists(): | |
| temp_in_path.unlink() | |
| def predict(lib_pth, in_pth, similarity_threshold=0.75, calculate_modified_cosine=False, progress=gr.Progress(track_tqdm=True)): | |
| """ | |
| Main prediction function with error handling | |
| Args: | |
| lib_pth: Library file path | |
| in_pth: Input file path | |
| calculate_modified_cosine: Whether to calculate modified cosine similarity | |
| progress: Gradio progress tracker | |
| Returns: | |
| tuple: (results_dataframe, csv_file_path) | |
| Raises: | |
| gr.Error: If prediction fails or input is invalid | |
| """ | |
| try: | |
| # Validate input file | |
| if not _validate_input_file(in_pth): | |
| raise gr.Error("Invalid input file. Please provide a valid .mgf or .mzML file.") | |
| # Check if library exists | |
| if not Path(lib_pth).exists(): | |
| raise gr.Error("Spectral library not found. Please ensure the library file exists.") | |
| df, csv_path = _predict_core(lib_pth, in_pth, similarity_threshold, calculate_modified_cosine, progress) | |
| return df, csv_path | |
| except gr.Error: | |
| # Re-raise Gradio errors as-is | |
| raise | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "CUDA" in error_msg or "cuda" in error_msg: | |
| error_msg = f"GPU/CUDA error: {error_msg}. The app is falling back to CPU mode." | |
| elif "RuntimeError" in error_msg: | |
| error_msg = f"Runtime error: {error_msg}. This may be due to memory or device issues." | |
| else: | |
| error_msg = f"Error: {error_msg}" | |
| print(f"Prediction failed: {error_msg}") | |
| raise gr.Error(error_msg) | |
| # ============================================================================= | |
| # GRADIO INTERFACE SETUP | |
| # ============================================================================= | |
| def _create_gradio_interface(): | |
| """ | |
| Create and configure the Gradio interface | |
| Returns: | |
| gr.Blocks: Configured Gradio app | |
| """ | |
| # JavaScript for theme management | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # Create app with custom theme | |
| app = gr.Blocks( | |
| theme=gr.themes.Default(primary_hue="yellow", secondary_hue="pink"), | |
| js=js_func | |
| ) | |
| with app: | |
| # Header and description | |
| gr.Image("https://raw.githubusercontent.com/pluskal-lab/DreaMS/cc806fa6fea281c1e57dd81fc512f71de9290017/assets/dreams_background.png", | |
| label="DreaMS") | |
| gr.Markdown(value=""" | |
| DreaMS (Deep Representations Empowering the Annotation of Mass Spectra) is a transformer-based | |
| neural network designed to interpret tandem mass spectrometry (MS/MS) data (<a href="https://www.nature.com/articles/s41587-025-02663-3">Bushuiev et al., Nature Biotechnology, 2025</a>). | |
| This website provides an easy access to perform library matching with DreaMS against the <a href="https://huggingface.co/datasets/roman-bushuiev/MassSpecGym">MassSpecGym</a> spectral library (combination of GNPS, MoNA, and Pluskal lab data). Please upload | |
| your file with MS/MS data and click on the "Run DreaMS" button. | |
| """) | |
| # Input section | |
| with gr.Row(equal_height=True): | |
| in_pth = gr.File( | |
| file_count="single", | |
| label="Input MS/MS file (.mgf or .mzML)", | |
| ) | |
| # Example files | |
| examples = gr.Examples( | |
| examples=["./data/example_5_spectra.mgf", "./data/202312_147_P55-Leaf-r2_1uL.mzML"], | |
| inputs=[in_pth], | |
| label="Examples (click on a file to load as input)", | |
| ) | |
| # Settings section | |
| with gr.Accordion("⚙️ Settings", open=False): | |
| similarity_threshold = gr.Slider( | |
| minimum=-1.0, | |
| maximum=1.0, | |
| value=0.75, | |
| step=0.01, | |
| label="Similarity threshold", | |
| info="Only display library matches with DreaMS similarity above this threshold (rendering less results also makes calculation faster)" | |
| ) | |
| calculate_modified_cosine = gr.Checkbox( | |
| label="Calculate modified cosine similarity", | |
| value=False, | |
| info="Enable to also calculate traditional modified cosine similarity scores between the input spectra and library hits (a bit slower)" | |
| ) | |
| # Prediction button | |
| predict_button = gr.Button(value="Run DreaMS", variant="primary") | |
| # Results table | |
| gr.Markdown("## Predictions") | |
| df_file = gr.File(label="Download predictions as .csv", interactive=False, visible=True) | |
| # Results table | |
| headers = ["Row", "Scan number", "Retention time", "Charge", "Precursor m/z", "Molecule", "Spectrum", | |
| "DreaMS similarity", "Library ID"] | |
| datatype = ["number", "number", "number", "str", "number", "html", "html", "number", "str"] | |
| column_widths = ["20px", "30px", "30px", "25px", "30px", "40px", "40px", "40px", "50px"] | |
| df = gr.Dataframe( | |
| headers=headers, | |
| datatype=datatype, | |
| col_count=(len(headers), "fixed"), | |
| column_widths=column_widths, | |
| max_height=1000, | |
| show_fullscreen_button=True, | |
| show_row_numbers=False, | |
| show_search='filter', | |
| ) | |
| # Connect prediction logic | |
| inputs = [in_pth, similarity_threshold, calculate_modified_cosine] | |
| outputs = [df, df_file] | |
| # Function to update dataframe headers based on setting | |
| def update_headers(show_cosine): | |
| if show_cosine: | |
| return gr.update(headers=headers + ["Modified cosine similarity"], | |
| col_count=(len(headers) + 1, "fixed"), | |
| column_widths=column_widths + ["40px"]) | |
| else: | |
| return gr.update(headers=headers, | |
| col_count=(len(headers), "fixed"), | |
| column_widths=column_widths) | |
| # Update headers when setting changes | |
| calculate_modified_cosine.change( | |
| fn=update_headers, | |
| inputs=[calculate_modified_cosine], | |
| outputs=[df] | |
| ) | |
| predict_func = partial(predict, LIBRARY_PATH) | |
| predict_button.click(predict_func, inputs=inputs, outputs=outputs, show_progress="first") | |
| return app | |
| # ============================================================================= | |
| # MAIN EXECUTION | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| # Initialize the application | |
| setup() | |
| # Create and launch the Gradio interface | |
| app = _create_gradio_interface() | |
| app.launch(allowed_paths=['./assets']) | |
| else: | |
| # When imported as a module, just run setup | |
| setup() | |