import gradio as gr import pandas as pd import numpy as np import matplotlib.pyplot as plt from io import StringIO from momentfm import MOMENTPipeline import logging # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize model with reconstruction task try: model = MOMENTPipeline.from_pretrained( "AutonLab/MOMENT-1-large", model_kwargs={"task_name": "reconstruction"}, # Correct task name ) model.init() logger.info("Model loaded successfully") except Exception as e: logger.error(f"Model loading failed: {str(e)}") raise def validate_data(data_input): """Validate and process input data""" try: if isinstance(data_input, str): df = pd.read_csv(StringIO(data_input)) else: raise ValueError("Input must be CSV text") # Validate columns if not all(col in df.columns for col in ['timestamp', 'value']): raise ValueError("CSV must contain 'timestamp' and 'value' columns") # Convert timestamps df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce') if df['timestamp'].isnull().any(): raise ValueError("Invalid timestamp format") # Convert values to numeric df['value'] = pd.to_numeric(df['value'], errors='raise') return df.sort_values('timestamp') except Exception as e: logger.error(f"Data validation error: {str(e)}") raise def detect_anomalies(data_input, sensitivity=3.0): """Perform reconstruction-based anomaly detection""" try: df = validate_data(data_input) values = df['value'].values.astype(np.float32) # Reshape to 3D format (batch, sequence, features) values_3d = values.reshape(1, -1, 1) # Get reconstruction reconstructed = model.reconstruct(values_3d) # Calculate reconstruction error (MAE) errors = np.abs(values - reconstructed[0,:,0]) # Dynamic threshold (z-score based) threshold = np.mean(errors) + sensitivity * np.std(errors) df['anomaly_score'] = errors df['is_anomaly'] = errors > threshold # Create plot fig, ax = plt.subplots(figsize=(12, 5)) ax.plot(df['timestamp'], df['value'], 'b-', label='Value') ax.scatter( df.loc[df['is_anomaly'], 'timestamp'], df.loc[df['is_anomaly'], 'value'], color='red', s=100, label='Anomaly' ) ax.set_title(f'Anomaly Detection (Threshold: {threshold:.2f})') ax.legend() # Prepare outputs stats = { "data_points": len(df), "anomalous_points": int(df['is_anomaly'].sum()), "detection_threshold": float(threshold), "max_error": float(np.max(errors)) } return fig, stats, df.to_dict('records') except Exception as e: logger.error(f"Detection error: {str(e)}") return None, {"error": str(e)}, None # Gradio Interface with gr.Blocks(title="MOMENT Anomaly Detector") as demo: gr.Markdown("## 🔍 Equipment Anomaly Detection using MOMENT") with gr.Row(): with gr.Column(): data_input = gr.Textbox( label="Paste time-series data (CSV format)", value="""timestamp,value 2025-04-01 00:00:00,100 2025-04-01 01:00:00,102 2025-04-01 02:00:00,98 2025-04-01 03:00:00,105 2025-04-01 04:00:00,103 2025-04-01 05:00:00,107 2025-04-01 06:00:00,200 2025-04-01 07:00:00,108 2025-04-01 08:00:00,110 2025-04-01 09:00:00,98 2025-04-01 10:00:00,99 2025-04-01 11:00:00,102 2025-04-01 12:00:00,101""", lines=15 ) sensitivity = gr.Slider( minimum=1.0, maximum=5.0, value=3.0, step=0.1, label="Detection Sensitivity (Z-Score)" ) submit_btn = gr.Button("Analyze Data", variant="primary") with gr.Column(): plot_output = gr.Plot(label="Anomaly Detection Results") stats_output = gr.JSON(label="Detection Statistics") data_output = gr.JSON( label="Processed Data", max_lines=15 ) submit_btn.click( detect_anomalies, inputs=[data_input, sensitivity], outputs=[plot_output, stats_output, data_output] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)