Gemma4-e4b adaptors fuse after training , how?

i manage to do it with qwen3.5, Gemma-4-e2b but after many attempts AI assisted i don’t get anything but errors, more than one million tokens with deep seek, many scripts, even retrained the model with mlx-vlm and mlx-lm . My las adapters where produced with this: cmd = [
VENV, “-m”, “mlx_vlm.lora”,
“–model-path”, MODEL,
“–dataset”, DATA,
“–batch-size”, “1”,
“–iters”, “300”,
“–learning-rate”, “1e-5”,
“–grad-checkpoint”,
“–gradient-accumulation-steps”, “4”,
“–steps-per-save”, “50”,
“–output-path”, OUTPUT,
“–lora-rank”, “8”,
“–lora-alpha”, “16”,
“–train-on-completions”,
“–assistant-id”, “4368”,
]

and this is the fusion script :
#!/usr/bin/env python3
“”“Fusion LoRA adapters into Gemma 4 E4B 8-bit model.
Ejecutar DIRECTAMENTE en el Mini, no por SSH.
“””
import json, shutil
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
from mlx_vlm import load
from mlx_vlm.trainer.lora import LoRaLayer

BASE = Path(“/Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit”)
ADAPTER = Path(“/Volumes/ssd./ssd_gemma4/adaptadores_v2”)
SALIDA = Path(“/Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit_ssd_fused”)

adapter_config.json

with open(ADAPTER / “adapter_config.json”, “w”) as f:
json.dump({“rank”: 8, “alpha”: 16.0, “dropout”: 0.0}, f)

print(“Cargando modelo con adapters…”)
model, processor = load(str(BASE), adapter_path=str(ADAPTER))

