--- library_name: peft license: llama3.1 base_model: tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5 tags: - generated_from_trainer datasets: - Coldog2333/JMedBench - humanalysis-square/KokushiMD-10 - EQUES/YakugakuQA - Henrychur/MMedBench - IgakuQA model-index: - name: outputs/qlora-out_swallow-8b_4 results: [] language: - ja metrics: - accuracy --- # MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 **MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5** is a fine-tuned QLora of [tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5](https://huggingface.co/tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5) to enhance capability for Japanese Medical Exam. - We trained a QLora for Japanese medical exams: **MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5**. - Since Llama-3.1 provides strong foundational capabilities, and Llama-3.1-Swallow-8B-Instruct-v0.5 is well fine-tuned for Japanese, we chose them as our main base model. - We used multiple-choice question data from the MIT-licensed portion of JMedBench (218,912 samples) and pharmacy-related data from KokushiMD-10 (1,386 samples) for training model. We evaluated the model on IgakuQA, which includes 2,000 medical exam questions from 2018–2022. Given the relatively large training set, we trained for only one epoch and adopted the lightweight QLoRA technique. - After fine-tuning, model’s accuracy improved from 55.75% to 62.40%, a gain of **6.65%**. Despite being only an **8B** model, it outperforms the ChatGPT baselines reported in IgakuQA (ChatGPT: 53.95%, Translate-ChatGPT: 56.60%). ## Model Overview - Developer: Ingenta Inc. - Base Model: [Llama-3.1-Swallow-8B-Instruct-v0.5](https://huggingface.co/tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5), [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) - Training Tool: [Axolotl](https://axolotl.ai/) - Supported Languages: Japanese - License: [Llama 3.1](https://www.llama.com/llama3_1/license/) and [Gemma](https://ai.google.dev/gemma/terms) ### Base Model Reference The following tabel show the model related to our work | Used Model | License | |---------|----------------| | [Llama-3.1-Swallow-8B-Instruct-v0.5](https://huggingface.co/tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5) | [Llama 3.1](https://www.llama.com/llama3_1/license/) and [Gemma](https://ai.google.dev/gemma/terms) | | [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)| [Llama 3.1](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE) | ## Training Configure We use axolotl to train QLora. [Built with Axolotl](https://github.com/axolotl-ai-cloud/axolotl)
See axolotl training config axolotl version: `0.10.0` ```yaml base_model: tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5 # optionally might have model_type or tokenizer_type model_type: AutoModelForCausalLM tokenizer_type: AutoTokenizer # Automatically upload checkpoint and final model to HF # hub_model_id: username/custom_model_name load_in_8bit: false load_in_4bit: true datasets: - path: merged_medical_qa_MIT.json type: alpaca dataset_prepared_path: val_set_size: 0 output_dir: ./outputs/qlora-out_swallow-8b adapter: qlora lora_model_dir: sequence_len: 2048 sample_packing: true pad_to_sequence_len: true lora_r: 32 lora_alpha: 16 lora_dropout: 0.05 lora_target_linear: true wandb_project: wandb_entity: wandb_watch: wandb_name: wandb_log_model: gradient_accumulation_steps: 4 micro_batch_size: 2 num_epochs: 1 optimizer: paged_adamw_32bit lr_scheduler: cosine learning_rate: 0.0002 bf16: auto tf32: false gradient_checkpointing: true resume_from_checkpoint: logging_steps: 1 flash_attention: true warmup_steps: 10 evals_per_epoch: 4 saves_per_epoch: 1 weight_decay: 0.0 special_tokens: pad_token: "<|end_of_text|>" ```

