Spaces:
Running
on
A10G
Running
on
A10G
| import glob | |
| import os | |
| import shutil | |
| import sys | |
| import re | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| import gradio as gr | |
| from finetune import finetune_model, log | |
| from language import languages | |
| from task import tasks | |
| import matplotlib.pyplot as plt | |
| def load_markdown(): | |
| with open("intro.md", "r") as f: | |
| return f.read() | |
| def read_logs(temp_dir): | |
| if not os.path.exists(f"{temp_dir}/output.log"): | |
| return "Log file not found." | |
| try: | |
| with open(f"{temp_dir}/output.log", "r") as f: | |
| return f.read() | |
| except: | |
| return None | |
| def plot_loss_acc(temp_dir, log_every): | |
| sys.stdout.flush() | |
| lines = [] | |
| if not os.path.exists(f"{temp_dir}/output.log"): | |
| return None, None | |
| with open(f"{temp_dir}/output.log", "r") as f: | |
| for line in f.readlines(): | |
| if re.match(r"^\[\d+\] - loss: \d+\.\d+ - acc: \d+\.\d+$", line): | |
| lines.append(line) | |
| losses = [] | |
| acces = [] | |
| if len(lines) == 0: | |
| return None, None | |
| for line in lines: | |
| _, loss, acc = line.split(" - ") | |
| losses.append(float(loss.split(":")[1].strip())) | |
| acces.append(float(acc.split(":")[1].strip())) | |
| x = [i * log_every for i in range(1, len(losses) + 1)] | |
| plt.plot(x, losses, label="loss") | |
| plt.xlim(log_every // 2, x[-1] + log_every // 2) | |
| plt.savefig(f"{temp_dir}/loss.png") | |
| plt.clf() | |
| plt.plot(x, acces, label="acc") | |
| plt.xlim(log_every // 2, x[-1] + log_every // 2) | |
| plt.savefig(f"{temp_dir}/acc.png") | |
| plt.clf() | |
| return f"{temp_dir}/acc.png", f"{temp_dir}/loss.png" | |
| def upload_file(fileobj, temp_dir): | |
| """ | |
| Upload a file and check the uploaded zip file. | |
| """ | |
| # First check if a file is a zip file. | |
| if not zipfile.is_zipfile(fileobj.name): | |
| log(temp_dir, "Please upload a zip file.") | |
| raise gr.Error("Please upload a zip file.") | |
| # Then unzip file | |
| log(temp_dir, "Unzipping file...") | |
| shutil.unpack_archive(fileobj.name, temp_dir) | |
| # check zip file | |
| if not os.path.exists(os.path.join(temp_dir, "text")): | |
| log(temp_dir, "Please upload a valid zip file.") | |
| raise gr.Error("Please upload a valid zip file.") | |
| if not os.path.exists(os.path.join(temp_dir, "text_ctc")): | |
| log(temp_dir, "Please upload a valid zip file.") | |
| raise gr.Error("Please upload a valid zip file.") | |
| if not os.path.exists(os.path.join(temp_dir, "audio")): | |
| log(temp_dir, "Please upload a valid zip file.") | |
| raise gr.Error("Please upload a valid zip file.") | |
| # check if all texts and audio matches | |
| log(temp_dir, "Checking if all texts and audio matches...") | |
| audio_ids = [] | |
| with open(os.path.join(temp_dir, "text"), "r") as f: | |
| for line in f.readlines(): | |
| audio_ids.append(line.split(maxsplit=1)[0]) | |
| with open(os.path.join(temp_dir, "text_ctc"), "r") as f: | |
| ctc_audio_ids = [] | |
| for line in f.readlines(): | |
| ctc_audio_ids.append(line.split(maxsplit=1)[0]) | |
| if len(audio_ids) != len(ctc_audio_ids): | |
| raise gr.Error( | |
| f"Length of `text` ({len(audio_ids)}) and `text_ctc` ({len(ctc_audio_ids)}) is different." | |
| ) | |
| if set(audio_ids) != set(ctc_audio_ids): | |
| log(temp_dir, f"`text` and `text_ctc` have different audio ids.") | |
| raise gr.Error(f"`text` and `text_ctc` have different audio ids.") | |
| for audio_id in glob.glob(os.path.join(temp_dir, "audio", "*")): | |
| if not Path(audio_id).stem in audio_ids: | |
| raise gr.Error(f"Audio id {audio_id} is not in `text` or `text_ctc`.") | |
| log(temp_dir, "Successfully uploaded and validated zip file.") | |
| gr.Info("Successfully uploaded and validated zip file.") | |
| return [fileobj] | |
| def delete_tmp_dir(tmp_dir): | |
| if os.path.exists(tmp_dir): | |
| shutil.rmtree(tmp_dir) | |
| print(f"Deleted temporary directory: {tmp_dir}") | |
| else: | |
| print("Temporary directory already deleted") | |
| def create_tmp_dir(): | |
| tmp_dir = tempfile.mkdtemp() | |
| print(f"Created temporary directory: {tmp_dir}") | |
| return tmp_dir | |
| with gr.Blocks(title="OWSM-finetune") as demo: | |
| tempdir_path=gr.State(create_tmp_dir, delete_callback=delete_tmp_dir, time_to_live=600) | |
| gr.Markdown( | |
| """# OWSM finetune demo! | |
| Finetune `owsm_v3.1_ebf_base` with your own dataset! | |
| Due to resource limitation, you can only train 5 epochs on maximum. | |
| ## Upload dataset and define settings | |
| """ | |
| ) | |
| # main contents | |
| with gr.Row(): | |
| with gr.Column(): | |
| file_output = gr.File() | |
| upload_button = gr.UploadButton("Click to Upload a File", file_count="single") | |
| upload_button.upload( | |
| upload_file, [upload_button, tempdir_path], [file_output] | |
| ) | |
| with gr.Column(): | |
| lang = gr.Dropdown( | |
| languages["espnet/owsm_v3.1_ebf_base"], | |
| label="Language", | |
| info="Choose language!", | |
| value="jpn", | |
| interactive=True, | |
| ) | |
| task = gr.Dropdown( | |
| tasks["espnet/owsm_v3.1_ebf_base"], | |
| label="Task", | |
| info="Choose task!", | |
| value="asr", | |
| interactive=True, | |
| ) | |
| gr.Markdown("## Set training settings") | |
| with gr.Row(): | |
| with gr.Column(): | |
| log_every = gr.Number(value=10, label="log_every", interactive=True) | |
| max_epoch = gr.Slider(1, 5, step=1, label="max_epoch", interactive=True) | |
| scheduler = gr.Dropdown( | |
| ["warmuplr"], label="warmup", value="warmuplr", interactive=True | |
| ) | |
| warmup_steps = gr.Number( | |
| value=100, label="warmup_steps", interactive=True | |
| ) | |
| with gr.Column(): | |
| optimizer = gr.Dropdown( | |
| ["adam", "adamw", "sgd", "adadelta", "adagrad", "adamax", "asgd", "rmsprop"], | |
| label="optimizer", | |
| value="adam", | |
| interactive=True | |
| ) | |
| learning_rate = gr.Number( | |
| value=1e-4, label="learning_rate", interactive=True | |
| ) | |
| weight_decay = gr.Number( | |
| value=0.000001, label="weight_decay", interactive=True | |
| ) | |
| gr.Markdown("## Logs and plots") | |
| with gr.Row(): | |
| with gr.Column(): | |
| log_output = gr.Textbox( | |
| show_label=False, | |
| interactive=False, | |
| max_lines=23, | |
| lines=23, | |
| ) | |
| demo.load(read_logs, [tempdir_path], log_output, every=2) | |
| with gr.Column(): | |
| log_acc = gr.Image(label="Accuracy", show_label=True, interactive=False) | |
| log_loss = gr.Image(label="Loss", show_label=True, interactive=False) | |
| demo.load(plot_loss_acc, [tempdir_path, log_every], [log_acc, log_loss], every=10) | |
| with gr.Row(): | |
| with gr.Column(): | |
| ref_text = gr.Textbox( | |
| label="Reference text", | |
| show_label=True, | |
| interactive=False, | |
| max_lines=10, | |
| lines=10, | |
| ) | |
| with gr.Column(): | |
| base_text = gr.Textbox( | |
| label="Baseline text", | |
| show_label=True, | |
| interactive=False, | |
| max_lines=10, | |
| lines=10, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| hyp_text = gr.Textbox( | |
| label="Hypothesis text", | |
| show_label=True, | |
| interactive=False, | |
| max_lines=10, | |
| lines=10, | |
| ) | |
| with gr.Column(): | |
| trained_model = gr.File( | |
| label="Trained model", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| finetune_btn = gr.Button("Finetune Model", variant="primary") | |
| finetune_btn.click( | |
| finetune_model, | |
| [ | |
| lang, | |
| task, | |
| tempdir_path, | |
| log_every, | |
| max_epoch, | |
| scheduler, | |
| warmup_steps, | |
| optimizer, | |
| learning_rate, | |
| weight_decay, | |
| ], | |
| [trained_model, ref_text, base_text, hyp_text] | |
| ) | |
| gr.Markdown(load_markdown()) | |
| if __name__ == "__main__": | |
| try: | |
| demo.queue().launch() | |
| except: | |
| print("Unexpected error:", sys.exc_info()[0]) | |
| raise | |
| finally: | |
| shutil.rmtree(os.environ['TEMP_DIR']) | |