zainulabedin949's picture
Update app.py
bd1a142 verified
raw
history blame
4.63 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from io import StringIO
from momentfm import MOMENTPipeline
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize model with reconstruction task
try:
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={"task_name": "reconstruction"}, # Correct task name
)
model.init()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise
def validate_data(data_input):
"""Validate and process input data"""
try:
if isinstance(data_input, str):
df = pd.read_csv(StringIO(data_input))
else:
raise ValueError("Input must be CSV text")
# Validate columns
if not all(col in df.columns for col in ['timestamp', 'value']):
raise ValueError("CSV must contain 'timestamp' and 'value' columns")
# Convert timestamps
df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
if df['timestamp'].isnull().any():
raise ValueError("Invalid timestamp format")
# Convert values to numeric
df['value'] = pd.to_numeric(df['value'], errors='raise')
return df.sort_values('timestamp')
except Exception as e:
logger.error(f"Data validation error: {str(e)}")
raise
def detect_anomalies(data_input, sensitivity=3.0):
"""Perform reconstruction-based anomaly detection"""
try:
df = validate_data(data_input)
values = df['value'].values.astype(np.float32)
# Reshape to 3D format (batch, sequence, features)
values_3d = values.reshape(1, -1, 1)
# Get reconstruction
reconstructed = model.reconstruct(values_3d)
# Calculate reconstruction error (MAE)
errors = np.abs(values - reconstructed[0,:,0])
# Dynamic threshold (z-score based)
threshold = np.mean(errors) + sensitivity * np.std(errors)
df['anomaly_score'] = errors
df['is_anomaly'] = errors > threshold
# Create plot
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(df['timestamp'], df['value'], 'b-', label='Value')
ax.scatter(
df.loc[df['is_anomaly'], 'timestamp'],
df.loc[df['is_anomaly'], 'value'],
color='red', s=100, label='Anomaly'
)
ax.set_title(f'Anomaly Detection (Threshold: {threshold:.2f})')
ax.legend()
# Prepare outputs
stats = {
"data_points": len(df),
"anomalous_points": int(df['is_anomaly'].sum()),
"detection_threshold": float(threshold),
"max_error": float(np.max(errors))
}
return fig, stats, df.to_dict('records')
except Exception as e:
logger.error(f"Detection error: {str(e)}")
return None, {"error": str(e)}, None
# Gradio Interface
with gr.Blocks(title="MOMENT Anomaly Detector") as demo:
gr.Markdown("## πŸ” Equipment Anomaly Detection using MOMENT")
with gr.Row():
with gr.Column():
data_input = gr.Textbox(
label="Paste time-series data (CSV format)",
value="""timestamp,value
2025-04-01 00:00:00,100
2025-04-01 01:00:00,102
2025-04-01 02:00:00,98
2025-04-01 03:00:00,105
2025-04-01 04:00:00,103
2025-04-01 05:00:00,107
2025-04-01 06:00:00,200
2025-04-01 07:00:00,108
2025-04-01 08:00:00,110
2025-04-01 09:00:00,98
2025-04-01 10:00:00,99
2025-04-01 11:00:00,102
2025-04-01 12:00:00,101""",
lines=15
)
sensitivity = gr.Slider(
minimum=1.0,
maximum=5.0,
value=3.0,
step=0.1,
label="Detection Sensitivity (Z-Score)"
)
submit_btn = gr.Button("Analyze Data", variant="primary")
with gr.Column():
plot_output = gr.Plot(label="Anomaly Detection Results")
stats_output = gr.JSON(label="Detection Statistics")
data_output = gr.JSON(
label="Processed Data",
max_lines=15
)
submit_btn.click(
detect_anomalies,
inputs=[data_input, sensitivity],
outputs=[plot_output, stats_output, data_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)