Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from pathlib import Path | |
| from scripts.recommendation import summarize_events | |
| from scripts.data_cleansing import cleanse_data | |
| from dotenv import load_dotenv | |
| import os | |
| import numpy as np | |
| import joblib | |
| ROOT = Path(__file__).resolve().parents[0] | |
| load_dotenv(ROOT / '.env') | |
| def preview_csv(file_obj): | |
| try: | |
| df = pd.read_csv(file_obj.name, dtype=str) | |
| return df.head(10).to_html(index=False) | |
| except Exception as e: | |
| return f"Error reading file: {e}" | |
| def parse_row_selection(df, rows_text: str): | |
| if not rows_text: | |
| return df | |
| idx = [] | |
| for token in rows_text.split(','): | |
| token = token.strip() | |
| if token.isdigit(): | |
| idx.append(int(token)) | |
| return df.iloc[idx] | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# OMS Analyze — Prototype") | |
| gr.Markdown("> Created by PEACE, Powered by AI, Version 0.0.1") | |
| with gr.Tabs(): | |
| # Upload & Preview tab | |
| with gr.TabItem('Upload & Preview'): | |
| gr.Markdown("**Usecase Scenario — Upload & Preview**: อัปโหลดไฟล์ CSV เพื่อตรวจสอบข้อมูลต้นฉบับ ทำความสะอาดข้อมูล (ลบข้อมูลซ้ำ, จัดการค่าที่หายไป) เปรียบเทียบตัวอย่างก่อน/หลัง และดาวน์โหลดไฟล์ที่ทำความสะอาดแล้ว") | |
| csv_up = gr.File(label='Upload CSV (data.csv)') | |
| with gr.Row(): | |
| remove_dup = gr.Checkbox(label='Remove Duplicates', value=False) | |
| missing_handling = gr.Radio(choices=['drop','impute_mean','impute_median','impute_mode'], value='drop', label='Missing Values Handling') | |
| apply_clean = gr.Button('Apply Cleansing') | |
| with gr.Tabs(): | |
| with gr.TabItem('Original Data'): | |
| original_preview = gr.Dataframe(label='Original Data Preview') | |
| with gr.TabItem('Cleansed Data'): | |
| cleansed_preview = gr.Dataframe(label='Cleansed Data Preview') | |
| download_cleansed = gr.File(label='Download Cleansed CSV') | |
| clean_status = gr.Textbox(label='Cleansing Status', interactive=False) | |
| def initial_preview(file): | |
| if file is None: | |
| return pd.DataFrame(), pd.DataFrame(), "Upload a file" | |
| df = pd.read_csv(file.name, dtype=str) | |
| return df.head(100), pd.DataFrame(), "File uploaded, apply cleansing if needed" | |
| def apply_cleansing(file, remove_duplicates, missing_strategy): | |
| if file is None: | |
| return pd.DataFrame(), "No file", None | |
| try: | |
| df = pd.read_csv(file.name, dtype=str) | |
| df_clean, orig_shape, clean_shape = cleanse_data(df, remove_duplicates, missing_strategy) | |
| status = f"Original: {orig_shape[0]} rows, {orig_shape[1]} cols → Cleaned: {clean_shape[0]} rows, {clean_shape[1]} cols" | |
| # Save cleansed data for download | |
| out_file = ROOT / 'outputs' / 'cleansed_data.csv' | |
| out_file.parent.mkdir(exist_ok=True) | |
| df_clean.to_csv(out_file, index=False, encoding='utf-8-sig') | |
| return df_clean.head(100), status, str(out_file) | |
| except Exception as e: | |
| return pd.DataFrame(), f"Error: {e}", None | |
| csv_up.change(fn=initial_preview, inputs=csv_up, outputs=[original_preview, cleansed_preview, clean_status]) | |
| apply_clean.click(fn=apply_cleansing, inputs=[csv_up, remove_dup, missing_handling], outputs=[cleansed_preview, clean_status, download_cleansed]) | |
| # Summary tab | |
| with gr.TabItem('Summary'): | |
| gr.Markdown("**Usecase Scenario — Summary**: สร้างสรุปภาพรวมของชุดข้อมูลทั้งหมด รวมสถิติพื้นฐาน และคำนวณดัชนีความน่าเชื่อถือ (เช่น SAIFI, SAIDI, CAIDI) พร้อมตัวเลือกใช้ Generative AI ในการขยายความ") | |
| csv_in_sum = gr.File(label='Upload CSV for Overall Summary') | |
| with gr.Row(): | |
| use_hf_sum = gr.Checkbox(label='Use Generative AI for Summary', value=False) | |
| total_customers = gr.Number(label='Total Customers (for reliability calculation)', value=500000, precision=0) | |
| run_sum = gr.Button('Generate Overall Summary') | |
| with gr.Row(): | |
| model_selector_sum = gr.Dropdown( | |
| choices=[ | |
| 'meta-llama/Llama-3.1-8B-Instruct:novita', | |
| 'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita', | |
| 'Qwen/Qwen3-VL-235B-A22B-Instruct:novita', | |
| 'deepseek-ai/DeepSeek-R1:novita', | |
| 'moonshotai/Kimi-K2-Instruct-0905:novita' | |
| ], | |
| value='meta-llama/Llama-3.1-8B-Instruct:novita', | |
| label='GenAI Model', | |
| interactive=True, | |
| visible=False | |
| ) | |
| with gr.Tabs(): | |
| with gr.TabItem('AI Summary'): | |
| ai_summary_out = gr.Textbox(label='AI Generated Summary', lines=10) | |
| with gr.TabItem('Basic Statistics'): | |
| basic_stats_out = gr.JSON(label='Basic Statistics') | |
| with gr.TabItem('Reliability Indices'): | |
| reliability_out = gr.Dataframe(label='Reliability Metrics') | |
| sum_status = gr.Textbox(label='Summary Status', interactive=False) | |
| def run_overall_summary(file, use_hf_flag, total_cust, model): | |
| if file is None: | |
| return {}, {}, pd.DataFrame(), 'No file provided' | |
| try: | |
| from scripts.summary import summarize_overall | |
| df = pd.read_csv(file.name, dtype=str) | |
| result = summarize_overall(df, use_hf=use_hf_flag, model=model, total_customers=total_cust) | |
| # Prepare outputs | |
| ai_summary = result.get('ai_summary', 'ไม่สามารถสร้างสรุปด้วย AI ได้') | |
| basic_stats = { | |
| 'total_events': result.get('total_events'), | |
| 'date_range': result.get('date_range'), | |
| 'event_types': result.get('event_types'), | |
| 'total_affected_customers': result.get('total_affected_customers') | |
| } | |
| # Reliability metrics as DataFrame | |
| reliability_df = result.get('reliability_df', pd.DataFrame()) | |
| status = f"Summary generated for {len(df)} events. AI used: {use_hf_flag}" | |
| return ai_summary, basic_stats, reliability_df, status | |
| except Exception as e: | |
| return f"Error: {str(e)}", {}, pd.DataFrame(), f'Summary failed: {e}' | |
| def update_model_visibility_sum(use_hf_flag): | |
| return gr.update(visible=use_hf_flag, interactive=use_hf_flag) | |
| use_hf_sum.change(fn=update_model_visibility_sum, inputs=use_hf_sum, outputs=model_selector_sum) | |
| run_sum.click(fn=run_overall_summary, inputs=[csv_in_sum, use_hf_sum, total_customers, model_selector_sum], outputs=[ai_summary_out, basic_stats_out, reliability_out, sum_status]) | |
| # Recommendation tab | |
| with gr.TabItem('Recommendation'): | |
| gr.Markdown("**Usecase Scenario — Recommendation**: สร้างสรุปเหตุการณ์ (เช่น สรุปเหตุการณ์ไฟฟ้าขัอข้องหรือบำรุงรักษา) สำหรับแถวที่เลือก ปรับระดับรายละเอียด และเลือกใช้ Generative AI เพื่อเพิ่มความชัดเจน พร้อมดาวน์โหลดไฟล์สรุป") | |
| csv_in = gr.File(label='Upload CSV (data.csv)') | |
| with gr.Row(): | |
| rows = gr.Textbox(label='Rows (comma-separated indexes) or empty = all', placeholder='e.g. 0,1,2') | |
| use_hf = gr.Checkbox(label='Use Generative AI', value=False) | |
| verbosity = gr.Radio(choices=['analyze','recommend'], value='analyze', label='Summary Type', interactive=True) | |
| run_btn = gr.Button('Generate Summaries', interactive=True) | |
| with gr.Row(): | |
| model_selector = gr.Dropdown( | |
| choices=[ | |
| 'meta-llama/Llama-3.1-8B-Instruct:novita', | |
| 'meta-llama/Llama-4-Scout-17B-16E-Instruct:novita', | |
| 'Qwen/Qwen3-VL-235B-A22B-Instruct:novita', | |
| 'deepseek-ai/DeepSeek-R1:novita', | |
| 'moonshotai/Kimi-K2-Instruct-0905:novita' | |
| ], | |
| value='meta-llama/Llama-3.1-8B-Instruct:novita', | |
| label='GenAI Model', | |
| interactive=True, | |
| visible=False | |
| ) | |
| out = gr.Dataframe(headers=['EventNumber','OutageDateTime','Summary']) | |
| status = gr.Textbox(label='Status', interactive=False) | |
| download = gr.File(label='Download summaries') | |
| def run_summarize(file, rows_text, use_hf_flag, verbosity_level, model): | |
| print(f"Debug: file={file}, rows_text={rows_text}, use_hf_flag={use_hf_flag}, verbosity_level={verbosity_level}, model={model}") | |
| if file is None: | |
| return pd.DataFrame([], columns=['EventNumber','OutageDateTime','Summary']), 'No file provided', None | |
| df = pd.read_csv(file.name, dtype=str) | |
| df_sel = parse_row_selection(df, rows_text) | |
| res = summarize_events(df_sel, use_hf=use_hf_flag, verbosity=verbosity_level, model=model) | |
| out_df = pd.DataFrame(res) | |
| out_file = ROOT / 'outputs' / 'summaries_from_ui.csv' | |
| out_file.parent.mkdir(exist_ok=True) | |
| out_df.to_csv(out_file, index=False, encoding='utf-8-sig') | |
| status_text = f"Summaries generated: {len(out_df)} rows. HF used: {use_hf_flag}" | |
| return out_df, status_text, str(out_file) | |
| def update_model_visibility(use_hf_flag): | |
| return gr.update(visible=use_hf_flag, interactive=use_hf_flag) | |
| use_hf.change(fn=update_model_visibility, inputs=use_hf, outputs=model_selector) | |
| run_btn.click(fn=run_summarize, inputs=[csv_in, rows, use_hf, verbosity, model_selector], outputs=[out, status, download]) | |
| # Anomaly Detection tab | |
| with gr.TabItem('Anomaly Detection'): | |
| gr.Markdown("**Usecase Scenario — Anomaly Detection**: ตรวจจับเหตุการณ์ที่มีพฤติกรรมผิดปกติในชุดข้อมูล (เช่น เหตุการณ์ที่มีค่าสูง/ต่ำผิดปกติ) โดยใช้หลาย algorithm ปรับระดับ contamination และส่งออกผลลัพธ์พร้อมธงความผิดปกติ") | |
| csv_in_anom = gr.File(label='Upload CSV for Anomaly') | |
| with gr.Row(): | |
| alg = gr.Radio(choices=['iso+lof','iso','lof','autoencoder'], value='iso+lof', label='Algorithm') | |
| contamination = gr.Slider(minimum=0.01, maximum=0.2, value=0.05, step=0.01, label='Contamination') | |
| run_anom = gr.Button('Run Anomaly Detection') | |
| anom_out = gr.Dataframe() | |
| anom_status = gr.Textbox(label='Anomaly Status', interactive=False) | |
| anom_download = gr.File(label='Download anomalies CSV') | |
| def run_anomaly_ui(file, algorithm, contamination): | |
| if file is None: | |
| return pd.DataFrame(), 'No file provided', None | |
| from scripts.anomaly import detect_anomalies | |
| df = pd.read_csv(file.name, dtype=str) | |
| res = detect_anomalies(df, contamination=contamination, algorithm=algorithm) | |
| # Reorder columns to put ensemble_flag and final_flag at the end | |
| cols = [c for c in res.columns if c not in ['ensemble_flag', 'final_flag']] + ['ensemble_flag', 'final_flag'] | |
| res = res[cols] | |
| out_file = ROOT / 'outputs' / 'anomalies_from_ui.csv' | |
| out_file.parent.mkdir(exist_ok=True) | |
| res.to_csv(out_file, index=False, encoding='utf-8-sig') | |
| status = f"Anomaly detection done. Rows: {len(res)}. Flags: {res['final_flag'].sum()}" | |
| return res, status, str(out_file) | |
| run_anom.click(fn=run_anomaly_ui, inputs=[csv_in_anom, alg, contamination], outputs=[anom_out, anom_status, anom_download]) | |
| # Classification tab | |
| with gr.TabItem('Classification'): | |
| gr.Markdown("**Usecase Scenario — Classification**: ฝึกและทดสอบโมเดลเพื่อจำแนกสาเหตุของเหตุการณ์ กำหนดคอลัมน์เป้าหมาย ปรับ hyperparameters, เปิดใช้งาน weak-labeling และดาวน์โหลดโมเดล/ผลการทำนาย") | |
| csv_in_cls = gr.File(label='Upload CSV for Classification') | |
| with gr.Row(): | |
| label_col = gr.Dropdown(choices=['CauseType','SubCauseType'], value='CauseType', label='Target Column') | |
| do_weak = gr.Checkbox(label='Run weak-labeling using HF (requires HF_TOKEN)', value=False) | |
| model_type = gr.Radio(choices=['rf','gb','mlp'], value='rf', label='Model Type') | |
| run_cls = gr.Button('Train Classifier') | |
| def update_hyperparams_visibility(model_choice): | |
| rf_visible = model_choice == 'rf' | |
| gb_visible = model_choice == 'gb' | |
| mlp_visible = model_choice == 'mlp' | |
| return [ | |
| gr.update(visible=rf_visible), | |
| gr.update(visible=rf_visible), | |
| gr.update(visible=rf_visible), | |
| gr.update(visible=rf_visible), | |
| gr.update(visible=gb_visible), | |
| gr.update(visible=gb_visible), | |
| gr.update(visible=gb_visible), | |
| gr.update(visible=mlp_visible), | |
| gr.update(visible=mlp_visible), | |
| gr.update(visible=mlp_visible), | |
| ] | |
| with gr.Accordion("Hyperparameters (Advanced)", open=False): | |
| gr.Markdown("Adjust hyperparameters for the selected model. Defaults are set for good performance.") | |
| rf_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="RF: n_estimators", visible=True) | |
| rf_max_depth = gr.Slider(minimum=5, maximum=50, value=10, step=1, label="RF: max_depth", visible=True) | |
| rf_min_samples_split = gr.Slider(minimum=2, maximum=10, value=2, step=1, label="RF: min_samples_split", visible=True) | |
| rf_min_samples_leaf = gr.Slider(minimum=1, maximum=5, value=1, step=1, label="RF: min_samples_leaf", visible=True) | |
| gb_n_estimators = gr.Slider(minimum=50, maximum=500, value=100, step=10, label="GB: n_estimators", visible=False) | |
| gb_max_depth = gr.Slider(minimum=3, maximum=20, value=3, step=1, label="GB: max_depth", visible=False) | |
| gb_learning_rate = gr.Slider(minimum=0.01, maximum=0.3, value=0.1, step=0.01, label="GB: learning_rate", visible=False) | |
| mlp_hidden_layer_sizes = gr.Textbox(value="(100,)", label="MLP: hidden_layer_sizes (tuple)", visible=False) | |
| mlp_alpha = gr.Slider(minimum=0.0001, maximum=0.01, value=0.0001, step=0.0001, label="MLP: alpha", visible=False) | |
| mlp_max_iter = gr.Slider(minimum=100, maximum=4000, value=500, step=50, label="MLP: max_iter", visible=False) | |
| model_type.change(fn=update_hyperparams_visibility, inputs=model_type, outputs=[rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter]) | |
| cls_out = gr.Textbox(label='Classification Report') | |
| model_path_state = gr.State() | |
| cls_download_model = gr.File(label='Download saved model') | |
| cls_download_preds = gr.File(label='Download predictions CSV') | |
| # Test section | |
| gr.Markdown("---") | |
| gr.Markdown("**ทดสอบโมเดล**: อัปโหลดไฟล์ CSV ใหม่เพื่อทดสอบโมเดลที่ฝึกแล้ว") | |
| test_csv = gr.File(label='Upload CSV for Testing') | |
| run_test = gr.Button('Test Model') | |
| test_out = gr.Dataframe(label='Test Predictions') | |
| test_status = gr.Textbox(label='Test Status', interactive=False) | |
| test_download = gr.File(label='Download Test Predictions') | |
| def run_classify_ui(file, label_col_choice, use_weak, model_choice, rf_n_est, rf_max_d, rf_min_ss, rf_min_sl, gb_n_est, gb_max_d, gb_lr, mlp_hls, mlp_a, mlp_mi): | |
| if file is None: | |
| return 'No file provided', None, None, None | |
| from scripts.classify import train_classifier | |
| df = pd.read_csv(file.name, dtype=str) | |
| try: | |
| hyperparams = {} | |
| if model_choice == 'rf': | |
| hyperparams = {'n_estimators': int(rf_n_est), 'max_depth': int(rf_max_d), 'min_samples_split': int(rf_min_ss), 'min_samples_leaf': int(rf_min_sl)} | |
| elif model_choice == 'gb': | |
| hyperparams = {'n_estimators': int(gb_n_est), 'max_depth': int(gb_max_d), 'learning_rate': gb_lr} | |
| elif model_choice == 'mlp': | |
| import ast | |
| hyperparams = {'hidden_layer_sizes': ast.literal_eval(mlp_hls), 'alpha': mlp_a, 'max_iter': int(mlp_mi)} | |
| res = train_classifier(df, label_col=label_col_choice, model_type=model_choice, hyperparams=hyperparams) | |
| report = res.get('report','') | |
| model_file = res.get('model_file') | |
| preds_file = res.get('predictions_file') | |
| # ensure returned file paths are strings for Gradio | |
| return report, model_file, preds_file, model_file | |
| except Exception as e: | |
| return f'Training failed: {e}', None, None, None | |
| def run_test_ui(test_file, model_path): | |
| if test_file is None: | |
| return pd.DataFrame(), 'No test file provided', None | |
| if model_path is None: | |
| return pd.DataFrame(), 'No trained model available. Please train a model first.', None | |
| try: | |
| from scripts.classify import parse_and_features | |
| # Load model | |
| model_data = joblib.load(model_path) | |
| pipeline = model_data['pipeline'] | |
| le = model_data['label_encoder'] | |
| # Load and preprocess test data | |
| df_test = pd.read_csv(test_file.name, dtype=str) | |
| df_test = parse_and_features(df_test) | |
| # Define features (same as training) | |
| feature_cols = ['duration_min','Load(MW)_num','Capacity(kVA)_num','AffectedCustomer_num','hour','weekday','device_freq','OpDeviceType','Owner','Weather','EventType'] | |
| X_test = df_test[feature_cols] | |
| # Predict | |
| y_pred_encoded = pipeline.predict(X_test) | |
| y_pred = le.inverse_transform(y_pred_encoded) | |
| # Create output df | |
| pred_df = df_test.copy() | |
| pred_df['Predicted_CauseType'] = y_pred | |
| # Save predictions | |
| out_file = ROOT / 'outputs' / 'test_predictions.csv' | |
| out_file.parent.mkdir(exist_ok=True) | |
| pred_df.to_csv(out_file, index=False, encoding='utf-8-sig') | |
| status = f"Test completed. Predictions for {len(pred_df)} rows." | |
| return pred_df.head(100), status, str(out_file) | |
| except Exception as e: | |
| return pd.DataFrame(), f'Test failed: {e}', None | |
| run_cls.click(fn=run_classify_ui, inputs=[csv_in_cls, label_col, do_weak, model_type, rf_n_estimators, rf_max_depth, rf_min_samples_split, rf_min_samples_leaf, gb_n_estimators, gb_max_depth, gb_learning_rate, mlp_hidden_layer_sizes, mlp_alpha, mlp_max_iter], outputs=[cls_out, cls_download_model, cls_download_preds, model_path_state]) | |
| run_test.click(fn=run_test_ui, inputs=[test_csv, model_path_state], outputs=[test_out, test_status, test_download]) | |
| # Label Suggestion tab | |
| with gr.TabItem('Label Suggestion'): | |
| gr.Markdown("**Usecase Scenario — Label Suggestion**: ให้คำแนะนำป้ายกำกับสาเหตุที่เป็นไปได้สำหรับเหตุการณ์ที่ไม่มีฉลาก โดยเทียบความคล้ายกับตัวอย่างที่มีฉลาก ปรับจำนวนคำแนะนำสูงสุด และส่งออกเป็นไฟล์ CSV") | |
| csv_in_ls = gr.File(label='Upload CSV (defaults to data/data_3.csv)') | |
| with gr.Row(): | |
| top_k = gr.Slider(minimum=1, maximum=5, value=1, step=1, label='Top K suggestions') | |
| run_ls = gr.Button('Run Label Suggestion') | |
| ls_out = gr.Dataframe() | |
| ls_status = gr.Textbox(label='Label Suggestion Status', interactive=False) | |
| ls_download = gr.File(label='Download label suggestions') | |
| def run_label_suggestion(file, top_k_suggest): | |
| # delegate to scripts.label_suggestion | |
| from scripts.label_suggestion import suggest_labels_to_file | |
| if file is None: | |
| default = ROOT / 'data' / 'data_3.csv' | |
| if not default.exists(): | |
| return pd.DataFrame(), 'No file provided and default data/data_3.csv not found', None | |
| df = pd.read_csv(default, dtype=str) | |
| else: | |
| df = pd.read_csv(file.name, dtype=str) | |
| out_file = ROOT / 'outputs' / 'label_suggestions.csv' | |
| out_df = suggest_labels_to_file(df, out_path=str(out_file), top_k=int(top_k_suggest)) | |
| status = f"Label suggestion done. Unknown rows processed: {len(out_df)}. Output: {out_file}" | |
| return out_df, status, str(out_file) if len(out_df)>0 else None | |
| run_ls.click(fn=run_label_suggestion, inputs=[csv_in_ls, top_k], outputs=[ls_out, ls_status, ls_download]) | |
| # Forecasting tab | |
| with gr.TabItem('Forecasting'): | |
| gr.Markdown("**Usecase Scenario — Forecasting**: พยากรณ์จำนวนเหตุการณ์หรือเวลาหยุดทำงานในอนาคตโดยเลือกโมเดล (Prophet, LSTM, Bi-LSTM, GRU, Naive) ปรับพารามิเตอร์ และส่งออกผลการพยากรณ์") | |
| gr.Markdown("*Multivariate forecasting (ใช้หลายฟีเจอร์) รองรับเฉพาะโมเดล LSTM, Bi-LSTM, GRU เท่านั้น*") | |
| csv_in_fc = gr.File(label='Upload CSV for Forecasting') | |
| with gr.Row(): | |
| metric_fc = gr.Radio(choices=['count','downtime_minutes'], value='count', label='Metric to Forecast') | |
| model_type_fc = gr.Radio(choices=['prophet','lstm','bilstm','gru','naive'], value='lstm', label='Forecasting Model', elem_id='forecast_model_radio') | |
| periods_fc = gr.Slider(minimum=1, maximum=30, value=7, step=1, label='Forecast Periods (days)') | |
| multivariate_fc = gr.Checkbox(value=False, label='Use Multivariate (Multiple Features)', interactive=False) | |
| run_fc = gr.Button('Run Forecasting') | |
| # Add state to track current model | |
| current_model_state = gr.State(value='lstm') | |
| def update_multivariate_visibility(model_choice): | |
| # Multivariate is only supported for LSTM, Bi-LSTM, GRU | |
| supported_models = ['lstm', 'bilstm', 'gru'] | |
| is_supported = model_choice in supported_models | |
| return gr.update(interactive=is_supported, value=False) | |
| def update_model_state(model_choice): | |
| return model_choice | |
| # Hyperparameter controls for forecasting | |
| with gr.Accordion("Hyperparameters (Advanced)", open=False): | |
| gr.Markdown("Adjust hyperparameters for the selected forecasting model. Defaults are set for good performance.") | |
| # Prophet hyperparameters | |
| prophet_changepoint_prior = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Prophet: changepoint_prior_scale", visible=False) | |
| prophet_seasonality_prior = gr.Slider(minimum=0.01, maximum=10.0, value=10.0, step=0.1, label="Prophet: seasonality_prior_scale", visible=False) | |
| prophet_seasonality_mode = gr.Radio(choices=['additive', 'multiplicative'], value='additive', label="Prophet: seasonality_mode", visible=False) | |
| # Deep learning hyperparameters (LSTM, Bi-LSTM, GRU) | |
| dl_seq_length = gr.Slider(minimum=3, maximum=30, value=7, step=1, label="DL: sequence_length (lag/input length)", visible=True) | |
| dl_epochs = gr.Slider(minimum=10, maximum=200, value=100, step=10, label="DL: epochs", visible=True) | |
| dl_batch_size = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="DL: batch_size", visible=True) | |
| dl_learning_rate = gr.Slider(minimum=0.0001, maximum=0.01, value=0.001, step=0.0001, label="DL: learning_rate", visible=True) | |
| dl_units = gr.Slider(minimum=32, maximum=256, value=100, step=16, label="DL: units (neurons)", visible=True) | |
| dl_dropout = gr.Slider(minimum=0.0, maximum=0.5, value=0.2, step=0.05, label="DL: dropout_rate", visible=True) | |
| # Naive has no hyperparameters | |
| def update_forecast_hyperparams_visibility(model_choice): | |
| prophet_visible = model_choice == 'prophet' | |
| dl_visible = model_choice in ['lstm', 'bilstm', 'gru'] | |
| return [ | |
| gr.update(visible=prophet_visible), # prophet_changepoint_prior | |
| gr.update(visible=prophet_visible), # prophet_seasonality_prior | |
| gr.update(visible=prophet_visible), # prophet_seasonality_mode | |
| gr.update(visible=dl_visible), # dl_seq_length | |
| gr.update(visible=dl_visible), # dl_epochs | |
| gr.update(visible=dl_visible), # dl_batch_size | |
| gr.update(visible=dl_visible), # dl_learning_rate | |
| gr.update(visible=dl_visible), # dl_units | |
| gr.update(visible=dl_visible), # dl_dropout | |
| ] | |
| with gr.Tabs(): | |
| with gr.TabItem('Historical Data'): | |
| hist_out = gr.Dataframe(label='Historical Time Series Data') | |
| with gr.TabItem('Forecast Results'): | |
| fcst_out = gr.Dataframe(label='Forecast Results') | |
| with gr.TabItem('Time Series Plot'): | |
| plot_out = gr.Plot(label='Historical + Forecast Plot') | |
| fc_status = gr.Textbox(label='Forecast Status', interactive=False) | |
| fc_download = gr.File(label='Download forecast CSV') | |
| def run_forecast_ui(file, metric, model_type, periods, multivariate, current_model, prophet_cp, prophet_sp, prophet_sm, dl_sl, dl_e, dl_bs, dl_lr, dl_u, dl_d): | |
| # Use current_model if available, otherwise use model_type | |
| actual_model = current_model if current_model else model_type | |
| if file is None: | |
| return pd.DataFrame(), pd.DataFrame(), None, 'No file provided', None | |
| try: | |
| from scripts.forecast import run_forecast | |
| import matplotlib.pyplot as plt | |
| df = pd.read_csv(file.name, dtype=str) | |
| # Build hyperparams dict based on model type | |
| hyperparams = {} | |
| if actual_model == 'prophet': | |
| hyperparams = { | |
| 'changepoint_prior_scale': prophet_cp, | |
| 'seasonality_prior_scale': prophet_sp, | |
| 'seasonality_mode': prophet_sm | |
| } | |
| elif actual_model in ['lstm', 'bilstm', 'gru']: | |
| hyperparams = { | |
| 'seq_length': int(dl_sl), | |
| 'epochs': int(dl_e), | |
| 'batch_size': int(dl_bs), | |
| 'learning_rate': dl_lr, | |
| 'units': int(dl_u), | |
| 'dropout_rate': dl_d | |
| } | |
| ts, fcst = run_forecast(df, metric=metric, periods=periods, model_type=actual_model, multivariate=multivariate, hyperparams=hyperparams) | |
| # Create time series plot | |
| fig, ax = plt.subplots(figsize=(14, 7)) | |
| # Plot historical data | |
| if len(ts) > 0 and 'y' in ts.columns: | |
| ax.plot(ts['ds'], ts['y'], 'b-', label='Historical Data', linewidth=2, marker='o', markersize=4) | |
| # Plot forecast data | |
| if len(fcst) > 0 and 'yhat' in fcst.columns: | |
| ax.plot(fcst['ds'], fcst['yhat'], 'r--', label='Forecast', linewidth=3, marker='s', markersize=5) | |
| if 'yhat_lower' in fcst.columns and 'yhat_upper' in fcst.columns: | |
| ax.fill_between(fcst['ds'], fcst['yhat_lower'], fcst['yhat_upper'], | |
| color='red', alpha=0.3, label='Confidence Interval') | |
| # Add vertical line to separate historical from forecast | |
| if len(ts) > 0 and len(fcst) > 0: | |
| last_hist_date = ts['ds'].max() | |
| ax.axvline(x=last_hist_date, color='gray', linestyle='--', alpha=0.7, label='Forecast Start') | |
| ax.set_title(f'Time Series Forecast: {model_type.upper()} ({metric.replace("_", " ").title()})', | |
| fontsize=16, fontweight='bold', pad=20) | |
| ax.set_xlabel('Date', fontsize=14) | |
| ax.set_ylabel(metric.replace('_', ' ').title(), fontsize=14) | |
| ax.legend(loc='upper left', fontsize=12) | |
| ax.grid(True, alpha=0.3) | |
| # Format x-axis dates | |
| import matplotlib.dates as mdates | |
| ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) | |
| ax.xaxis.set_major_locator(mdates.DayLocator(interval=max(1, len(ts) // 10))) | |
| plt.xticks(rotation=45, ha='right') | |
| plt.tight_layout() | |
| # Save forecast results | |
| mode = 'multivariate' if multivariate else 'univariate' | |
| if multivariate and model_type not in ['lstm', 'bilstm', 'gru']: | |
| mode += ' (fallback: model does not support multivariate)' | |
| out_file = ROOT / 'outputs' / f'forecast_{metric}_{model_type}_{mode.replace(" ", "_")}.csv' | |
| out_file.parent.mkdir(exist_ok=True) | |
| fcst.to_csv(out_file, index=False) | |
| status = f"Forecasting completed using {model_type.upper()} ({mode}). Historical data: {len(ts)} days, Forecast: {len(fcst)} days." | |
| if multivariate and model_type not in ['lstm', 'bilstm', 'gru']: | |
| status += " Note: Model does not support multivariate - used univariate instead." | |
| return ts, fcst, fig, status, str(out_file) | |
| except Exception as e: | |
| import matplotlib.pyplot as plt | |
| fig, ax = plt.subplots(figsize=(14, 7)) | |
| ax.text(0.5, 0.5, f'Forecasting Error:\n{str(e)}', | |
| transform=ax.transAxes, ha='center', va='center', | |
| fontsize=14, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightcoral")) | |
| ax.set_title('Time Series Forecast - Error Occurred', fontsize=16, fontweight='bold') | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| plt.axis('off') | |
| return pd.DataFrame(), pd.DataFrame(), fig, f'Forecasting failed: {e}', None | |
| model_type_fc.change(fn=update_multivariate_visibility, inputs=[model_type_fc], outputs=[multivariate_fc]) | |
| model_type_fc.change(fn=update_model_state, inputs=[model_type_fc], outputs=[current_model_state]) | |
| model_type_fc.change(fn=update_forecast_hyperparams_visibility, inputs=[model_type_fc], outputs=[prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout]) | |
| run_fc.click(fn=run_forecast_ui, inputs=[csv_in_fc, metric_fc, model_type_fc, periods_fc, multivariate_fc, current_model_state, prophet_changepoint_prior, prophet_seasonality_prior, prophet_seasonality_mode, dl_seq_length, dl_epochs, dl_batch_size, dl_learning_rate, dl_units, dl_dropout], outputs=[hist_out, fcst_out, plot_out, fc_status, fc_download]) | |
| if __name__ == '__main__': | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |