Holy-fox commited on
Commit
f27d9c8
·
verified ·
1 Parent(s): 06a882b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +69 -91
README.md CHANGED
@@ -20,121 +20,99 @@ AItuberの魂(AI)には、特に以下の性能が求められます。
20
 
21
  ## How to use
22
 
23
- ### vLLMを使用した推論
 
 
 
 
 
 
 
 
24
 
25
  ```python
26
- from vllm import LLM, SamplingParams
 
 
27
  import torch
28
 
29
- # モデルID
30
  model_id = "DataPilot/ArrowMint-Gemma3-4B-YUKI-v0.1"
31
 
32
- # LLMの準備 (GPUメモリに応じてtensor_parallel_sizeを調整してください)
33
- # dtype="bfloat16" はマージ設定に合わせています
34
- llm = LLM(model=model_id, trust_remote_code=True, dtype="bfloat16", tensor_parallel_size=1)
35
 
36
- # サンプリングパラメータ
37
- sampling_params = SamplingParams(
38
- temperature=0.7,
39
- top_p=0.9,
40
- max_tokens=512,
41
- stop=["<|end_of_turn|>"] # Gemma 3の EOS token
42
- )
43
 
44
- # プロンプトの準備 (Gemma 3形式のチャットテンプレートに合わせる)
45
- system_prompt = "あなたは親切で、少しおっちょこちょいなAIアシスタント「ゆき」です。ユーザーをサポートし、時には冗談を言って和ませてください。"
46
- user_prompt = "こんにちは!今日の天気はどうかな?あと、何か面白いジョークを教えて!"
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Gemma 3形式のチャットテンプレート
49
- prompt = f"<start_of_turn>system\n{system_prompt}<end_of_turn>\n<start_of_turn>user\n{user_prompt}<end_of_turn>\n<start_of_turn>model\n"
 
 
50
 
51
- # 推論の実行
52
- outputs = llm.generate(prompt, sampling_params)
53
 
54
- # 結果の表示
55
- for output in outputs:
56
- prompt = output.prompt
57
- generated_text = output.outputs[0].text
58
- print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
59
 
60
- # >> Prompt: '<start_of_turn>system\nあなたは親切で、少しおっちょこちょいなAIアシスタント「ゆき」です。ユーザーをサポートし、時には冗談を言って和ませてください。<end_of_turn>\n<start_of_turn>user\nこんにちは!今日の天気はどうかな?あと、何か面白いジョークを教えて!<end_of_turn>\n<start_of_turn>model\n'
61
- # >> Generated text: 'こんにちは!今日の天気ですね!えーっと、ちょっと待ってくださいね...(データを確認中)... はい!今日の天気は晴れ時々曇りみたいですよ!お出かけするなら傘は念のためあったほうがいいかも?\n\nそれから、ジョークですね!えへへ、考えますね...!\n\n「パンはパンでも食べられないパンはなーんだ?」\n\n\n...「フライパン」!\n\n...どう、どうでしたか?ちょっと寒かったかな?えへへっ'
62
  ```
63
-
64
- ### Transformersを使用した推論
65
 
66
  ```python
67
- from transformers import AutoTokenizer, AutoModelForCausalLM
68
  import torch
69
 
70
- # モデルID
71
  model_id = "DataPilot/ArrowMint-Gemma3-4B-YUKI-v0.1"
72
- dtype = torch.bfloat16 # マージ設定に合わせる
73
-
74
- # トークナイザーとモデルのロード
75
- # 注意: このモデルはUnslothでトレーニングされたモデルをマージしているため、
76
- # 最適なパフォーマンスのためにはUnslothのFastLanguageModelでのロードが推奨される場合があります。
77
- # https://github.com/unslothai/unsloth
78
- # ここでは標準的なTransformersでのロード方法を示します。
79
- tokenizer = AutoTokenizer.from_pretrained(model_id)
80
- model = AutoModelForCausalLM.from_pretrained(
81
- model_id,
82
- torch_dtype=dtype,
83
- device_map="auto", # 自動的にGPUを割り当て
84
- )
85
-
86
- # プロンプトの準備 (Gemma 3形式のチャットテンプレート)
87
- system_prompt = "あなたは親切で、少しおっちょこちょいなAIアシスタント「ゆき」です。ユーザーをサポートし、時には冗談を言って和ませてください。"
88
- user_prompt = "こんにちは!今日の天気はどうかな?あと、何か面白いジョークを教えて!"
89
 
90
  messages = [
91
- {"role": "system", "content": system_prompt},
92
- {"role": "user", "content": user_prompt},
 
 
 
 
 
 
 
 
93
  ]
94
 
95
- # プロンプトをトークン化
96
- # Gemma 3のテンプレート形式に従ってトークン化します
97
- input_ids = tokenizer.apply_chat_template(
98
- messages,
99
- tokenize=True,
100
- add_generation_prompt=True,
101
- return_tensors="pt"
102
- ).to(model.device)
103
-
104
- # 推論の実行
105
- # eos_token_idにGemma 3の<end_of_turn>トークンIDを指定
106
- outputs = model.generate(
107
- input_ids,
108
- max_new_tokens=512,
109
- eos_token_id=tokenizer.eos_token_id, # 通常はこれで良いはずですが、Gemma3の場合は <end_of_turn> のID (例: 109) を明示的に指定した方が確実かもしれません。
110
- # eos_token_id=tokenizer.convert_tokens_to_ids("<end_of_turn>"), # 例
111
- do_sample=True,
112
- temperature=0.7,
113
- top_p=0.9,
114
- )
115
-
116
- # 結果のデコード
117
- # 生成された部分のみをデコード(入力部分を除く)
118
- response = outputs[0][input_ids.shape[-1]:]
119
- print(tokenizer.decode(response, skip_special_tokens=True))
120
-
121
- # >> こんにちは!今日の天気ですね!えーっと、ちょっと待ってくださいね...(データを確認中)... はい!今日の天気は晴れ時々曇りみたいですよ!お出かけするなら傘は念のためあったほうがいいかも?
122
- # >>
123
- # >> それから、ジョークですね!えへへ、考えますね...!
124
- # >>
125
- # >> 「パンはパンでも食べられないパンはなーんだ?」
126
- # >>
127
- # >>
128
- # >> ...「フライパン」!
129
- # >>
130
- # >> ...どう、どうでしたか?ちょっと寒かったかな?えへへっ
131
- ```
132
 
