Fine-tuning Gemma-4-E2B on MacBook M3

Hi. I’m trying to fine-tune Gemma-4-E2B on MacBook M3 but I haven’t been able to do so. I had previously fine-tuned Llama and Qwen models with no issues. Gemme-4 is presenting real challenges. Having resolved the linear layers selection issue and the tokenizer chat template issue, now I’m stuck with high Training Loss rates > 40 that refuse to decrease. Some info I found blame bfloat16 support on MacBooks. I tried using float16 and even float32 but the model exploded after few epochs.

I wonder if there’s any Gemma-4-E2B/E4B best settings info or gotchas list to watch out for or maybe a guide that helps overcome these issues? Any guidance would be truly appreciated!

1 Like

Since Gemma 4 is such a new model, I personally suspect it might be due to a minor bug in the framework:


The short version is that Gemma-4-E2B/E4B are harder to fine-tune locally than Llama or Qwen right now, and on a MacBook M3 the usual cause is a stack problem, not just “bad hyperparameters.” Gemma 4 is a multimodal family, not a plain text-only family: E2B and E4B support text, image, and audio, use a 262K vocabulary, and have 128K context. Google’s own Hugging Face examples for Gemma 4 use multimodal model classes and recent library versions, which is already a hint that the training path is different from the simpler Llama/Qwen recipes you may be used to. (Google AI for Developers)

What your loss pattern likely means

A loss that sits above 40 and refuses to come down is not what I would treat as a normal Gemma-4 baseline quirk. With a vocabulary of about 262K, random-token cross-entropy is only about ln(262000) ≈ 12.5. So a run that stays around 40+ usually means one of four things: the forward path is broken, the labels are misaligned, the masks are wrong, or numerics are unstable. It is much more likely to be one of those than “Gemma just needs more epochs.” (Google AI for Developers)

That interpretation also lines up with current public bug reports. There is an open Transformers issue stating that for Gemma 4, use_cache=False can corrupt attention computation and produce garbage logits. There is another open Transformers issue stating that some Gemma 4 training paths still expect mm_token_type_ids even for text-only fine-tuning. And there is an open PEFT issue showing that Gemma 4’s Gemma4ClippableLinear breaks normal QLoRA target selection on affected versions. Those are not subtle tuning issues. Those are “training can look pathological even when the script runs” issues. (GitHub)

Why Gemma 4 feels harder than Llama and Qwen

The first reason is architectural. Google’s model card and overview make clear that Gemma 4 is built as a multimodal family, and the official text fine-tuning guide still loads the model with AutoModelForImageTextToText, not a plain causal-LM-only path. That means even “text-only SFT” may still be going through code paths and assumptions that are more multimodal than you expect. That is a big contrast with many small Llama/Qwen tutorials, where text-only fine-tuning is usually closer to the simplest path. (Google AI for Developers)

The second reason is ecosystem maturity. Gemma 4 support is new. Hugging Face’s Gemma 4 launch article says Gemma 4 is supported across Transformers, TRL, and Unsloth, and the model pages tell users to install the latest Transformers. The current Transformers release page shows we are already in a rapid sequence of Gemma-4-related patch releases, and the 5.5.2 release specifically mentions a Gemma 4 fix for the use_cache=False path. That is a strong sign that the stack is still stabilizing. (Hugging Face)

The third reason is your hardware and backend. Apple’s own MPS page still says the MPS backend is in beta, and PyTorch’s AdamW docs state that the MPS implementation supports torch.float32 and torch.float16. There is also a live PyTorch issue showing BFloat16 is not supported on MPS in at least some MPS environments. So the common CUDA assumption that “bf16 is the safe default for modern fine-tuning” does not carry over cleanly to Mac MPS. (Apple Developer)

What I think is happening in your case

My best read is that you are not hitting one isolated problem. You are probably hitting three layers at once.

1. You already hit the first layer: Gemma-4-specific PEFT friction

You said you solved the linear-layer selection issue. That lines up directly with the current PEFT Gemma 4 issue: PEFT does not properly recognize Gemma4ClippableLinear as a supported target module in the affected path. So your first obstacle was real and upstream. (GitHub)

