Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from io import StringIO | |
| from momentfm import MOMENTPipeline | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize model with reconstruction task | |
| try: | |
| model = MOMENTPipeline.from_pretrained( | |
| "AutonLab/MOMENT-1-large", | |
| model_kwargs={"task_name": "reconstruction"}, # Correct task name | |
| ) | |
| model.init() | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {str(e)}") | |
| raise | |
| def validate_data(data_input): | |
| """Validate and process input data""" | |
| try: | |
| if isinstance(data_input, str): | |
| df = pd.read_csv(StringIO(data_input)) | |
| else: | |
| raise ValueError("Input must be CSV text") | |
| # Validate columns | |
| if not all(col in df.columns for col in ['timestamp', 'value']): | |
| raise ValueError("CSV must contain 'timestamp' and 'value' columns") | |
| # Convert timestamps | |
| df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce') | |
| if df['timestamp'].isnull().any(): | |
| raise ValueError("Invalid timestamp format") | |
| # Convert values to numeric | |
| df['value'] = pd.to_numeric(df['value'], errors='raise') | |
| return df.sort_values('timestamp') | |
| except Exception as e: | |
| logger.error(f"Data validation error: {str(e)}") | |
| raise | |
| def detect_anomalies(data_input, sensitivity=3.0): | |
| """Perform reconstruction-based anomaly detection""" | |
| try: | |
| df = validate_data(data_input) | |
| values = df['value'].values.astype(np.float32) | |
| # Reshape to 3D format (batch, sequence, features) | |
| values_3d = values.reshape(1, -1, 1) | |
| # Get reconstruction | |
| reconstructed = model.reconstruct(values_3d) | |
| # Calculate reconstruction error (MAE) | |
| errors = np.abs(values - reconstructed[0,:,0]) | |
| # Dynamic threshold (z-score based) | |
| threshold = np.mean(errors) + sensitivity * np.std(errors) | |
| df['anomaly_score'] = errors | |
| df['is_anomaly'] = errors > threshold | |
| # Create plot | |
| fig, ax = plt.subplots(figsize=(12, 5)) | |
| ax.plot(df['timestamp'], df['value'], 'b-', label='Value') | |
| ax.scatter( | |
| df.loc[df['is_anomaly'], 'timestamp'], | |
| df.loc[df['is_anomaly'], 'value'], | |
| color='red', s=100, label='Anomaly' | |
| ) | |
| ax.set_title(f'Anomaly Detection (Threshold: {threshold:.2f})') | |
| ax.legend() | |
| # Prepare outputs | |
| stats = { | |
| "data_points": len(df), | |
| "anomalous_points": int(df['is_anomaly'].sum()), | |
| "detection_threshold": float(threshold), | |
| "max_error": float(np.max(errors)) | |
| } | |
| return fig, stats, df.to_dict('records') | |
| except Exception as e: | |
| logger.error(f"Detection error: {str(e)}") | |
| return None, {"error": str(e)}, None | |
| # Gradio Interface | |
| with gr.Blocks(title="MOMENT Anomaly Detector") as demo: | |
| gr.Markdown("## π Equipment Anomaly Detection using MOMENT") | |
| with gr.Row(): | |
| with gr.Column(): | |
| data_input = gr.Textbox( | |
| label="Paste time-series data (CSV format)", | |
| value="""timestamp,value | |
| 2025-04-01 00:00:00,100 | |
| 2025-04-01 01:00:00,102 | |
| 2025-04-01 02:00:00,98 | |
| 2025-04-01 03:00:00,105 | |
| 2025-04-01 04:00:00,103 | |
| 2025-04-01 05:00:00,107 | |
| 2025-04-01 06:00:00,200 | |
| 2025-04-01 07:00:00,108 | |
| 2025-04-01 08:00:00,110 | |
| 2025-04-01 09:00:00,98 | |
| 2025-04-01 10:00:00,99 | |
| 2025-04-01 11:00:00,102 | |
| 2025-04-01 12:00:00,101""", | |
| lines=15 | |
| ) | |
| sensitivity = gr.Slider( | |
| minimum=1.0, | |
| maximum=5.0, | |
| value=3.0, | |
| step=0.1, | |
| label="Detection Sensitivity (Z-Score)" | |
| ) | |
| submit_btn = gr.Button("Analyze Data", variant="primary") | |
| with gr.Column(): | |
| plot_output = gr.Plot(label="Anomaly Detection Results") | |
| stats_output = gr.JSON(label="Detection Statistics") | |
| data_output = gr.JSON( | |
| label="Processed Data", | |
| max_lines=15 | |
| ) | |
| submit_btn.click( | |
| detect_anomalies, | |
| inputs=[data_input, sensitivity], | |
| outputs=[plot_output, stats_output, data_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |