AI-OMS-Analyze / app.py
kawaiipeace's picture
update
74baf5f
raw
history blame
35.5 kB
import gradio as gr
import pandas as pd
from pathlib import Path
from scripts.recommendation import summarize_events
from scripts.data_cleansing import cleanse_data
from dotenv import load_dotenv
import os
import numpy as np
import joblib
ROOT = Path(__file__).resolve().parents[0]
load_dotenv(ROOT / '.env')
def preview_csv(file_obj):
try:
df = pd.read_csv(file_obj.name, dtype=str)
return df.head(10).to_html(index=False)
except Exception as e:
return f"Error reading file: {e}"
def parse_row_selection(df, rows_text: str):
if not rows_text:
return df
idx = []
for token in rows_text.split(','):
token = token.strip()
if token.isdigit():
idx.append(int(token))
return df.iloc[idx]
with gr.Blocks() as demo:
gr.Markdown("# OMS Analyze — Prototype")
gr.Markdown("> Created by PEACE, Powered by AI, Version 0.0.1")
with gr.Tabs():
# Upload & Preview tab
with gr.TabItem('Upload & Preview'):
gr.Markdown("**Usecase Scenario — Upload & Preview**: อัปโหลดไฟล์ CSV เพื่อตรวจสอบข้อมูลต้นฉบับ ทำความสะอาดข้อมูล (ลบข้อมูลซ้ำ, จัดการค่าที่หายไป) เปรียบเทียบตัวอย่างก่อน/หลัง และดาวน์โหลดไฟล์ที่ทำความสะอาดแล้ว")
csv_up = gr.File(label='Upload CSV (data.csv)')
with gr.Row():
remove_dup = gr.Checkbox(label='Remove Duplicates', value=False)
missing_handling = gr.Radio(choices=['drop','impute_mean','impute_median','impute_mode'], value='drop', label='Missing Values Handling')
apply_clean = gr.Button('Apply Cleansing')
with gr.Tabs():
with gr.TabItem('Original Data'):
original_preview = gr.Dataframe(label='Original Data Preview')
with gr.TabItem('Cleansed Data'):
cleansed_preview = gr.Dataframe(label='Cleansed Data Preview')
download_cleansed = gr.File(label='Download Cleansed CSV')
clean_status = gr.Textbox(label='Cleansing Status', interactive=False)
def initial_preview(file):
if file is None:
return pd.DataFrame(), pd.DataFrame(), "Upload a file"
df = pd.read_csv(file.name, dtype=str)
return df.head(100), pd.DataFrame(), "File uploaded, apply cleansing if needed"
def apply_cleansing(file, remove_duplicates, missing_strategy):
if file is None:
return pd.DataFrame(), "No file", None
try:
df = pd.read_csv(file.name, dtype=str)
df_clean, orig_shape, clean_shape = cleanse_data(df, remove_duplicates, missing_strategy)
status = f"Original: {orig_shape[0]} rows, {orig_shape[1]} cols → Cleaned: {clean_shape[0]} rows, {clean_shape[1]} cols"
# Save cleansed data for download
out_file = ROOT / 'outputs' / 'cleansed_data.csv'
out_file.parent.mkdir(exist_ok=True)
df_clean.to_csv(out_file, index=False, encoding='utf-8-sig')
return df_clean.head(100), status, str(out_file)
except Exception as e:
return pd.DataFrame(), f"Error: {e}", None
csv_up.change(fn=initial_preview, inputs=csv_up, outputs=[original_preview, cleansed_preview, clean_status])
apply_clean.click(fn=apply_cleansing, inputs=[csv_up, remove_dup, missing_handling], outputs=[cleansed_preview, clean_status, download_cleansed])
# Summary tab
with gr.TabItem('Summary'):
gr.Markdown("**Usecase Scenario — Summary**: สร้างสรุปภาพรวมของชุดข้อมูลทั้งหมด รวมสถิติพื้นฐาน และคำนวณดัชนีความน่าเชื่อถือ (เช่น SAIFI, SAIDI, CAIDI) พร้อมตัวเลือกใช้ Generative AI ในการขยายความ")
csv_in_sum = gr.File(label='Upload CSV for Overall Summary')
with gr.Row():
use_hf_sum = gr.Checkbox(label='Use Generative AI for Summary', value=False)
total_customers = gr.Number(label='Total Customers (for reliability calculation)', value=500000, precision=0)
run_sum = gr.Button('Generate Overall Summary')
with gr.Row():
model_selector_sum = gr.Dropdown(
choices=[
'meta-llama/Llama-3.1-8B-Instruct:novita',
'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita',
'Qwen/Qwen3-VL-235B-A22B-Instruct:novita',
'deepseek-ai/DeepSeek-R1:novita',
'moonshotai/Kimi-K2-Instruct-0905:novita'
],
value='meta-llama/Llama-3.1-8B-Instruct:novita',
label='GenAI Model',
interactive=True,
visible=False
)
with gr.Tabs():
with gr.TabItem('AI Summary'):
ai_summary_out = gr.Textbox(label='AI Generated Summary', lines=10)
with gr.TabItem('Basic Statistics'):
basic_stats_out = gr.JSON(label='Basic Statistics')
with gr.TabItem('Reliability Indices'):
reliability_out = gr.Dataframe(label='Reliability Metrics')
sum_status = gr.Textbox(label='Summary Status', interactive=False)
def run_overall_summary(file, use_hf_flag, total_cust, model):
if file is None:
return {}, {}, pd.DataFrame(), 'No file provided'
try:
from scripts.summary import summarize_overall
df = pd.read_csv(file.name, dtype=str)
result = summarize_overall(df, use_hf=use_hf_flag, model=model, total_customers=total_cust)
# Prepare outputs
ai_summary = result.get('ai_summary', 'ไม่สามารถสร้างสรุปด้วย AI ได้')
basic_stats = {
'total_events': result.get('total_events'),
'date_range': result.get('date_range'),
'event_types': result.get('event_types'),
'total_affected_customers': result.get('total_affected_customers')
}
# Reliability metrics as DataFrame
reliability_df = result.get('reliability_df', pd.DataFrame())
status = f"Summary generated for {len(df)} events. AI used: {use_hf_flag}"
return ai_summary, basic_stats, reliability_df, status
except Exception as e:
return f"Error: {str(e)}", {}, pd.DataFrame(), f'Summary failed: {e}'
def update_model_visibility_sum(use_hf_flag):
return gr.update(visible=use_hf_flag, interactive=use_hf_flag)
use_hf_sum.change(fn=update_model_visibility_sum, inputs=use_hf_sum, outputs=model_selector_sum)
run_sum.click(fn=run_overall_summary, inputs=[csv_in_sum, use_hf_sum, total_customers, model_selector_sum], outputs=[ai_summary_out, basic_stats_out, reliability_out, sum_status])
# Recommendation tab
with gr.TabItem('Recommendation'):
gr.Markdown("**Usecase Scenario — Recommendation**: สร้างสรุปเหตุการณ์ (เช่น สรุปเหตุการณ์ไฟฟ้าขัอข้องหรือบำรุงรักษา) สำหรับแถวที่เลือก ปรับระดับรายละเอียด และเลือกใช้ Generative AI เพื่อเพิ่มความชัดเจน พร้อมดาวน์โหลดไฟล์สรุป")
csv_in = gr.File(label='Upload CSV (data.csv)')
with gr.Row():
rows = gr.Textbox(label='Rows (comma-separated indexes) or empty = all', placeholder='e.g. 0,1,2')
use_hf = gr.Checkbox(label='Use Generative AI', value=False)
verbosity = gr.Radio(choices=['analyze','recommend'], value='analyze', label='Summary Type', interactive=True)
run_btn = gr.Button('Generate Summaries', interactive=True)
with gr.Row():
model_selector = gr.Dropdown(
choices=[
'meta-llama/Llama-3.1-8B-Instruct:novita',
'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita',
'Qwen/Qwen3-VL-235B-A22B-Instruct:novita',
'deepseek-ai/DeepSeek-R1:novita',
'moonshotai/Kimi-K2-Instruct-0905:novita'
],
value='meta-llama/Llama-3.1-8B-Instruct:novita',
label='GenAI Model',
interactive=True,
visible=False
)
out = gr.Dataframe(headers=['EventNumber','OutageDateTime','Summary'])
status = gr.Textbox(label='Status', interactive=False)
download = gr.File(label='Download summaries')
def run_summarize(file, rows_text, use_hf_flag, verbosity_level, model):
print(f"Debug: file={file}, rows_text={rows_text}, use_hf_flag={use_hf_flag}, verbosity_level={verbosity_level}, model={model}")
if file is None:
return pd.DataFrame([], columns=['EventNumber','OutageDateTime','Summary']), 'No file provided', None
df = pd.read_csv(file.name, dtype=str)
df_sel = parse_row_selection(df, rows_text)
res = summarize_events(df_sel, use_hf=use_hf_flag, verbosity=verbosity_level, model=model)
out_df = pd.DataFrame(res)
out_file = ROOT / 'outputs' / 'summaries_from_ui.csv'
out_file.parent.mkdir(exist_ok=True)
out_df.to_csv(out_file, index=False, encoding='utf-8-sig')
status_text = f"Summaries generated: {len(out_df)} rows. HF used: {use_hf_flag}"
return out_df, status_text, str(out_file)
def update_model_visibility(use_hf_flag):
return gr.update(visible=use_hf_flag, interactive=use_hf_flag)
use_hf.change(fn=update_model_visibility, inputs=use_hf, outputs=model_selector)
run_btn.click(fn=run_summarize, inputs=[csv_in, rows, use_hf, verbosity, model_selector], outputs=[out, status, download])
# Anomaly Detection tab
with gr.TabItem('Anomaly Detection'):
gr.Markdown("**Usecase Scenario — Anomaly Detection**: ตรวจจับเหตุการณ์ที่มีพฤติกรรมผิดปกติในชุดข้อมูล (เช่น เหตุการณ์ที่มีค่าสูง/ต่ำผิดปกติ) โดยใช้หลาย algorithm ปรับระดับ contamination และส่งออกผลลัพธ์พร้อมธงความผิดปกติ")
csv_in_anom = gr.File(label='Upload CSV for Anomaly')
with gr.Row():
alg = gr.Radio(choices=['iso+lof','iso','lof','autoencoder'], value='iso+lof', label='Algorithm')
contamination = gr.Slider(minimum=0.01, maximum=0.2, value=0.05, step=0.01, label='Contamination')
run_anom = gr.Button('Run Anomaly Detection')
anom_out = gr.Dataframe()
anom_status = gr.Textbox(label='Anomaly Status', interactive=False)
anom_download = gr.File(label='Download anomalies CSV')
def run_anomaly_ui(file, algorithm, contamination):
if file is None:
return pd.DataFrame(), 'No file provided', None
from scripts.anomaly import detect_anomalies
df = pd.read_csv(file.name, dtype=str)
res = detect_anomalies(df, contamination=contamination, algorithm=algorithm)
# Reorder columns to put ensemble_flag and final_flag at the end
cols = [c for c in res.columns if c not in ['ensemble_flag', 'final_flag']] + ['ensemble_flag', 'final_flag']
res = res[cols]
out_file = ROOT / 'outputs' / 'anomalies_from_ui.csv'
out_file.parent.mkdir(exist_ok=True)
res.to_csv(out_file, index=False, encoding='utf-8-sig')
status = f"Anomaly detection done. Rows: {len(res)}. Flags: {res['final_flag'].sum()}"
return res, status, str(out_file)
run_anom.click(fn=run_anomaly_ui, inputs=[csv_in_anom, alg, contamination], outputs=[anom_out, anom_status, anom_download])
# Classification tab
with gr.TabItem('Classification'):
gr.Markdown("**Usecase Scenario — Classification**: ฝึกและทดสอบโมเดลเพื่อจำแนกสาเหตุของเหตุการณ์ กำหนดคอลัมน์เป้าหมาย ปรับ hyperparameters, เปิดใช้งาน weak-labeling และดาวน์โหลดโมเดล/ผลการทำนาย")
csv_in_cls = gr.File(label='Upload CSV for Classification')
with gr.Row():
label_col = gr.Dropdown(choices=['CauseType','SubCauseType'], value='CauseType', label='Target Column')
do_weak = gr.Checkbox(label='Run weak-labeling using HF (requires HF_TOKEN)', value=False)
model_type = gr.Radio(choices=['rf','gb','mlp'], value='rf', label='Model Type')
run_cls = gr.Button('Train Classifier')
def update_hyperparams_visibility(model_choice):
rf_visible = model_choice == 'rf'
gb_visible = model_choice == 'gb'
mlp_visible = model_choice == 'mlp'
return [
gr.update(visible=rf_visible),
gr.update(visible=rf_visible),
gr.update(visible=rf_visible),
gr.update(visible=rf_visible),
gr.update(visible=gb_visible),
gr.update(visible=gb_visible),
gr.update(visible=gb_visible),
gr.update(visible=mlp_visible),
gr.update(visible=mlp_visible),
gr.update(visible=mlp_visible),
]
with gr.Accordion("Hyperparameters (Advanced)", open=False):
gr.Markdown("Adjust hyperparameters for the selected model. Defaults are set for good performance.")
rf_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="RF: n_estimators", visible=True)
rf_max_depth = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="RF: max_depth", visible=True)
rf_min_samples_split = gr.Slider(minimum=2, maximum=10, value=2, step=1, label="RF: min_samples_split", visible=True)
rf_min_samples_leaf = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="RF: min_samples_leaf", visible=True)
gb_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="GB: n_estimators", visible=False)
gb_max_depth = gr.Slider(minimum=3, maximum=20, value=3, step=1, label="GB: max_depth", visible=False)
gb_learning_rate = gr.Slider(minimum=0.01, maximum=0.3, value=0.1, step=0.01, label="GB: learning_rate", visible=False)
mlp_hidden_layer_sizes = gr.Textbox(value="(100,)", label="MLP: hidden_layer_sizes (tuple)", visible=False)
mlp_alpha = gr.Slider(minimum=0.0001, maximum=0.01, value=0.0001, step=0.0001, label="MLP: alpha", visible=False)
mlp_max_iter = gr.Slider(minimum=100, maximum=4000, value=500, step=50, label="MLP: max_iter", visible=False)
model_type.change(fn=update_hyperparams_visibility, inputs=model_type, outputs=[rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter])
cls_out = gr.Textbox(label='Classification Report')
model_path_state = gr.State()
cls_download_model = gr.File(label='Download saved model')
cls_download_preds = gr.File(label='Download predictions CSV')
# Test section
gr.Markdown("---")
gr.Markdown("**ทดสอบโมเดล**: อัปโหลดไฟล์ CSV ใหม่เพื่อทดสอบโมเดลที่ฝึกแล้ว")
test_csv = gr.File(label='Upload CSV for Testing')
run_test = gr.Button('Test Model')
test_out = gr.Dataframe(label='Test Predictions')
test_status = gr.Textbox(label='Test Status', interactive=False)
test_download = gr.File(label='Download Test Predictions')
def run_classify_ui(file, label_col_choice, use_weak, model_choice, rf_n_est, rf_max_d, rf_min_ss, rf_min_sl, gb_n_est, gb_max_d, gb_lr, mlp_hls, mlp_a, mlp_mi):
if file is None:
return 'No file provided', None, None, None
from scripts.classify import train_classifier
df = pd.read_csv(file.name, dtype=str)
try:
hyperparams = {}
if model_choice == 'rf':
hyperparams = {'n_estimators': int(rf_n_est), 'max_depth': int(rf_max_d), 'min_samples_split': int(rf_min_ss), 'min_samples_leaf': int(rf_min_sl)}
elif model_choice == 'gb':
hyperparams = {'n_estimators': int(gb_n_est), 'max_depth': int(gb_max_d), 'learning_rate': gb_lr}
elif model_choice == 'mlp':
import ast
hyperparams = {'hidden_layer_sizes': ast.literal_eval(mlp_hls), 'alpha': mlp_a, 'max_iter': int(mlp_mi)}
res = train_classifier(df, label_col=label_col_choice, model_type=model_choice, hyperparams=hyperparams)
report = res.get('report','')
model_file = res.get('model_file')
preds_file = res.get('predictions_file')
# ensure returned file paths are strings for Gradio
return report, model_file, preds_file, model_file
except Exception as e:
return f'Training failed: {e}', None, None, None
def run_test_ui(test_file, model_path):
if test_file is None:
return pd.DataFrame(), 'No test file provided', None
if model_path is None:
return pd.DataFrame(), 'No trained model available. Please train a model first.', None
try:
from scripts.classify import parse_and_features
# Load model
model_data = joblib.load(model_path)
pipeline = model_data['pipeline']
le = model_data['label_encoder']
# Load and preprocess test data
df_test = pd.read_csv(test_file.name, dtype=str)
df_test = parse_and_features(df_test)
# Define features (same as training)
feature_cols = ['duration_min','Load(MW)_num','Capacity(kVA)_num','AffectedCustomer_num','hour','weekday','device_freq','OpDeviceType','Owner','Weather','EventType']
X_test = df_test[feature_cols]
# Predict
y_pred_encoded = pipeline.predict(X_test)
y_pred = le.inverse_transform(y_pred_encoded)
# Create output df
pred_df = df_test.copy()
pred_df['Predicted_CauseType'] = y_pred
# Save predictions
out_file = ROOT / 'outputs' / 'test_predictions.csv'
out_file.parent.mkdir(exist_ok=True)
pred_df.to_csv(out_file, index=False, encoding='utf-8-sig')
status = f"Test completed. Predictions for {len(pred_df)} rows."
return pred_df.head(100), status, str(out_file)
except Exception as e:
return pd.DataFrame(), f'Test failed: {e}', None
run_cls.click(fn=run_classify_ui, inputs=[csv_in_cls, label_col, do_weak, model_type, rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter], outputs=[cls_out, cls_download_model, cls_download_preds, model_path_state])
run_test.click(fn=run_test_ui, inputs=[test_csv, model_path_state], outputs=[test_out, test_status, test_download])
# Label Suggestion tab
with gr.TabItem('Label Suggestion'):
gr.Markdown("**Usecase Scenario — Label Suggestion**: ให้คำแนะนำป้ายกำกับสาเหตุที่เป็นไปได้สำหรับเหตุการณ์ที่ไม่มีฉลาก โดยเทียบความคล้ายกับตัวอย่างที่มีฉลาก ปรับจำนวนคำแนะนำสูงสุด และส่งออกเป็นไฟล์ CSV")
csv_in_ls = gr.File(label='Upload CSV (defaults to data/data_3.csv)')
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label='Top K suggestions')
run_ls = gr.Button('Run Label Suggestion')
ls_out = gr.Dataframe()
ls_status = gr.Textbox(label='Label Suggestion Status', interactive=False)
ls_download = gr.File(label='Download label suggestions')
def run_label_suggestion(file, top_k_suggest):
# delegate to scripts.label_suggestion
from scripts.label_suggestion import suggest_labels_to_file
if file is None:
default = ROOT / 'data' / 'data_3.csv'
if not default.exists():
return pd.DataFrame(), 'No file provided and default data/data_3.csv not found', None
df = pd.read_csv(default, dtype=str)
else:
df = pd.read_csv(file.name, dtype=str)
out_file = ROOT / 'outputs' / 'label_suggestions.csv'
out_df = suggest_labels_to_file(df, out_path=str(out_file), top_k=int(top_k_suggest))
status = f"Label suggestion done. Unknown rows processed: {len(out_df)}. Output: {out_file}"
return out_df, status, str(out_file) if len(out_df)>0 else None
run_ls.click(fn=run_label_suggestion, inputs=[csv_in_ls, top_k], outputs=[ls_out, ls_status, ls_download])
# Forecasting tab
with gr.TabItem('Forecasting'):
gr.Markdown("**Usecase Scenario — Forecasting**: พยากรณ์จำนวนเหตุการณ์หรือเวลาหยุดทำงานในอนาคตโดยเลือกโมเดล (Prophet, LSTM, Bi-LSTM, GRU, Naive) ปรับพารามิเตอร์ และส่งออกผลการพยากรณ์")
gr.Markdown("*Multivariate forecasting (ใช้หลายฟีเจอร์) รองรับเฉพาะโมเดล LSTM, Bi-LSTM, GRU เท่านั้น*")
csv_in_fc = gr.File(label='Upload CSV for Forecasting')
with gr.Row():
metric_fc = gr.Radio(choices=['count','downtime_minutes'], value='count', label='Metric to Forecast')
model_type_fc = gr.Radio(choices=['prophet','lstm','bilstm','gru','naive'], value='lstm', label='Forecasting Model', elem_id='forecast_model_radio')
periods_fc = gr.Slider(minimum=1, maximum=30, value=7, step=1, label='Forecast Periods (days)')
multivariate_fc = gr.Checkbox(value=False, label='Use Multivariate (Multiple Features)', interactive=False)
run_fc = gr.Button('Run Forecasting')
# Add state to track current model
current_model_state = gr.State(value='lstm')
def update_multivariate_visibility(model_choice):
# Multivariate is only supported for LSTM, Bi-LSTM, GRU
supported_models = ['lstm', 'bilstm', 'gru']
is_supported = model_choice in supported_models
return gr.update(interactive=is_supported, value=False)
def update_model_state(model_choice):
return model_choice
# Hyperparameter controls for forecasting
with gr.Accordion("Hyperparameters (Advanced)", open=False):
gr.Markdown("Adjust hyperparameters for the selected forecasting model. Defaults are set for good performance.")
# Prophet hyperparameters
prophet_changepoint_prior = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Prophet: changepoint_prior_scale", visible=False)
prophet_seasonality_prior = gr.Slider(minimum=0.01, maximum=10.0, value=10.0, step=0.1, label="Prophet: seasonality_prior_scale", visible=False)
prophet_seasonality_mode = gr.Radio(choices=['additive', 'multiplicative'], value='additive', label="Prophet: seasonality_mode", visible=False)
# Deep learning hyperparameters (LSTM, Bi-LSTM, GRU)
dl_seq_length = gr.Slider(minimum=3, maximum=30, value=7, step=1, label="DL: sequence_length (lag/input length)", visible=True)
dl_epochs = gr.Slider(minimum=10, maximum=200, value=100, step=10, label="DL: epochs", visible=True)
dl_batch_size = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="DL: batch_size", visible=True)
dl_learning_rate = gr.Slider(minimum=0.0001, maximum=0.01, value=0.001, step=0.0001, label="DL: learning_rate", visible=True)
dl_units = gr.Slider(minimum=32, maximum=256, value=100, step=16, label="DL: units (neurons)", visible=True)
dl_dropout = gr.Slider(minimum=0.0, maximum=0.5, value=0.2, step=0.05, label="DL: dropout_rate", visible=True)
# Naive has no hyperparameters
def update_forecast_hyperparams_visibility(model_choice):
prophet_visible = model_choice == 'prophet'
dl_visible = model_choice in ['lstm', 'bilstm', 'gru']
return [
gr.update(visible=prophet_visible), # prophet_changepoint_prior
gr.update(visible=prophet_visible), # prophet_seasonality_prior
gr.update(visible=prophet_visible), # prophet_seasonality_mode
gr.update(visible=dl_visible), # dl_seq_length
gr.update(visible=dl_visible), # dl_epochs
gr.update(visible=dl_visible), # dl_batch_size
gr.update(visible=dl_visible), # dl_learning_rate
gr.update(visible=dl_visible), # dl_units
gr.update(visible=dl_visible), # dl_dropout
]
with gr.Tabs():
with gr.TabItem('Historical Data'):
hist_out = gr.Dataframe(label='Historical Time Series Data')
with gr.TabItem('Forecast Results'):
fcst_out = gr.Dataframe(label='Forecast Results')
with gr.TabItem('Time Series Plot'):
plot_out = gr.Plot(label='Historical + Forecast Plot')
fc_status = gr.Textbox(label='Forecast Status', interactive=False)
fc_download = gr.File(label='Download forecast CSV')
def run_forecast_ui(file, metric, model_type, periods, multivariate, current_model, prophet_cp, prophet_sp, prophet_sm, dl_sl, dl_e, dl_bs, dl_lr, dl_u, dl_d):
# Use current_model if available, otherwise use model_type
actual_model = current_model if current_model else model_type
if file is None:
return pd.DataFrame(), pd.DataFrame(), None, 'No file provided', None
try:
from scripts.forecast import run_forecast
import matplotlib.pyplot as plt
df = pd.read_csv(file.name, dtype=str)
# Build hyperparams dict based on model type
hyperparams = {}
if actual_model == 'prophet':
hyperparams = {
'changepoint_prior_scale': prophet_cp,
'seasonality_prior_scale': prophet_sp,
'seasonality_mode': prophet_sm
}
elif actual_model in ['lstm', 'bilstm', 'gru']:
hyperparams = {
'seq_length': int(dl_sl),
'epochs': int(dl_e),
'batch_size': int(dl_bs),
'learning_rate': dl_lr,
'units': int(dl_u),
'dropout_rate': dl_d
}
ts, fcst = run_forecast(df, metric=metric, periods=periods, model_type=actual_model, multivariate=multivariate, hyperparams=hyperparams)
# Create time series plot
fig, ax = plt.subplots(figsize=(14, 7))
# Plot historical data
if len(ts) > 0 and 'y' in ts.columns:
ax.plot(ts['ds'], ts['y'], 'b-', label='Historical Data', linewidth=2, marker='o', markersize=4)
# Plot forecast data
if len(fcst) > 0 and 'yhat' in fcst.columns:
ax.plot(fcst['ds'], fcst['yhat'], 'r--', label='Forecast', linewidth=3, marker='s', markersize=5)
if 'yhat_lower' in fcst.columns and 'yhat_upper' in fcst.columns:
ax.fill_between(fcst['ds'], fcst['yhat_lower'], fcst['yhat_upper'],
color='red', alpha=0.3, label='Confidence Interval')
# Add vertical line to separate historical from forecast
if len(ts) > 0 and len(fcst) > 0:
last_hist_date = ts['ds'].max()
ax.axvline(x=last_hist_date, color='gray', linestyle='--', alpha=0.7, label='Forecast Start')
ax.set_title(f'Time Series Forecast: {model_type.upper()} ({metric.replace("_", " ").title()})',
fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('Date', fontsize=14)
ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=14)
ax.legend(loc='upper left', fontsize=12)
ax.grid(True, alpha=0.3)
# Format x-axis dates
import matplotlib.dates as mdates
ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(ts) // 10)))
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
# Save forecast results
mode = 'multivariate' if multivariate else 'univariate'
if multivariate and model_type not in ['lstm', 'bilstm', 'gru']:
mode += ' (fallback: model does not support multivariate)'
out_file = ROOT / 'outputs' / f'forecast_{metric}_{model_type}_{mode.replace(" ", "_")}.csv'
out_file.parent.mkdir(exist_ok=True)
fcst.to_csv(out_file, index=False)
status = f"Forecasting completed using {model_type.upper()} ({mode}). Historical data: {len(ts)} days, Forecast: {len(fcst)} days."
if multivariate and model_type not in ['lstm', 'bilstm', 'gru']:
status += " Note: Model does not support multivariate - used univariate instead."
return ts, fcst, fig, status, str(out_file)
except Exception as e:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(14, 7))
ax.text(0.5, 0.5, f'Forecasting Error:\n{str(e)}',
transform=ax.transAxes, ha='center', va='center',
fontsize=14, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral"))
ax.set_title('Time Series Forecast - Error Occurred', fontsize=16, fontweight='bold')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.axis('off')
return pd.DataFrame(), pd.DataFrame(), fig, f'Forecasting failed: {e}', None
model_type_fc.change(fn=update_multivariate_visibility, inputs=[model_type_fc], outputs=[multivariate_fc])
model_type_fc.change(fn=update_model_state, inputs=[model_type_fc], outputs=[current_model_state])
model_type_fc.change(fn=update_forecast_hyperparams_visibility, inputs=[model_type_fc], outputs=[prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout])
run_fc.click(fn=run_forecast_ui, inputs=[csv_in_fc, metric_fc, model_type_fc, periods_fc, multivariate_fc, current_model_state, prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout], outputs=[hist_out, fcst_out, plot_out, fc_status, fc_download])
if __name__ == '__main__':
demo.launch(server_name="0.0.0.0", server_port=7860)