2. Your current symptom is more consistent with the training path than with dtype alone

Once training actually starts, the strongest suspect is the Gemma 4 use_cache=False bug. The issue report is explicit: the uncached path can corrupt attention and produce garbage logits. That matters because TRL defaults gradient_checkpointing=True, and many training setups effectively run Gemma through the uncached path when checkpointing is enabled. TRL also defaults bf16=True if fp16 is not set. On NVIDIA this is often fine. On MPS it can be a trap. So if you used near-default TRL settings, your run may have been pushed into the worst possible combination for Gemma 4 on a Mac: uncached path plus MPS precision friction. (GitHub)

3. Your batch or masking may still be wrong even if tokenization works

Gemma 4 prompt formatting is stricter than it looks. Google says thinking mode is activated by <|think|> in the system instruction. Google also says that Gemma’s thought channel should be stripped between normal turns, and for no-thinking fine-tuning on larger Gemma 4 models, adding an empty thought channel can stabilize behavior. TRL separately says that assistant_only_loss=True depends on a chat template that can emit the correct assistant-token mask, and packing changes the sequence construction. If your pipeline already needed chat-template fixes, I would assume there is still a real chance the loss mask or conversation structure is off, even though the model “runs.” (Google AI for Developers)

4. For E2B/E4B, text-only fine-tuning can still need multimodal plumbing

The most non-obvious current gotcha is the open Transformers issue saying mm_token_type_ids may still be required for text-only fine-tuning. Google’s own vision fine-tuning guide also shows a custom collator, manual label masking, dataset_kwargs={"skip_prepare_dataset": True}, and remove_unused_columns=False. So if your current training stack is text-only but uses a generic text collator, I would treat the batch contract itself as suspect. (GitHub)

The background behind the official settings

The official Google text QLoRA guide uses a very conservative baseline: batch size 1, max length 512, learning rate 5e-5, LoRA rank 16, LoRA alpha 16, LoRA dropout 0.05, and max_grad_norm=0.3. It also loads the tokenizer from google/gemma-4-E2B-it to use the official chat template. That is the useful part of the official recipe. (Google AI for Developers)

But you should not copy the rest of the official examples blindly to a Mac. Google’s vision QLoRA guide uses batch size 1, learning rate 2e-4, bf16, and a custom collator with skip_prepare_dataset and remove_unused_columns=False. That recipe is clearly aimed at the standard GPU path. It is useful as documentation of the expected data plumbing, not as a direct MPS recipe. (Google AI for Developers)

Google’s general Gemma fine-tuning guidance also says that PEFT / LoRA is the less resource-intensive path, while full tuning is compute- and memory-intensive. For a MacBook M3, that means LoRA should be your baseline, not full fine-tuning. Full FT may be possible in some cases, but it is the wrong baseline while the stack is still unstable. (Google AI for Developers)

My recommended baseline for a MacBook M3

This is the baseline I would use before trying anything more ambitious.

Software stack

Use the latest Transformers 5.5.x, not an older Gemma 4 stack. Gemma 4 requires recent Transformers, and the 5.5.2 patch specifically includes a Gemma-4-focused fix for the use_cache=False issue. Keep PEFT current, but assume Gemma 4 support is still rough if you broadly target every linear wrapper. (Hugging Face)

Precision

On MPS, treat fp16 as the normal training dtype and fp32 as the diagnostic dtype for short smoke tests. Do not assume bf16 is the right choice on Mac just because many GPU tutorials use it. Apple still calls MPS beta, PyTorch only documents Adam/AdamW MPS support for fp32/fp16, and public MPS issues still show bf16 failures. (Apple Developer)

Optimization settings

Start with LoRA only, batch size 1, sequence length 512 or 1024, learning rate 2e-5 to 5e-5, and max_grad_norm=0.3. The official Google text guide already uses batch size 1, max length 512, LR 5e-5, and grad clip 0.3. I would go even more conservative on MPS by starting at the bottom of that range and only increasing after the run is clearly stable. (Google AI for Developers)

