Since Gemma 4 is such a new model, I personally suspect it might be due to a minor bug in the framework:
The short version is that Gemma-4-E2B/E4B are harder to fine-tune locally than Llama or Qwen right now, and on a MacBook M3 the usual cause is a stack problem, not just “bad hyperparameters.” Gemma 4 is a multimodal family, not a plain text-only family: E2B and E4B support text, image, and audio, use a 262K vocabulary, and have 128K context. Google’s own Hugging Face examples for Gemma 4 use multimodal model classes and recent library versions, which is already a hint that the training path is different from the simpler Llama/Qwen recipes you may be used to. (Google AI for Developers)
What your loss pattern likely means
A loss that sits above 40 and refuses to come down is not what I would treat as a normal Gemma-4 baseline quirk. With a vocabulary of about 262K, random-token cross-entropy is only about ln(262000) ≈ 12.5. So a run that stays around 40+ usually means one of four things: the forward path is broken, the labels are misaligned, the masks are wrong, or numerics are unstable. It is much more likely to be one of those than “Gemma just needs more epochs.” (Google AI for Developers)
That interpretation also lines up with current public bug reports. There is an open Transformers issue stating that for Gemma 4, use_cache=False can corrupt attention computation and produce garbage logits. There is another open Transformers issue stating that some Gemma 4 training paths still expect mm_token_type_ids even for text-only fine-tuning. And there is an open PEFT issue showing that Gemma 4’s Gemma4ClippableLinear breaks normal QLoRA target selection on affected versions. Those are not subtle tuning issues. Those are “training can look pathological even when the script runs” issues. (GitHub)
Why Gemma 4 feels harder than Llama and Qwen
The first reason is architectural. Google’s model card and overview make clear that Gemma 4 is built as a multimodal family, and the official text fine-tuning guide still loads the model with AutoModelForImageTextToText, not a plain causal-LM-only path. That means even “text-only SFT” may still be going through code paths and assumptions that are more multimodal than you expect. That is a big contrast with many small Llama/Qwen tutorials, where text-only fine-tuning is usually closer to the simplest path. (Google AI for Developers)
The second reason is ecosystem maturity. Gemma 4 support is new. Hugging Face’s Gemma 4 launch article says Gemma 4 is supported across Transformers, TRL, and Unsloth, and the model pages tell users to install the latest Transformers. The current Transformers release page shows we are already in a rapid sequence of Gemma-4-related patch releases, and the 5.5.2 release specifically mentions a Gemma 4 fix for the use_cache=False path. That is a strong sign that the stack is still stabilizing. (Hugging Face)
The third reason is your hardware and backend. Apple’s own MPS page still says the MPS backend is in beta, and PyTorch’s AdamW docs state that the MPS implementation supports torch.float32 and torch.float16. There is also a live PyTorch issue showing BFloat16 is not supported on MPS in at least some MPS environments. So the common CUDA assumption that “bf16 is the safe default for modern fine-tuning” does not carry over cleanly to Mac MPS. (Apple Developer)
What I think is happening in your case
My best read is that you are not hitting one isolated problem. You are probably hitting three layers at once.
1. You already hit the first layer: Gemma-4-specific PEFT friction
You said you solved the linear-layer selection issue. That lines up directly with the current PEFT Gemma 4 issue: PEFT does not properly recognize Gemma4ClippableLinear as a supported target module in the affected path. So your first obstacle was real and upstream. (GitHub)
2. Your current symptom is more consistent with the training path than with dtype alone
Once training actually starts, the strongest suspect is the Gemma 4 use_cache=False bug. The issue report is explicit: the uncached path can corrupt attention and produce garbage logits. That matters because TRL defaults gradient_checkpointing=True, and many training setups effectively run Gemma through the uncached path when checkpointing is enabled. TRL also defaults bf16=True if fp16 is not set. On NVIDIA this is often fine. On MPS it can be a trap. So if you used near-default TRL settings, your run may have been pushed into the worst possible combination for Gemma 4 on a Mac: uncached path plus MPS precision friction. (GitHub)
3. Your batch or masking may still be wrong even if tokenization works
Gemma 4 prompt formatting is stricter than it looks. Google says thinking mode is activated by <|think|> in the system instruction. Google also says that Gemma’s thought channel should be stripped between normal turns, and for no-thinking fine-tuning on larger Gemma 4 models, adding an empty thought channel can stabilize behavior. TRL separately says that assistant_only_loss=True depends on a chat template that can emit the correct assistant-token mask, and packing changes the sequence construction. If your pipeline already needed chat-template fixes, I would assume there is still a real chance the loss mask or conversation structure is off, even though the model “runs.” (Google AI for Developers)
4. For E2B/E4B, text-only fine-tuning can still need multimodal plumbing
The most non-obvious current gotcha is the open Transformers issue saying mm_token_type_ids may still be required for text-only fine-tuning. Google’s own vision fine-tuning guide also shows a custom collator, manual label masking, dataset_kwargs={"skip_prepare_dataset": True}, and remove_unused_columns=False. So if your current training stack is text-only but uses a generic text collator, I would treat the batch contract itself as suspect. (GitHub)
The background behind the official settings
The official Google text QLoRA guide uses a very conservative baseline: batch size 1, max length 512, learning rate 5e-5, LoRA rank 16, LoRA alpha 16, LoRA dropout 0.05, and max_grad_norm=0.3. It also loads the tokenizer from google/gemma-4-E2B-it to use the official chat template. That is the useful part of the official recipe. (Google AI for Developers)
But you should not copy the rest of the official examples blindly to a Mac. Google’s vision QLoRA guide uses batch size 1, learning rate 2e-4, bf16, and a custom collator with skip_prepare_dataset and remove_unused_columns=False. That recipe is clearly aimed at the standard GPU path. It is useful as documentation of the expected data plumbing, not as a direct MPS recipe. (Google AI for Developers)
Google’s general Gemma fine-tuning guidance also says that PEFT / LoRA is the less resource-intensive path, while full tuning is compute- and memory-intensive. For a MacBook M3, that means LoRA should be your baseline, not full fine-tuning. Full FT may be possible in some cases, but it is the wrong baseline while the stack is still unstable. (Google AI for Developers)
My recommended baseline for a MacBook M3
This is the baseline I would use before trying anything more ambitious.
Software stack
Use the latest Transformers 5.5.x, not an older Gemma 4 stack. Gemma 4 requires recent Transformers, and the 5.5.2 patch specifically includes a Gemma-4-focused fix for the use_cache=False issue. Keep PEFT current, but assume Gemma 4 support is still rough if you broadly target every linear wrapper. (Hugging Face)
Precision
On MPS, treat fp16 as the normal training dtype and fp32 as the diagnostic dtype for short smoke tests. Do not assume bf16 is the right choice on Mac just because many GPU tutorials use it. Apple still calls MPS beta, PyTorch only documents Adam/AdamW MPS support for fp32/fp16, and public MPS issues still show bf16 failures. (Apple Developer)
Optimization settings
Start with LoRA only, batch size 1, sequence length 512 or 1024, learning rate 2e-5 to 5e-5, and max_grad_norm=0.3. The official Google text guide already uses batch size 1, max length 512, LR 5e-5, and grad clip 0.3. I would go even more conservative on MPS by starting at the bottom of that range and only increasing after the run is clearly stable. (Google AI for Developers)
Trainer settings
For the first stable run, I would keep gradient_checkpointing=False, packing=False, assistant_only_loss=False, and torch_compile=False. The reason is not that these settings are always wrong. It is that each one introduces another place for Gemma-4-specific behavior or MPS immaturity to bite you. TRL defaults checkpointing on, and the current Gemma 4 uncached-path bug makes that a bad default to trust blindly. The community Apple-silicon Gemma 4 guide also recommends torch_compile=False on MPS and often attn_implementation="eager" for stability. That last part is community guidance rather than official Google guidance, but it is plausible and matches MPS’s current rough edges. (Hugging Face)
Data and formatting
Use the official chat template only. Keep the dataset in standard system / user / assistant roles. Decide one thinking strategy and stick to it: either final visible answers only, or a consistent thought-aware format. Do not mix multiple incompatible thought formats. For most production-style assistants, answer-only fine-tuning is the simplest stable choice. (Google AI for Developers)
Multimodal gotchas even for text-only SFT
Inspect the first collated batch. Check that you have input_ids, attention_mask, labels, and, if your path requires them, token_type_ids and mm_token_type_ids. If those token-type tensors are missing, the current public workaround is to add them as zero tensors in a custom collator and keep remove_unused_columns=False. That is now an important Gemma 4 debugging step, not an exotic edge case. (GitHub)
What I would do first in your exact situation
I would do a tiny overfit test on 32 to 128 examples. Same format everywhere. No packing. No assistant-only masking. No checkpointing. LoRA only. Batch 1. Length 512. LR 2e-5 or 5e-5. Official chat template. If that tiny set does not overfit cleanly, the problem is not “Gemma needs different training philosophy.” The problem is pipeline correctness. That conclusion follows directly from how far your current loss is from a sane baseline.
Then I would make one change at a time in this order:
- Upgrade Transformers to latest 5.5.x.
- Disable checkpointing for the first stability test.
- Verify
mm_token_type_ids / collator behavior.
- Switch to fp16 on MPS, fp32 only for short diagnostics.
- Try
attn_implementation="eager" and keep torch_compile=False.
- Only then reintroduce checkpointing, assistant-only loss, packing, or longer sequences. (NewReleases)
TRL vs Unsloth for your case
The cleanest way to think about it is this:
TRL is the training framework. It gives you SFTTrainer, dataset handling, packing, assistant-only loss, and the overall fine-tuning API. Hugging Face says Gemma 4 is fully supported in TRL. (Hugging Face)
Unsloth is an optimized Hugging Face-compatible path. Hugging Face’s docs say it is fully compatible with SFTTrainer, and the Transformers integration page says Unsloth patches internal Transformers methods for speed. In other words, Unsloth is usually not a different fine-tuning method. It is more like a faster, lower-memory implementation route around the same ecosystem. (Hugging Face)
For your case, I would use them like this:
- Use TRL / plain Transformers first if your priority is correctness and debugging clarity.
- Use Unsloth second if your priority is local memory pressure, speed, and convenience, especially once the pipeline is known-good. Hugging Face’s Gemma 4 launch post explicitly includes Unsloth Studio as a Gemma 4 fine-tuning option, and Unsloth’s current Gemma 4 guide gives very relevant practical advice for E2B/E4B: start with multimodal loading, keep vision layers off first, fine-tune language/attention/MLP layers first, use small batches, and keep your thinking format consistent. (Hugging Face)
One caveat: Unsloth’s current release notes say they fixed Gemma 4 training issues including exploding losses from gradient accumulation and Gemma 4 use_cache=False behavior. That is useful signal, but it is still vendor-provided guidance. I would treat it as encouraging, not as proof that every Mac/MPS problem disappears by switching libraries. (GitHub)
My bottom-line view
For your specific setup, I would rank the likely causes like this:
- Gemma 4 training-path issue, especially uncached / checkpointed path
- Missing or malformed multimodal batch fields in text-only training
- Template / thinking / masking mismatch
- MPS precision and attention-path instability
- Ordinary hyperparameter tuning (GitHub)
So no, I do not think your situation is just “MacBook bf16 is bad.” That is part of it, but it is not the center of it. The center is that Gemma 4 is new, multimodal, and still receiving training-path fixes, and MPS makes those early-stack problems more visible.
The simplest isolation is to split the problem into three tests:
- Forward-path sanity
- Batch-contract sanity
- Tiny overfit sanity
That works because the current Gemma 4 failure reports are concentrated in those exact places: the use_cache=False path, the mm_token_type_ids / token_type_ids batch fields, and trainer defaults that can silently push you into the broken path. Hugging Face’s current Gemma 4 docs say to use the latest Transformers, and the recent Transformers 5.5.2 patch specifically fixed Gemma-4 use_cache=False behavior. TRL’s SFTTrainer also defaults gradient_checkpointing=True and bf16=True if fp16 is not set, which is risky on Mac MPS. Apple still labels MPS as beta, and PyTorch’s AdamW docs say MPS AdamW support is for float32 and float16. (Hugging Face)
What this code will tell you
If use_cache=True looks sane but use_cache=False looks bad, you likely hit the known Gemma 4 cache-path bug. If the model only works when you manually add zero-filled token_type_ids and mm_token_type_ids, you likely hit the current Gemma 4 text-only batch-contract issue. If both of those pass, but the model still cannot overfit a single short sample, the remaining suspects are mainly template / label masking or MPS numeric instability. (GitHub)
Script 1: forward-path and batch-contract sanity check
Run this first. It does not train. It only checks whether the model’s forward pass is healthy.
# diag_gemma4_forward.py
import os
import math
import torch
import transformers
from transformers import AutoProcessor, AutoModelForCausalLM
MODEL_ID = "google/gemma-4-E2B-it"
def print_env():
print("torch:", torch.__version__)
print("transformers:", transformers.__version__)
print("mps available:", torch.backends.mps.is_available())
print("cuda available:", torch.cuda.is_available())
print("device:", "mps" if torch.backends.mps.is_available() else "cpu")
def build_inputs(processor, device):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Say hello in a formal way."},
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
batch = processor(text=text, return_tensors="pt")
batch = {k: v.to(device) for k, v in batch.items()}
return text, batch
def add_missing_token_type_fields(batch):
# Current Gemma 4 issue says these may be required even for text-only SFT.
if "token_type_ids" not in batch:
batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
if "mm_token_type_ids" not in batch:
batch["mm_token_type_ids"] = torch.zeros_like(batch["input_ids"])
return batch
def topk_report(processor, logits, k=5):
probs = torch.softmax(logits[0, -1].float(), dim=-1)
vals, idxs = torch.topk(probs, k)
rows = []
for p, idx in zip(vals.tolist(), idxs.tolist()):
token = processor.tokenizer.decode([idx])
rows.append((repr(token), p))
return rows
def safe_forward(model, batch, use_cache):
with torch.no_grad():
out = model(**batch, use_cache=use_cache)
return out
def main():
print_env()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# Use float16 first on MPS. Switch to float32 only for a short diagnostic run if needed.
dtype = torch.float16 if device.type == "mps" else torch.float32
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
).to(device)
# Explicitly keep cache ON for this diagnostic.
model.config.use_cache = True
model.eval()
text, batch = build_inputs(processor, device)
print("\nPrompt:")
print(text)
print("\nOriginal batch keys:", sorted(batch.keys()))
batch = add_missing_token_type_fields(batch)
print("Patched batch keys:", sorted(batch.keys()))
print("\n=== Forward with use_cache=True ===")
out_true = safe_forward(model, batch, use_cache=True)
report_true = topk_report(processor, out_true.logits, k=5)
for token, prob in report_true:
print(f"{token:>20} {prob:.6f}")
print("\n=== Forward with use_cache=False ===")
try:
out_false = safe_forward(model, batch, use_cache=False)
report_false = topk_report(processor, out_false.logits, k=5)
for token, prob in report_false:
print(f"{token:>20} {prob:.6f}")
except Exception as e:
print("FAILED with use_cache=False:", repr(e))
# Optional: compare losses on the same short target.
labels = batch["input_ids"].clone()
with torch.no_grad():
loss_true = model(**batch, labels=labels, use_cache=True).loss.item()
print(f"\nLoss with use_cache=True : {loss_true:.6f}")
try:
with torch.no_grad():
loss_false = model(**batch, labels=labels, use_cache=False).loss.item()
print(f"Loss with use_cache=False: {loss_false:.6f}")
except Exception as e:
print("Loss with use_cache=False failed:", repr(e))
if __name__ == "__main__":
main()
How to read the result
A healthy run should show sensible next-token candidates for the formal hello prompt and a loss that is not absurd. If use_cache=True gives reasonable tokens but use_cache=False gives junk, that is almost a direct fingerprint of the known Gemma 4 cache-path bug. If the forward pass fails until you add zero-filled token_type_ids and mm_token_type_ids, that matches the current open Gemma 4 text-only training issue. (GitHub)
Script 2: tiny overfit sanity check
Only run this after Script 1 looks sane. The point is not to get a good model. The point is to see whether the training loop can reduce loss on one tiny sample. Keep it extremely small. No TRL. No checkpointing. No packing. No assistant-only masking.
# diag_gemma4_tiny_overfit.py
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
MODEL_ID = "google/gemma-4-E2B-it"
def make_batch(processor, device):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2 + 2?"},
{"role": "assistant", "content": "4"},
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
enable_thinking=False,
)
batch = processor(text=text, return_tensors="pt")
batch = {k: v.to(device) for k, v in batch.items()}
if "token_type_ids" not in batch:
batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
if "mm_token_type_ids" not in batch:
batch["mm_token_type_ids"] = torch.zeros_like(batch["input_ids"])
batch["labels"] = batch["input_ids"].clone()
return batch
def main():
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# For a tiny diagnostic, float32 is acceptable if memory allows.
# If memory is too tight, switch this to torch.float16.
dtype = torch.float32 if device.type == "mps" else torch.float32
processor = AutoProcessor.from_pretrained(MODEL_ID)
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
).to(device)
base_model.config.use_cache = True
if hasattr(base_model, "gradient_checkpointing_disable"):
base_model.gradient_checkpointing_disable()
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.0,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
)
model = get_peft_model(base_model, lora_config)
model.train()
batch = make_batch(processor, device)
# Only train LoRA params
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=2e-5)
print("Starting tiny overfit test...")
for step in range(20):
optimizer.zero_grad(set_to_none=True)
out = model(**batch, use_cache=True)
loss = out.loss
if not torch.isfinite(loss):
print(f"Step {step:02d} loss is non-finite:", loss)
break
loss.backward()
torch.nn.utils.clip_grad_norm_(params, 0.3)
optimizer.step()
print(f"step={step:02d} loss={loss.item():.6f}")
if __name__ == "__main__":
main()
How to read the overfit test
If the loss does not move down at all on one trivial sample, the problem is almost never “needs better hyperparameters.” It means the training path is still broken. If the loss drops cleanly here but explodes in your full training job, the difference is probably in your dataset formatting, masking, sequence length, trainer defaults, or MPS pressure. That conclusion is consistent with the Gemma 4 cache-path bug report, the mm_token_type_ids report, and Google’s requirement to use the model’s own chat formatting. (GitHub)
If you are using TRL later, override these defaults
If your manual test works but SFTTrainer does not, check TRL first. TRL currently documents that SFTTrainer defaults gradient_checkpointing=True, bf16=True if fp16 is not set, and learning_rate=2e-5. For Gemma 4 on MPS, the important one is checkpointing, because it can route you straight into the uncached path that was recently fixed in Transformers 5.5.2. (Hugging Face)
A minimal safe TRL config for diagnosis looks like this:
from trl import SFTConfig
cfg = SFTConfig(
output_dir="out",
per_device_train_batch_size=1,
gradient_accumulation_steps=1,
learning_rate=2e-5,
max_seq_length=256,
gradient_checkpointing=False, # important for Gemma 4 diagnosis
bf16=False,
fp16=False, # manual tests first; use fp16 only if you must
packing=False,
max_grad_norm=0.3,
logging_steps=1,
)
What results map to what root cause
If Script 1 fails only when use_cache=False, you found the known Gemma 4 path bug. Upgrade Transformers first, and avoid forcing use_cache=False while isolating. Transformers 5.5.2 specifically shipped a Gemma-4 use_cache=False fix. (GitHub)
If Script 1 only works after adding token_type_ids and mm_token_type_ids, your issue is likely the current batch-contract problem for Gemma 4 text-only fine-tuning. In that case, keep the custom collator and remove_unused_columns=False when you move back to trainer code. (GitHub)
If Script 1 is fine and Script 2 cannot overfit one sample, suspect MPS numerics or a remaining formatting bug. Apple still documents MPS as beta, and PyTorch only documents MPS AdamW support for float16 and float32. That is why I would debug on float32 for tiny tests if memory allows, then move to float16 for real runs on MPS. (Apple Developer)
If both scripts pass, your full run is probably failing because of dataset construction, long sequence length, assistant-only masking, packing, or TRL defaults rather than because “Gemma 4 cannot fine-tune on Mac.” Google’s own Gemma 4 docs emphasize using the built-in processor and chat template, and Gemma 4’s prompt structure is more specific than older Llama/Qwen-style plain chat formatting. (Google AI for Developers)