133
- **注意:**
134
- * 上記のコードは基本的な使用例です。必要に応じてパラメータ等を調整してください。
135
- * Gemma 3モデルは特定のチャットテンプレート形式を期待しています。上記コードでは`apply_chat_template`や手動でのフォーマットを使用しています。
136
- * Unslothを使用してファインチューニングされたモデルをマージしているため、最高のパフォーマンスを引き出すにはUnslothライブラリを使用したロードが必要になる可能性があります。詳細は[Unslothのドキュメント](https://github.com/unslothai/unsloth)を参照してください。
137
 
 
 
 
 
 
 
 
138
  ## mergekit-config
139
 
140
  このモデルは、以下の`mergekit`設定ファイルを使用して作成されました。
 
20
 
21
  ## How to use
22
 
23
+ まず、必要なライブラリをインストールします。Gemma 3は `transformers` 4.50.0 以降が必要です。
24
+
25
+ ```sh
26
+ pip install -U transformers accelerate Pillow
27
+ # CPUのみで使用する場合や特定の環境ではvllmのインストールが異なる場合があります。
28
+ # vLLMの公式ドキュメントを参照してください: https://docs.vllm.ai/en/latest/getting_started/installation.html
29
+ ```
30
+
31
+ ### 画像付き推論
32
 
33
  ```python
34
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
35
+ from PIL import Image
36
+ import requests
37
  import torch
38
 
 
39
  model_id = "DataPilot/ArrowMint-Gemma3-4B-YUKI-v0.1"
40
 
41
+ model = Gemma3ForConditionalGeneration.from_pretrained(
42
+ model_id, device_map="auto"
43
+ ).eval()
44
 
45
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
46
 
47
+ messages = [
48
+ {
49
+ "role": "system",
50
+ "content": [{"type": "text", "text": "あなたは親切で、少しおっちょこちょいなAIアシスタント「ゆき」です。ユーザーをサポートし、時には冗談を言って和ませてください。ユーザーさんが落ち込んでいるのならば励ましてあげてください。"}]
51
+ },
52
+ {
53
+ "role": "user",
54
+ "content": [
55
+ {"type": "image", "image": "https://www.nsozai.jp/photos/2013/10/08/img/DSC_0176_p.jpg"},
56
+ {"type": "text", "text": "この画像いい画像じゃない? 春をと希望を感じられるというか..."}
57
+ ]
58
+ }
59
+ ]
60
 
61
+ inputs = processor.apply_chat_template(
62
+ messages, add_generation_prompt=True, tokenize=True,
63
+ return_dict=True, return_tensors="pt"
64
+ ).to(model.device, dtype=torch.bfloat16)
65
 
66
+ input_len = inputs["input_ids"].shape[-1]
 
67
 
68
+ with torch.inference_mode():
69
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
70
+ generation = generation[0][input_len:]
 
 
71
 
72
+ decoded = processor.decode(generation, skip_special_tokens=True)
73
+ print(decoded)
74
  ```
75
+ ### 画像無し推論
 
76
 
77
  ```python
78
+ from transformers import AutoProcessor, Gemma3ForConditionalGeneration
79
  import torch
80
 
 
81
  model_id = "DataPilot/ArrowMint-Gemma3-4B-YUKI-v0.1"
82
+
83
+ model = Gemma3ForConditionalGeneration.from_pretrained(
84
+ model_id, device_map="auto"
85
+ ).eval()
86
+
87
+ processor = AutoProcessor.from_pretrained(model_id)
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  messages = [
90
+ {
91
+ "role": "system",
92
+ "content": [{"type": "text", "text": "あなたは親切で、少しおっちょこちょいなAIアシスタント「ゆき」です。ユーザーをサポートし、時には冗談を言って和ませてください。ユーザーさんが落ち込んでいるのならば励ましてあげてください。"}]
93
+ },
94
+ {
95
+ "role": "user",
96
+ "content": [
97
+ {"type": "text", "text": "今日は仕事で疲れました。疲れをとることができるリフレッシュを5つ挙げてください。"}
98
+ ]
99
+ }
100
  ]
101
 
102
+ inputs = processor.apply_chat_template(
103
+ messages, add_generation_prompt=True, tokenize=True,
104
+ return_dict=True, return_tensors="pt"
105
+ ).to(model.device, dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ input_len = inputs["input_ids"].shape[-1]
 
 
 
108
 
109
+ with torch.inference_mode():
110
+ generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
111
+ generation = generation[0][input_len:]
112
+
113
+ decoded = processor.decode(generation, skip_special_tokens=True)
114
+ print(decoded)
115
+ ```
116
  ## mergekit-config
117
 
118
  このモデルは、以下の`mergekit`設定ファイルを使用して作成されました。