Trainer settings

For the first stable run, I would keep gradient_checkpointing=False, packing=False, assistant_only_loss=False, and torch_compile=False. The reason is not that these settings are always wrong. It is that each one introduces another place for Gemma-4-specific behavior or MPS immaturity to bite you. TRL defaults checkpointing on, and the current Gemma 4 uncached-path bug makes that a bad default to trust blindly. The community Apple-silicon Gemma 4 guide also recommends torch_compile=False on MPS and often attn_implementation="eager" for stability. That last part is community guidance rather than official Google guidance, but it is plausible and matches MPS’s current rough edges. (Hugging Face)

Data and formatting

Use the official chat template only. Keep the dataset in standard system / user / assistant roles. Decide one thinking strategy and stick to it: either final visible answers only, or a consistent thought-aware format. Do not mix multiple incompatible thought formats. For most production-style assistants, answer-only fine-tuning is the simplest stable choice. (Google AI for Developers)

Multimodal gotchas even for text-only SFT

Inspect the first collated batch. Check that you have input_ids, attention_mask, labels, and, if your path requires them, token_type_ids and mm_token_type_ids. If those token-type tensors are missing, the current public workaround is to add them as zero tensors in a custom collator and keep remove_unused_columns=False. That is now an important Gemma 4 debugging step, not an exotic edge case. (GitHub)

What I would do first in your exact situation

I would do a tiny overfit test on 32 to 128 examples. Same format everywhere. No packing. No assistant-only masking. No checkpointing. LoRA only. Batch 1. Length 512. LR 2e-5 or 5e-5. Official chat template. If that tiny set does not overfit cleanly, the problem is not “Gemma needs different training philosophy.” The problem is pipeline correctness. That conclusion follows directly from how far your current loss is from a sane baseline.

Then I would make one change at a time in this order:

  1. Upgrade Transformers to latest 5.5.x.
  2. Disable checkpointing for the first stability test.
  3. Verify mm_token_type_ids / collator behavior.
  4. Switch to fp16 on MPS, fp32 only for short diagnostics.
  5. Try attn_implementation="eager" and keep torch_compile=False.
  6. Only then reintroduce checkpointing, assistant-only loss, packing, or longer sequences. (NewReleases)

TRL vs Unsloth for your case

The cleanest way to think about it is this:

TRL is the training framework. It gives you SFTTrainer, dataset handling, packing, assistant-only loss, and the overall fine-tuning API. Hugging Face says Gemma 4 is fully supported in TRL. (Hugging Face)

Unsloth is an optimized Hugging Face-compatible path. Hugging Face’s docs say it is fully compatible with SFTTrainer, and the Transformers integration page says Unsloth patches internal Transformers methods for speed. In other words, Unsloth is usually not a different fine-tuning method. It is more like a faster, lower-memory implementation route around the same ecosystem. (Hugging Face)

For your case, I would use them like this:

  • Use TRL / plain Transformers first if your priority is correctness and debugging clarity.
  • Use Unsloth second if your priority is local memory pressure, speed, and convenience, especially once the pipeline is known-good. Hugging Face’s Gemma 4 launch post explicitly includes Unsloth Studio as a Gemma 4 fine-tuning option, and Unsloth’s current Gemma 4 guide gives very relevant practical advice for E2B/E4B: start with multimodal loading, keep vision layers off first, fine-tune language/attention/MLP layers first, use small batches, and keep your thinking format consistent. (Hugging Face)

One caveat: Unsloth’s current release notes say they fixed Gemma 4 training issues including exploding losses from gradient accumulation and Gemma 4 use_cache=False behavior. That is useful signal, but it is still vendor-provided guidance. I would treat it as encouraging, not as proof that every Mac/MPS problem disappears by switching libraries. (GitHub)

My bottom-line view

For your specific setup, I would rank the likely causes like this:

  1. Gemma 4 training-path issue, especially uncached / checkpointed path
  2. Missing or malformed multimodal batch fields in text-only training
  3. Template / thinking / masking mismatch
  4. MPS precision and attention-path instability
  5. Ordinary hyperparameter tuning (GitHub)

So no, I do not think your situation is just “MacBook bf16 is bad.” That is part of it, but it is not the center of it. The center is that Gemma 4 is new, multimodal, and still receiving training-path fixes, and MPS makes those early-stack problems more visible.


The simplest isolation is to split the problem into three tests:

  1. Forward-path sanity
  2. Batch-contract sanity
  3. Tiny overfit sanity

That works because the current Gemma 4 failure reports are concentrated in those exact places: the use_cache=False path, the mm_token_type_ids / token_type_ids batch fields, and trainer defaults that can silently push you into the broken path. Hugging Face’s current Gemma 4 docs say to use the latest Transformers, and the recent Transformers 5.5.2 patch specifically fixed Gemma-4 use_cache=False behavior. TRL’s SFTTrainer also defaults gradient_checkpointing=True and bf16=True if fp16 is not set, which is risky on Mac MPS. Apple still labels MPS as beta, and PyTorch’s AdamW docs say MPS AdamW support is for float32 and float16. (Hugging Face)

What this code will tell you

If use_cache=True looks sane but use_cache=False looks bad, you likely hit the known Gemma 4 cache-path bug. If the model only works when you manually add zero-filled token_type_ids and mm_token_type_ids, you likely hit the current Gemma 4 text-only batch-contract issue. If both of those pass, but the model still cannot overfit a single short sample, the remaining suspects are mainly template / label masking or MPS numeric instability. (GitHub)

Script 1: forward-path and batch-contract sanity check

Run this first. It does not train. It only checks whether the model’s forward pass is healthy.

# diag_gemma4_forward.py
import os
import math
import torch
import transformers
from transformers import AutoProcessor, AutoModelForCausalLM

MODEL_ID = "google/gemma-4-E2B-it"

def print_env():
    print("torch:", torch.__version__)
    print("transformers:", transformers.__version__)
    print("mps available:", torch.backends.mps.is_available())
    print("cuda available:", torch.cuda.is_available())
    print("device:", "mps" if torch.backends.mps.is_available() else "cpu")

def build_inputs(processor, device):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Say hello in a formal way."},
    ]

    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )
    batch = processor(text=text, return_tensors="pt")
    batch = {k: v.to(device) for k, v in batch.items()}
    return text, batch

def add_missing_token_type_fields(batch):
    # Current Gemma 4 issue says these may be required even for text-only SFT.
    if "token_type_ids" not in batch:
        batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
    if "mm_token_type_ids" not in batch:
        batch["mm_token_type_ids"] = torch.zeros_like(batch["input_ids"])
    return batch

def topk_report(processor, logits, k=5):
    probs = torch.softmax(logits[0, -1].float(), dim=-1)
    vals, idxs = torch.topk(probs, k)
    rows = []
    for p, idx in zip(vals.tolist(), idxs.tolist()):
        token = processor.tokenizer.decode([idx])
        rows.append((repr(token), p))
    return rows

def safe_forward(model, batch, use_cache):
    with torch.no_grad():
        out = model(**batch, use_cache=use_cache)
    return out

def main():
    print_env()

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    # Use float16 first on MPS. Switch to float32 only for a short diagnostic run if needed.
    dtype = torch.float16 if device.type == "mps" else torch.float32

    processor = AutoProcessor.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
    ).to(device)

    # Explicitly keep cache ON for this diagnostic.
    model.config.use_cache = True
    model.eval()

    text, batch = build_inputs(processor, device)
    print("\nPrompt:")
    print(text)

    print("\nOriginal batch keys:", sorted(batch.keys()))
    batch = add_missing_token_type_fields(batch)
    print("Patched batch keys:", sorted(batch.keys()))

    print("\n=== Forward with use_cache=True ===")
    out_true = safe_forward(model, batch, use_cache=True)
    report_true = topk_report(processor, out_true.logits, k=5)
    for token, prob in report_true:
        print(f"{token:>20}  {prob:.6f}")

    print("\n=== Forward with use_cache=False ===")
    try:
        out_false = safe_forward(model, batch, use_cache=False)
        report_false = topk_report(processor, out_false.logits, k=5)
        for token, prob in report_false:
            print(f"{token:>20}  {prob:.6f}")
    except Exception as e:
        print("FAILED with use_cache=False:", repr(e))

    # Optional: compare losses on the same short target.
    labels = batch["input_ids"].clone()
    with torch.no_grad():
        loss_true = model(**batch, labels=labels, use_cache=True).loss.item()
    print(f"\nLoss with use_cache=True : {loss_true:.6f}")

    try:
        with torch.no_grad():
            loss_false = model(**batch, labels=labels, use_cache=False).loss.item()
        print(f"Loss with use_cache=False: {loss_false:.6f}")
    except Exception as e:
        print("Loss with use_cache=False failed:", repr(e))

if __name__ == "__main__":
    main()

How to read the result

A healthy run should show sensible next-token candidates for the formal hello prompt and a loss that is not absurd. If use_cache=True gives reasonable tokens but use_cache=False gives junk, that is almost a direct fingerprint of the known Gemma 4 cache-path bug. If the forward pass fails until you add zero-filled token_type_ids and mm_token_type_ids, that matches the current open Gemma 4 text-only training issue. (GitHub)

Script 2: tiny overfit sanity check

Only run this after Script 1 looks sane. The point is not to get a good model. The point is to see whether the training loop can reduce loss on one tiny sample. Keep it extremely small. No TRL. No checkpointing. No packing. No assistant-only masking.

# diag_gemma4_tiny_overfit.py
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

MODEL_ID = "google/gemma-4-E2B-it"

def make_batch(processor, device):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is 2 + 2?"},
        {"role": "assistant", "content": "4"},
    ]
    text = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
        enable_thinking=False,
    )
    batch = processor(text=text, return_tensors="pt")
    batch = {k: v.to(device) for k, v in batch.items()}
    if "token_type_ids" not in batch:
        batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])
    if "mm_token_type_ids" not in batch:
        batch["mm_token_type_ids"] = torch.zeros_like(batch["input_ids"])
    batch["labels"] = batch["input_ids"].clone()
    return batch

def main():
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    # For a tiny diagnostic, float32 is acceptable if memory allows.
    # If memory is too tight, switch this to torch.float16.
    dtype = torch.float32 if device.type == "mps" else torch.float32

    processor = AutoProcessor.from_pretrained(MODEL_ID)
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
    ).to(device)

    base_model.config.use_cache = True
    if hasattr(base_model, "gradient_checkpointing_disable"):
        base_model.gradient_checkpointing_disable()

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.0,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj", "k_proj", "v_proj", "o_proj",
            "gate_proj", "up_proj", "down_proj"
        ],
    )
    model = get_peft_model(base_model, lora_config)
    model.train()

    batch = make_batch(processor, device)

    # Only train LoRA params
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(params, lr=2e-5)

    print("Starting tiny overfit test...")
    for step in range(20):
        optimizer.zero_grad(set_to_none=True)
        out = model(**batch, use_cache=True)
        loss = out.loss
        if not torch.isfinite(loss):
            print(f"Step {step:02d} loss is non-finite:", loss)
            break
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 0.3)
        optimizer.step()
        print(f"step={step:02d} loss={loss.item():.6f}")

if __name__ == "__main__":
    main()

How to read the overfit test

If the loss does not move down at all on one trivial sample, the problem is almost never “needs better hyperparameters.” It means the training path is still broken. If the loss drops cleanly here but explodes in your full training job, the difference is probably in your dataset formatting, masking, sequence length, trainer defaults, or MPS pressure. That conclusion is consistent with the Gemma 4 cache-path bug report, the mm_token_type_ids report, and Google’s requirement to use the model’s own chat formatting. (GitHub)

If you are using TRL later, override these defaults

If your manual test works but SFTTrainer does not, check TRL first. TRL currently documents that SFTTrainer defaults gradient_checkpointing=True, bf16=True if fp16 is not set, and learning_rate=2e-5. For Gemma 4 on MPS, the important one is checkpointing, because it can route you straight into the uncached path that was recently fixed in Transformers 5.5.2. (Hugging Face)