print(“Fusionando capas LoRA…”)
to_update = {}
for name, module in model.named_modules():
if not isinstance(module, LoRaLayer):
continue
orig = module.original_layer
if isinstance(orig, nn.QuantizedLinear):
w = mx.dequantize(orig.weight, orig.scales, orig.biases, orig.group_size, orig.bits)
else:
w = orig.weight
lu = module.scale * (module.A @ module.B)
od, id_ = w.shape
if lu.shape == (id_, od):
wf = w + lu.T
elif lu.shape == (od, id_):
wf = w + lu
else:
print(" SKIP shape mismatch:“, name, w.shape, lu.shape)
continue
to_update[name] = wf
print(” %d capas fusionadas" % len(to_update))

print(“Cargando pesos originales…”)
all_w = {}
for sf in sorted(BASE.glob(“model-*.safetensors”)):
all_w.update(mx.load(str(sf)))
print(" %d tensores totales" % len(all_w))

Reemplazar pesos fusionados

for name, wf in to_update.items():
mk = name + “.weight”
if mk in all_w:
all_w[mk] = wf
for s in [“.scales”, “.biases”]:
all_w.pop(name + s, None)

VERIFICAR shapes conv contra originales y restaurar si cambiaron

print(“Verificando shapes conv…”)
restored = 0
for sf in sorted(BASE.glob(“model-*.safetensors”)):
orig = mx.load(str(sf))
for k, v in orig.items():
if len(v.shape) == 4 and k in all_w and all_w[k].shape != v.shape:
print(" RESTAURANDO %s: %s → %s" % (k, list(all_w[k].shape), list(v.shape)))
all_w[k] = v
restored += 1
print(" %d conv weights restaurados" % restored)

Guardar

if SALIDA.exists():
shutil.rmtree(SALIDA)
SALIDA.mkdir(parents=True)

print(“Guardando %d tensores…” % len(all_w))
mx.save_safetensors(str(SALIDA / “model.safetensors”), all_w)

Copiar configs

for f in BASE.glob(“*”):
if f.suffix in (“.json”, “.txt”, “.md”) or “tokenizer” in f.name:
shutil.copy2(f, SALIDA)

Eliminar index.json

idx = SALIDA / “model.safetensors.index.json”
if idx.exists():
idx.unlink()

print("OK " + str(SALIDA))

Verificacion final

print(“Verificando conv weights en output…”)
d = mx.load(str(SALIDA / “model.safetensors”))
ok = True
for k, v in d.items():
if “conv” in k.lower() and len(v.shape) == 4:
print(" %s: %s" % (k, list(v.shape)))

Verificar contra original

for sf in sorted(BASE.glob(“model-*.safetensors”)):
orig = mx.load(str(sf))
if k in orig and v.shape != orig[k].shape:
print(" ERROR: deberia ser %s" % list(orig[k].shape))
ok = False
if ok:
print(“TODOS los conv weights tienen shapes correctos!”)

Oh… This looks like a fairly complex case with several known layers of compound drift:


TL;DR

I would not read this as “Gemma 4 E4B adapters cannot be fused”.

A closer reading suggests something narrower:

Your training recipe is not obviously impossible.
The risky part is the custom fuse/export path, especially because it partially dequantizes an MLX quantized Gemma 4 E4B checkpoint, deletes some quantization metadata, writes a single model.safetensors, and removes the shard index.

There are nearby success examples. The closest one I found is deadbydawn101/gemma-4-E4B-Agentic-Opus-Reasoning-GeminiCLI-mlx-4bit, which says it trained with mlx_vlm.lora, rank 8, alpha 16, and fused 378 LoRA pairs into the base weights. But its merge path is different: it explicitly dequantizes and saves the result as a BF16 3-shard safetensors model.

So my current hypothesis is:

The LoRA math may be roughly right, but the output checkpoint is probably not a consistent Gemma 4 E4B MLX checkpoint.


Why this case is probably tricky

This is sitting at the intersection of several moving parts:

Layer Why it matters here
Gemma 4 E4B architecture E2B/E4B have Gemma 4-specific projection/shared-KV/multimodal structure.
mlx-vlm Gemma 4 support Recent releases include several Gemma 4-specific fixes.
LoRA scaling mlx-vlm had a known alpha vs alpha/rank scaling issue fixed in PR #846.
Quantized MLX checkpoint layout .weight, .scales, .biases, shard index, and config must stay consistent.
VLM vs text-only loader paths mlx_vlm and mlx_lm can expose different practical behavior.
Adapter vs fused behavior Adapter-loaded inference can work while fused checkpoints fail.
Server vs direct CLI inference mlx_vlm server --adapter-path has had an adapter-dropping cache issue: issue #907.

The mlx-vlm v0.5.0 release notes are also worth reading because they include multiple relevant fixes: Gemma 4 quantized per-layer projection loading, Gemma 4 audio fixes, LoRA alpha/rank scaling, Gemma 4 LoRA training fixes, etc. See Blaizzy/mlx-vlm releases.


Nearby success example vs this case

The closest success example I found is:

That example is not an official proof that all E4B LoRA fuse workflows are safe, but it is useful because it is very close.

Item Your case Nearby success example
Model family Gemma 4 E4B Gemma 4 E4B
Runtime family MLX / mlx-vlm MLX / mlx-vlm / mlx_lm
Training command family mlx_vlm.lora mlx_vlm.lora
LoRA rank 8 8
LoRA alpha 16 16
LR 1e-5 1e-5
Training style completions-only SFT completions-only
Base precision local E4B 8-bit E4B MLX 4-bit
Fuse count unknown from the post 378 LoRA pairs
Merge scale module.scale explicitly alpha / rank
Output format single model.safetensors, index removed BF16 3-shard safetensors
Quant metadata handling deletes selected .scales / .biases dequantized merged checkpoint
Loader path shown mlx_vlm.load(...) mostly mlx_lm text-generation examples

The most important difference is not 4-bit vs 8-bit by itself. The important difference is:

The success example appears to turn the merged model into a coherent BF16 checkpoint.
Your script may be producing a mixed quantized/floating checkpoint.


The biggest red flags in the custom fuse script

1. adapter_config.json is overwritten

This part is risky:

with open(ADAPTER / "adapter_config.json", "w") as f:
    json.dump({"rank": 8, "alpha": 16.0, "dropout": 0.0}, f)

Even if those values are correct this time, the fuse script should not rewrite the adapter metadata. If the adapter config contains target module information, naming conventions, or version-specific metadata, this can silently destroy useful information.

Safer:

print((ADAPTER / "adapter_config.json").read_text())

Do not modify it during fuse.


2. The output may become mixed quantized/floating-point

This is the biggest issue:

all_w[name + ".weight"] = w2
all_w.pop(name + ".scales", None)
all_w.pop(name + ".biases", None)

For each LoRA target layer, you are replacing the quantized layer with a dequantized/fused .weight and removing its .scales / .biases.

But unless you do the same coherently for the whole checkpoint and update the config accordingly, you can end up with something like:

Part of model Possible state after script
LoRA-target layers floating-point .weight, no .scales / .biases
non-target quantized layers still quantized .weight + .scales / .biases
config still copied from the original quantized model
shard index removed
output file layout single model.safetensors

That is not obviously a valid MLX Gemma 4 E4B checkpoint layout.

The nearby fused example says it dequantized the result to BF16 and saved as 3-shard safetensors. That is a much cleaner contract.


3. The shard/index behavior differs from the success example

Your script writes:

mx.save_safetensors(str(SALIDA / "model.safetensors"), all_w)

idx = SALIDA / "model.safetensors.index.json"
if idx.exists():
    idx.unlink()

That might work for some small/simple models, but Gemma 4 E4B MLX checkpoints have enough architecture-specific structure that I would avoid this unless I knew the loader accepted exactly this layout.

The nearby success example says:

Result dequantized to bfloat16 and saved as 3-shard safetensors.

So I would try to reproduce that style instead of collapsing everything into one file.


4. module.scale must be checked explicitly

Your script does:

lu = module.scale * (module.A @ module.B)

That is only safe if module.scale == alpha / rank.

For rank 8 and alpha 16, the expected value is:

alpha / rank = 16 / 8 = 2.0

This matters because mlx-vlm had a known scaling bug where LoRA used raw alpha instead of alpha/rank. See PR #846: fix alpha/rank scaling in LoRaLayer.

With rank 8 / alpha 16:

Scale used Effective LoRA strength
2.0 expected standard LoRA scaling
16.0 8x too strong

So please print it:

for name, module in model.named_modules():
    if isinstance(module, LoRaLayer):
        print(name, "scale=", float(module.scale), "A=", module.A.shape, "B=", module.B.shape)
        break

If this prints 16.0, that is a serious problem.


Known related drift points

This is why I think this is a compound drift case rather than one simple bug.

Area Link Relevance
LoRA scaling in mlx-vlm PR #846 Fixes raw alpha vs alpha/rank. Directly relevant to rank 8 / alpha 16.
Gemma 4 quantized projection loading PR #935 Shows Gemma 4 quantized projection loading was recently touched.
Gemma 4 embedding scaling PR #893 Earlier Gemma 4 MLX conversion/embedding behavior was not completely stable.
Gemma 4 LoRA training NaN / freeze leak PR #1052 Important if training used image/audio branches or if adapter size is unexpectedly large.
mlx_vlm server drops adapter after first request issue #907 Can make adapter testing look like base-model behavior if testing through server.
Gemma 4 checkpoint round-trip/shared-KV divergence mlx-lm issue #1210 Shows Gemma 4 checkpoint structure can diverge across MLX runtimes.
Gemma 4 E4B 4bit/8bit load drift mlx-lm issue #1242 Shows E4B quantized checkpoints can be sensitive to version/key expectations.
General PEFT LoRA guide PEFT LoRA developer guide Useful for standard LoRA merge mental model.
Gemma official docs Google Gemma docs Background on Gemma variants and tuning/deployment.
Gemma 4 release history Gemma releases Useful for tracking how recent Gemma 4 is.

What I would test before changing more code

1. Does the adapter work before fuse?

This is the most important split.

Run direct CLI inference, not server-based inference:

python -m mlx_vlm.generate \
  --model /Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit \
  --adapter-path /Volumes/ssd./ssd_gemma4/adaptadores_v2 \
  --prompt "Use a fixed validation prompt here." \
  --max-tokens 128 \
  --temperature 0.0

Also try the text path if the task is text-only:

mlx_lm.generate \
  --model /Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit \
  --adapter-path /Volumes/ssd./ssd_gemma4/adaptadores_v2 \
  --prompt "Use a fixed validation prompt here." \
  --max-tokens 128 \
  --temp 0.0

Interpretation:

Result Meaning
base and adapter output are identical adapter is not being applied, or target names/config are wrong
adapter works but fused is base-like fuse did not merge the important deltas
adapter works but fused is garbage scaling / dtype / quant metadata / checkpoint layout issue
adapter is already garbage training/data/template/NaN issue, not fuse issue

2. Print the number of fused LoRA layers

The nearby success example says 378 LoRA pairs.

Your script already has:

print(" %d capas fusionadas" % len(to_update))

Please report this number.

Expected ballpark, if your target coverage matches the nearby E4B example:

len(to_update) ≈ 378

If it is much lower, the fuse script is not seeing all LoRA layers.


3. Print module.scale

Expected:

rank = 8
alpha = 16
scale = 2.0

Minimal check:

for name, module in model.named_modules():
    if isinstance(module, LoRaLayer):
        print(name, "scale=", float(module.scale))
        break

If it prints 16.0, then you are likely applying an 8x-too-large LoRA delta.


4. Check adapter size and contents

The nearby E4B adapter example reports an adapter size around 658 MB:

Check yours:

du -h /Volumes/ssd./ssd_gemma4/adaptadores_v2/*

Then inspect whether it is really only LoRA tensors or if other weights got saved too:

from pathlib import Path
import mlx.core as mx

adapter_dir = Path("/Volumes/ssd./ssd_gemma4/adaptadores_v2")

for f in adapter_dir.glob("*.safetensors"):
    print("FILE", f)
    w = mx.load(str(f))
    print("tensor count:", len(w))

    suspicious = []
    for k in w:
        if any(s in k for s in ["audio", "vision", "embed_audio", "embed_vision"]):
            suspicious.append(k)

    print("suspicious audio/vision/embed keys:", len(suspicious))
    for k in suspicious[:50]:
        print(" ", k, w[k].shape)

Why this matters: PR #1052 mentions a Gemma 4 LoRA training issue involving vision backward NaNs and an audio_tower freeze leak. If non-LoRA weights were saved into the adapter, a simple LoRaLayer-only fuse script may silently drop them.


5. Check for NaN/Inf in the adapter

from pathlib import Path
import mlx.core as mx

adapter_dir = Path("/Volumes/ssd./ssd_gemma4/adaptadores_v2")

for f in adapter_dir.glob("*.safetensors"):
    w = mx.load(str(f))
    bad = []
    for k, v in w.items():
        vf = v.astype(mx.float32)
        if bool(mx.any(mx.isnan(vf)).item()) or bool(mx.any(mx.isinf(vf)).item()):
            bad.append(k)

    print(f, "bad tensors:", len(bad))
    for k in bad[:20]:
        print(" ", k, w[k].shape)

If this finds NaN/Inf tensors, the adapter is already compromised before fusion.


6. Compare base vs fused key sets

This is where the mixed-checkpoint issue should become visible.

from pathlib import Path
import mlx.core as mx

BASE = Path("/Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit")
FUSED = Path("/Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit_ssd_fused")

def load_dir(p):
    out = {}
    for sf in sorted(Path(p).glob("*.safetensors")):
        out.update(mx.load(str(sf)))
    return out

base = load_dir(BASE)
fused = load_dir(FUSED)

base_keys = set(base)
fused_keys = set(fused)

print("base only:", len(base_keys - fused_keys))
for k in sorted(base_keys - fused_keys)[:100]:
    print("BASE_ONLY", k, base[k].shape)

print("fused only:", len(fused_keys - base_keys))
for k in sorted(fused_keys - base_keys)[:100]:
    print("FUSED_ONLY", k, fused[k].shape)

scale_keys = [k for k in fused if k.endswith(".scales")]
bias_keys = [k for k in fused if k.endswith(".biases")]
print("fused .scales:", len(scale_keys))
print("fused .biases:", len(bias_keys))

If only some quantization metadata remains, the checkpoint is probably not coherent.


What I would change in the fuse strategy

I would not try to preserve the original 8-bit checkpoint format in the first pass.

Instead, I would mimic the nearby success example:

  1. Load the base model.
  2. Load the adapter.
  3. Dequantize the base weights.
  4. Merge LoRA with alpha / rank.
  5. Save a coherent BF16 checkpoint.
  6. Only after that, optionally quantize again.

Conceptually:

quantized base + adapter
        ↓
dequantized BF16 base
        ↓
BF16 merged/fused model
        ↓
optional re-quantization

Not:

quantized base
        ↓
replace only some layers with floating-point fused weights
        ↓
delete selected .scales/.biases
        ↓
single model.safetensors with original config

The latter is much more likely to break loader expectations.


Practical recommendation

Short-term

Do not fuse.

Use the adapter directly:

python -m mlx_vlm.generate \
  --model /Users/hal9000/Desktop/AI/modelos/gemma_4_e4b_it_8bit \
  --adapter-path /Volumes/ssd./ssd_gemma4/adaptadores_v2 \
  --prompt "your prompt" \
  --max-tokens 256 \
  --temperature 0.0

If this works, your adapter is probably okay and the problem is mainly fuse/export.


Medium-term

Build a BF16 fused model rather than an in-place-ish quantized fused model.

The closest public success example says it used:

W_merged = dequantize(W_base) + (A @ B).T × (alpha / rank)

and saved the result as BF16 3-shard safetensors.

So I would aim for:

Requirement Target
LoRA pairs near 378, if matching that E4B coverage
scale alpha / rank = 2.0
output dtype BF16
quant metadata no half-removed mixed state
file layout sharded safetensors if large
config consistent with BF16 model, not stale 8-bit quant config

Long-term

Wait for, or request, an official mlx-vlm fuse/export path for Gemma 4 E4B VLM adapters.

There are enough Gemma 4-specific fixes around MLX loading/training/quantization that a hand-written fuse script is fragile. The relevant upstream surface is still moving.


My final read

I think your result is probably close, but the current fuse script is crossing too many contracts at once.

The training settings are not obviously wrong because a nearby E4B example uses similar settings:

Setting You Nearby example
mlx_vlm.lora yes yes
rank 8 8
alpha 16 16
LR 1e-5 1e-5
completions-only yes yes

The main divergence is after training:

Fuse/export detail Your script Nearby example
base precision local 8-bit MLX 4-bit
merge result partially dequantized mixed state possible BF16 merged checkpoint
LoRA pair count unknown 378
scale module.scale, must verify alpha / rank
output files single model.safetensors 3-shard safetensors
quant metadata selectively deleted dequantized output

So I would debug this as:

adapter correctness first, then LoRA coverage/scaling, then checkpoint serialization.

Not as:

Gemma 4 E4B is impossible to fuse.

The fastest useful data points to post back would be:

mlx-vlm version:
mlx-lm version:
mlx version:
adapter size:
dynamic adapter output differs from base? yes/no
len(to_update):
first few module.scale values:
NaN/Inf in adapter? yes/no
number of .scales/.biases left in fused checkpoint:
base-only/fused-only key count:

Those numbers should make the failure class much clearer.