## Usage - Preparing environment ```bash conda create --name medexamdoc python=3.11 conda activate medexamdoc pip install PEFT==0.15.2 Transformers==4.52.3 torch==2.5.1 Datasets==3.6.0 Tokenizers==0.21.2 hf download IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 ``` - Python inference ```python from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import torch import re import time # 基本モデル読み込み model_path = "IngentaAITeam/MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5" def load_model(model_name, device="cuda"): """モデルを読み込み""" print(f"モデルを読み込み中:{model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None ) model.eval() return tokenizer, model def create_medical_prompt_template(): """医学問題のプロンプトテンプレートを作成""" template = """Answer this medical multiple choice question by selecting the correct option letter (A, B, C, D, or E). Question: {question} Options: {options} Answer:""" return template def format_question(question_data): """問題をフォーマット""" question = question_data['question'] options = question_data['options'] # 選択肢をフォーマット(訓練データ形式と完全一致) options_text = "" for key, value in options.items(): options_text += f"{key}. {value}\n" template = create_medical_prompt_template() prompt = template.format( question=question, options=options_text.rstrip() # 最後の改行を削除 ) return prompt def extract_answer(response): """モデルの回答から選択肢の文字を抽出""" pattern = r'\b[A-E]\b' matches = re.findall(pattern, response.upper()) return matches[0] if matches else None def extract_multiple_answers(response): """複数選択肢の回答を抽出し、重複を除去""" pattern = r'\b[A-E]\b' matches = re.findall(pattern, response.upper()) # 重複を除去し、ソート unique_matches = list(dict.fromkeys(matches)) # 順序を保持して重複除去 return ''.join(sorted(unique_matches)) if unique_matches else None def generate_answer(tokenizer, model, prompt, device="cuda"): """回答を生成""" inputs = tokenizer(prompt, return_tensors="pt").to(device) # 推論時間を測定 start_time = time.time() with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=64, temperature=0.0, do_sample=False, pad_token_id=tokenizer.eos_token_id ) end_time = time.time() response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) inference_time = end_time - start_time return response.strip(), inference_time # 使用例 tokenizer, model = load_model(model_path) # 医学問題の例 question_data = { 'question': 'What is the most common cause of acute myocardial infarction?', 'options': { 'A': 'Coronary artery spasm', 'B': 'Atherosclerotic plaque rupture', 'C': 'Coronary artery embolism', 'D': 'Coronary artery dissection', 'E': 'Takotsubo cardiomyopathy' } } # プロンプトをフォーマット prompt = format_question(question_data) print("プロンプト:") print(prompt) print("---" * 40) # 回答を生成 response, inference_time = generate_answer(tokenizer, model, prompt) predicted_answer = extract_answer(response) print("モデル回答:") print(response) print(f"抽出された選択肢: {predicted_answer}") print(f"推論時間: {inference_time:.3f}秒") ``` ## Training and evaluation data ### Training Data | Dataset | # of Training Samples | Data Source | License | Data Selection Method | Version(Commit ID) | |---------|------------------------|-------------|----------------|-----------------------|-----------------------| | [JMedBench](https://huggingface.co/datasets/Coldog2333/JMedBench) | 218,912 | - medmcqa_jp (translated from MedMCQA)
- usmleqa_jp (translated from MedQA)
- medqa_jp (translated from MedQA)
- mmlu_medical_jp (translated from MMLU)
- pubmedqa_jp (translated from PubMedQA) | MIT | MultipleChoiceQA |fe772d4fb76c11a4b24e06a2d06c72a7e3e32ef5| | [KokushiMD-10](https://huggingface.co/datasets/humanalysis-square/KokushiMD-10) | 1,386 | Japanese national healthcare licensing examinations (2020–2024) | MIT | text_only=True & profession=pharmacy (some questions in profession medicine overlapping with IgakuQA test set, so profession medicine are excluded) |c381c014c6769d0a8ca40356d7c30a969a12816d| ### Testing Data | Dataset | # of Testing Samples | Data Source | License | Data Selection Method | Version(Commit ID) | |---------|----------------------|-------------|----------------|-----------------------|-----------------------| | [IgakuQA](https://github.com/jungokasai/IgakuQA) | 2,000 | Japanese medical licensing examinations (2018–2022) | Public([Ministry of Health, Labour and Welfare](https://www.mhlw.go.jp/chosakuken/)) | All data |2bc4c3d159cf5505f6253d24a909fbd53237e239| ## Model Performance - We evaluate our model on Igaku dataset. Other baseline result provided by IgakuQA are also listed below. As shown in the table below, although our MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5 is based on an 8B parameter model, it outperforms ChatGPT's results from that year. - As for the fine-tuned JPharmatron-7B, it was mainly used as a baseline and is not the primary model we propose in this work; therefore, we do not upload the model. The training parameters are the same as those used for training our MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5. ### Evaluation Results on IgakuQA Benchmark For our experiment, the prompt for generating answer is as follows: ```python """Answer this medical multiple choice question by selecting the correct option letter (A, B, C, D, or E). Question: {question} Options: {options} Answer:""" ``` See the table below for evaluation results: | Model Configuration | Overall Accuracy | Single-choice accuracy |Multiple-choice accuracy| Notes | |---------------------|------------------|------------------------|------------------------|-------| | [Llama-3.1-Swallow-8B-Instruct-v0.5](https://huggingface.co/tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.5) | 55.75% | 60.33% | 30.87% | Base model | | **MedExamDoc-Llama-3.1-Swallow-8B-Instruct-v0.5** | **62.40%** | **66.31%** | **41.16%** | Our fine-tuned model | | [JPharmatron-7B](https://huggingface.co/EQUES/JPharmatron-7B) | 61.25% | 67.02% | 29.90% | We use the open source model with our script to test accuracy | | [JPharmatron-7B](https://huggingface.co/EQUES/JPharmatron-7B) + finetune | 65.90% | 71.11% | 37.62% | We finetune the model, and use our script to test accuracy. After finetune, the accuracy improve by 4.65%. | | student_majority | 93.90% | 94.24% | 91.95% | Provided by IgakuQA; selects the option most frequently chosen by students | | GPT-4 | 76.60% | 77.97% | 68.79% | Provided by IgakuQA benchmark | | translate_chatgpt | 56.60% | 60.11% | 36.58% | Provided by IgakuQA benchmark; approximately ChatGPT (2023) with translation | | ChatGPT | 53.95% | 56.99% | 36.58% | Provided by IgakuQA benchmark; approximately ChatGPT (2023) | | GPT-3 | 40.35% | 43.13% | 24.50% | Provided by IgakuQA benchmark | ### Framework versions - PEFT 0.15.2 - Transformers 4.52.3 - Pytorch 2.5.1 - Datasets 3.6.0 - Tokenizers 0.21.2