A minimal safe TRL config for diagnosis looks like this:

from trl import SFTConfig

cfg = SFTConfig(
    output_dir="out",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-5,
    max_seq_length=256,
    gradient_checkpointing=False,  # important for Gemma 4 diagnosis
    bf16=False,
    fp16=False,                    # manual tests first; use fp16 only if you must
    packing=False,
    max_grad_norm=0.3,
    logging_steps=1,
)

What results map to what root cause

If Script 1 fails only when use_cache=False, you found the known Gemma 4 path bug. Upgrade Transformers first, and avoid forcing use_cache=False while isolating. Transformers 5.5.2 specifically shipped a Gemma-4 use_cache=False fix. (GitHub)

If Script 1 only works after adding token_type_ids and mm_token_type_ids, your issue is likely the current batch-contract problem for Gemma 4 text-only fine-tuning. In that case, keep the custom collator and remove_unused_columns=False when you move back to trainer code. (GitHub)

If Script 1 is fine and Script 2 cannot overfit one sample, suspect MPS numerics or a remaining formatting bug. Apple still documents MPS as beta, and PyTorch only documents MPS AdamW support for float16 and float32. That is why I would debug on float32 for tiny tests if memory allows, then move to float16 for real runs on MPS. (Apple Developer)

If both scripts pass, your full run is probably failing because of dataset construction, long sequence length, assistant-only masking, packing, or TRL defaults rather than because “Gemma 4 cannot fine-tune on Mac.” Google’s own Gemma 4 docs emphasize using the built-in processor and chat template, and Gemma 4’s prompt structure is more specific than older Llama/Qwen-style plain chat formatting. (Google AI for Developers)

1 Like

Hi @John6666 thank you so much for your amazing support and feedback and especially the time you take to provide comprehensive complete answers. I’m truly grateful for your support! :folded_hands: :heart:

To give you some feedback regarding few of the points you mentioned:

  • If your pipeline already needed chat-template fixes: Yes I had to use a custom chat-template that includes {% generation %}.

  • So if your current training stack is text-only but uses a generic text collator: I found out that in the latest trl, data collator was deprecated in favor of assistant_only_loss with the customer chat-template that includes {% generation %}.

  • It also loads the tokenizer from google/gemma-4-E2B-it: I tried but it didn’t work. The tokenizer from google/gemma-4-E2B-it seems to be designed for inference only and does not include the right chat-template required for fine-tuning. I had to use the tokenizer from google/gemma-4-E2B and add a chat-template manually.

  • For a MacBook M3, that means LoRA should be your baseline, not full fine-tuning: Indeed I’m using LoRa.

  • Use the official chat template only: I clarified above, the official google/gemma-4-E2B-it chat-template doesn’t seem to work for fine-tuning and google/gemma-4-E2B doesn’t include a chat-template. So manual adding one is required.

  • Keep the dataset in standard system / user / assistant roles: Indeed I did away with the multi-modal dataset format and opted for a simple format as follows {“messages”: [{“role”: “user”, “content”: “…”}, {“role”: “assistant”, “content”: “…”}]}

  • Inspect the first collated batch. Check that you have input_ids, attention_mask, labels, and, if your path requires them, token_type_ids and mm_token_type_ids: My batch output is as shown below:

Batch Keys: dict_keys(['input_ids', 'labels', 'attention_mask'])

--- Decoded Input IDs ---
<bos><start_of_turn>user
What products does the solution support?<end_of_turn>
<start_of_turn>model
Meetings, Chat, Docs, Notes, Workflows, Videos.<end_of_turn>


--- Labels (Tokens the model is trained to predict) ---
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>Meetings, Chat, Docs, Notes, Workflows, Videos.<<pad><pad><pad><pad><pad><pad><pad>
  • TRL vs Unsloth for your case: Indeed I’m using trl.

I hope this helps give you feedback. Thank you so much! :folded_hands: :heart:

1 Like

