zainulabedin949's picture
Update app.py
9458d26 verified
raw
history blame
4.53 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from io import StringIO
import logging
from momentfm import MOMENTPipeline
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize model (with error handling)
try:
model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={"task_name": "reconstruction"},
)
model.init()
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise
def validate_and_process_data(data_input):
"""Handle all data validation and processing"""
try:
# Read data
if isinstance(data_input, str):
df = pd.read_csv(StringIO(data_input))
else:
raise ValueError("Input must be CSV text")
# Check required columns
required = ['timestamp', 'value']
if not all(col in df.columns for col in required):
missing = [col for col in required if col not in df.columns]
raise ValueError(f"Missing columns: {missing}")
# Convert and validate timestamp
df['timestamp'] = pd.to_datetime(df['timestamp'], errors='coerce')
if df['timestamp'].isnull().any():
raise ValueError("Invalid timestamp format")
# Validate values
try:
df['value'] = pd.to_numeric(df['value'])
except:
raise ValueError("Non-numeric values found")
# Sort by timestamp
df = df.sort_values('timestamp').reset_index(drop=True)
return df
except Exception as e:
logger.error(f"Data processing error: {str(e)}")
raise
def detect_anomalies(data_input, threshold=0.1):
"""Main anomaly detection function"""
try:
# Process input data
df = validate_and_process_data(data_input)
values = df['value'].values.astype(np.float32)
# Get reconstruction
reconstruction = model.reconstruct(values)
errors = np.abs(values - reconstruction)
# Dynamic threshold (3Οƒ from mean)
threshold_value = np.mean(errors) + 3 * np.std(errors)
df['anomaly_score'] = errors
df['is_anomaly'] = errors > threshold_value
# Create plot
fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(df['timestamp'], df['value'], 'b-', label='Value')
ax.scatter(
df.loc[df['is_anomaly'], 'timestamp'],
df.loc[df['is_anomaly'], 'value'],
color='red', s=100, label='Anomaly'
)
ax.set_title(f'Anomaly Detection (Threshold: {threshold_value:.2f})')
ax.legend()
plt.close(fig) # Prevents duplicate plots
# Prepare outputs
stats = {
"data_points": len(df),
"anomalies": int(df['is_anomaly'].sum()),
"threshold_used": float(threshold_value),
"max_score": float(np.max(errors))
}
return fig, stats, df.to_dict('records')
except Exception as e:
logger.error(f"Detection error: {str(e)}")
return None, {"error": str(e)}, None
# Gradio Interface
with gr.Blocks(title="Anomaly Detector") as demo:
gr.Markdown("# 🚨 Time-Series Anomaly Detection")
with gr.Row():
with gr.Column():
data_input = gr.Textbox(
label="Paste CSV Data",
value="""timestamp,value
2025-04-01 00:00:00,100
2025-04-01 01:00:00,102
2025-04-01 02:00:00,98
2025-04-01 03:00:00,105
2025-04-01 04:00:00,103
2025-04-01 05:00:00,107
2025-04-01 06:00:00,200
2025-04-01 07:00:00,108
2025-04-01 08:00:00,110
2025-04-01 09:00:00,98
2025-04-01 10:00:00,99
2025-04-01 11:00:00,102
2025-04-01 12:00:00,101""",
lines=15
)
threshold = gr.Slider(0.01, 1.0, value=0.3, label="Sensitivity (higher = stricter)")
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column():
plot_output = gr.Plot(label="Results")
stats_output = gr.JSON(label="Statistics")
data_output = gr.JSON(label="Detailed Data")
submit_btn.click(
detect_anomalies,
inputs=[data_input, threshold],
outputs=[plot_output, stats_output, data_output]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)