Oh… Since Transformers and TRL are currently undergoing fairly major renovations, there are just too many potential culprits in this case…


Your feedback changes the diagnosis a lot.

The main point

Your problem now looks much less like “bad hyperparameters” and much more like a mismatch between Gemma 4’s training input contract and TRL’s assistant-only training path. The strongest clue is your batch: it only contains input_ids, labels, and attention_mask, while there is an active Transformers issue showing that Gemma 4 text-only fine-tuning may still require token_type_ids and mm_token_type_ids during training, with a custom collator and remove_unused_columns=False called out as the practical workaround. (GitHub)

That means your current setup may be doing two things at once: the loss mask is partially right, but the model inputs are still incomplete. Those are different layers. TRL’s assistant_only_loss=True controls which tokens contribute to loss. It does not guarantee that every extra tensor Gemma 4 expects at training time is present in the batch. (Hugging Face)

What your label dump suggests

Your decoded labels are actually useful. They suggest that the assistant answer is being supervised while the user turn is masked out. That is a good sign. It means your custom {% generation %} template is not obviously failing at the most basic level. TRL’s docs say assistant-only loss depends on templates that support {% generation %} / {% endgeneration %} masking, so your workaround is conceptually in the right direction. (Hugging Face)

But that does not prove the template is fully correct. A template can look fine when decoded and still produce a slightly wrong assistant mask, wrong end-of-turn coverage, or a mismatch with how Gemma 4 expects turns to be structured during training. TRL’s own docs make clear that assistant-only loss depends on the template returning the right assistant-token mask, and a current TRL tracking issue exists precisely because many models do not ship training-ready templates with those generation markers. (Hugging Face)

On your tokenizer observation

I cannot verify the claim that google/gemma-4-E2B-it is “for inference only.” The official Google Gemma 4 QLoRA guide still uses the tokenizer from google/gemma-4-E2B-it explicitly to get the official template. But your practical observation still makes sense in context: the official template may be fine for inference or standard conversational formatting, while still not being enough for TRL assistant-only SFT, which specifically needs generation markers for training masks. So your manual template patch is not inherently suspicious. It is just another place where mistakes become easy. (Google AI for Developers)

What I now think is happening in your case

1. Missing Gemma-4-specific batch tensors is the top suspect

This is now my number one suspect, by a clear margin. Your current batch shape matches the public Gemma 4 issue almost too well. The issue explicitly says Gemma 4 text-only fine-tuning may require both token_type_ids and mm_token_type_ids, even when your dataset is only user and assistant text. That is unusual compared with Llama or Qwen, but it is exactly the kind of multimodal-family residue that Gemma 4 currently exposes. (GitHub)

2. The use_cache=False / checkpointing path is the second suspect

This is the other big one. There is a recent Transformers issue saying Gemma 4 training with use_cache=False can corrupt attention and produce garbage logits, and Transformers 5.5.2 specifically shipped a fix for Gemma 4 in that area. TRL also defaults gradient_checkpointing=True, and that often pushes training into the uncached path. So even if your template and labels are mostly fine, the actual forward pass may still be unstable for reasons upstream of your data. (Hugging Face)

3. Your template may be “good enough to run” but still not exact

Because you had to create a custom {% generation %} template, I would treat it as plausible, not trusted. Gemma-family prompt structure uses <start_of_turn>user, <start_of_turn>model, and <end_of_turn> markers, so the format you showed is not obviously wrong. But the training mask has to align with the assistant response exactly, and tiny mismatches there can produce stubbornly bad loss without causing a clean crash. (Hugging Face)

4. MPS is likely amplifying the failure, not causing it first

Apple’s MPS backend is still documented as beta, and PyTorch’s MPS optimizer docs focus on float32 and float16, while public MPS issues still exist around bf16 support. So your suspicion about Mac precision is reasonable. I just no longer think it is the root cause. I think it is turning a brittle Gemma-4-specific setup into an explosive one. (Google AI for Developers)

Why TRL alone is not enough here

You are right that TRL now encourages assistant_only_loss rather than older custom completion collators in many setups. But that guidance is about loss masking, not about every model family’s extra forward inputs. Gemma 4 is precisely the kind of model where those two concerns diverge. So “the collator is deprecated” does not really invalidate the need for a small custom collator in your case. It only means the collator should not be the thing deciding the loss region. It can still be the thing adding missing tensors. (Hugging Face)

That distinction is the key insight for your case:

  • assistant_only_loss decides where loss applies
  • custom collator can still decide which tensors the model receives (Hugging Face)

My recommended fix order

First fix: add the missing tensors

Keep your conversational dataset and keep assistant_only_loss=True, but add a tiny collator that injects zero-filled token_type_ids and mm_token_type_ids, then set remove_unused_columns=False. That is the most directly evidence-backed change you can make right now. (GitHub)

import torch
from transformers import default_data_collator

class Gemma4TextCollator:
    def __call__(self, features):
        batch = default_data_collator(features)

        if "token_type_ids" not in batch:
            batch["token_type_ids"] = torch.zeros_like(batch["input_ids"])

        if "mm_token_type_ids" not in batch:
            batch["mm_token_type_ids"] = torch.zeros_like(batch["input_ids"])

        return batch

Second fix: disable the risky training path

For diagnosis, use:

  • gradient_checkpointing=False
  • do not force use_cache=False
  • packing=False
  • bf16=False
  • fp16=True on MPS, or very short fp32 smoke tests only
  • latest Transformers with the Gemma 4 fixes included (Hugging Face)

Third fix: verify the assistant mask mechanically

Do not trust the decoded labels alone. Ask the tokenizer/template path to return the assistant mask and check that it is non-empty and lines up exactly with the answer span. TRL’s masking behavior depends on that. (Hugging Face)

Fourth fix: shrink the task until it overfits

Run a tiny overfit test on 32–128 samples, sequence length 256–512, batch size 1, low LR. If it still cannot overfit, that is further evidence that the stack is wrong, not the data scale. Google’s own Gemma 4 QLoRA guide already uses small defaults like batch size 1 and max length 512, which supports this conservative debugging strategy. (Google AI for Developers)

The practical alternative I would seriously consider

For your case, I would strongly consider switching from conversational messages + assistant_only_loss=True to prompt-completion format for the first stable run.

Why? Because right now you are depending on a custom Jinja training template to infer the supervised region. Prompt-completion makes the target explicit. That removes one moving part. TRL supports both, and completion-style supervision is simply less fragile when the model family’s chat-template training path is still maturing. (Hugging Face)

This does not mean your current approach is wrong in theory. It means it is fragile in practice, especially on Gemma 4 + TRL + MPS.

My bottom-line conclusion

Based on everything you shared, my current ranking is:

  1. Missing token_type_ids / mm_token_type_ids
  2. gradient_checkpointing / use_cache=False path
  3. Subtle custom-template mask mismatch
  4. MPS precision instability
  5. Only then: learning-rate or ordinary tuning issues (GitHub)

So my advice is no longer “try different settings.” It is:

  • keep LoRA
  • keep the simple messages dataset for now
  • add a Gemma-4-specific collator
  • set remove_unused_columns=False
  • turn checkpointing off
  • avoid the uncached path while diagnosing
  • verify the assistant mask directly
  • only then return to hyperparameter tuning (GitHub)
1 Like

Sharing some success with Gemma-4-E4B (Q8_0) tuning!

I’ve been experimenting with the “overclocking” feel of parameter tweaks on my headless ROCm server. I found that making subtle “shuttle changes” to the inference settings really tightened up the model’s performance.

For those running the E4B variant, these settings significantly cut back on “rambling” and reduced thinking latency without the logic falling apart:

  • Temperature: 0.8 (keeps it opinionated)

  • Top_P: 0.85 / Top_K: 40 (narrower, faster search area)

  • Repeat Penalty: 1.1 (just enough to kill the logic loops)

It feels like tuning a GPU—if you push too hard, it breaks, but hitting this sweet spot makes it feel much more surgical. Anyone else found a “magic” parameter set for the Gemma-4 family?

1 Like