WeLM-v4.5-80B-A3B-Instruct-0608:同一份权重,transformers 4.x 推理正常、5.x 坍塌 的根因排查

问题域:trust_remote_code 自研模型 · YaRN + partial rotary · device_map 加载 · transformers 4.57.1 vs 5.6.0(torch / Python 完全相同)
结论:根因是 RoPE 的 inv_freq buffer 在 transformers 5.x 的 meta-device 初始化下被算成全 0,导致 RoPE 完全失效。已修复并端到端验证。
目录
  1. 一、现象
  2. 二、结论速览(TL;DR)
  3. 三、排查方法与约束的确立
  4. 四、逐步定位(含走过的弯路与纠正)
  5. 五、根因机理
  6. 六、决定性证明
  7. 七、修复方案与落点
  8. 八、影响范围:哪些场景会中招、哪些不会(含 ZeRO-2/3)
  9. 附录:实验脚本清单

一、现象

同一个 checkpoint(WeLM-v4.5-80B-A3B-Instruct-0608),用 AutoModelForCausalLM.from_pretrained(..., device_map="auto", low_cpu_mem_usage=True) 加载后做贪心解码:

环境输入「你好」末层 logitstop1 token生成结果
transformers 4.57.1max=30.0 std=2.98你好«你好!有什么可以帮你的吗?»正常
transformers 5.6.0max=19.25 std=3.14\n\n坍塌成重复 \n\n / </think>坍塌

直观对比同一个问题、同一份权重,只换 transformers 版本,输出就从「正常对话」退化成「无意义重复」:

4.57.1 · 正确

>>> 输入:你好
你好!有什么可以帮你的吗?

>>> 输入:用一句话介绍你自己
我是 WeLM-v4.5-80B-A3B-Instruct,
一个由腾讯训练的对话助手。

5.6.0 · 错误

>>> 输入:你好
\n\n\n\n\n\n\n\n …(无限重复)

>>> 输入:用一句话介绍你自己
</think></think></think> …
/ 重复乱码,完全不成句
现象本质:两个版本加载的是完全相同的权重文件,torch / Python / 输入 / 解码参数也全相同——唯一差别只有 transformers 版本。4.x 输出可正常对话的结果,5.x 直接坍塌成重复 token(\n\n</think> 等)、不成句、无法使用。这种「换个库版本就坏,但权重一字未改」的现象,强烈指向加载阶段而非权重本身的问题。

注意:5.6 的 logits std 没有下降(2.98 → 3.14),不是「整体被压平」,而是 排序完全错乱——这个细节后来成了重要线索。

二、结论速览(TL;DR)

根因(一句话):WeLMV4MoeRotaryEmbedding.inv_freq 被注册为 persistent=False 的 buffer(不进 checkpoint,靠 __init__ 现算)。transformers 5.x 在 meta-device 加载device_map / low_cpu_mem_usage)下,meta 上算出的 inv_freq 落地成全 0,且加载后不再重新物化它 → freqs = inv_freq(全0) @ position_ids = 0cos≡1 / sin≡0RoPE 完全失效 → 注意力丢失全部位置信息 → logits 坍塌。transformers 4.57.x 加载后会正确物化该 buffer,故无此问题。

子步(L0)4.57.15.6.0是否一致
embedding 输入norm=10.428norm=10.428一致 ✓
input_ln / q / k / v 投影逐位完全相同一致 ✓
attention_mask(1,1,13,13) 严格 causal,78 个 -inf一致 ✓
RoPE inv_freqnorm=1.337norm=0.000(全 0)唯一分叉 ✗
→ cos / sin27.19 / 9.6328.84 / 0.00sin≡0 ✗
→ self_attn 输出起55.76855.557(级联坍塌)
末层 logits30.0 / 你好19.25 / \n\n

修复:加载后对每个 rotary 用 rope_init_fn 重算 inv_freq(仅在异常时)。5.x 自动修复、4.57.x 为 no-op。已实测:修复后 5.6 logits 与 4.57 逐位一致(30.0 / 你好)。

三、排查方法与约束的确立

这套 modeling 是 trust_remote_code 加载的本地代码(与 transformers 版本无关)。因此版本差异只可能来自 modeling 调用的 transformers 包内组件。排查前先钉死了两个关键约束,避免被噪声误导:

约束 1运行环境唯一变量是 transformers。

venv457 :  python 3.12.3   torch 2.11.0+cu130   transformers 4.57.1
system  :  python 3.12.3   torch 2.11.0+cu130   transformers 5.6.0

torch / Python 完全相同 → 排除「底层算子(softmax/matmul/MoE)随 torch 版本变化」这一大类干扰。任何差异都归于 transformers。

约束 2两版各自完全确定性。同进程连续两次 forward:logits #1 == #2,maxdiff=0.000000(两版皆然)。

→ 之前误判的「非确定性噪声」其实是中途反复改文件 / 清 HF 缓存造成的脏状态。确立确定性后,任何跨版本差异都是 100% 系统性的,逐子步指纹对比才可信。

四、逐步定位(含走过的弯路与纠正)

本节如实记录排查过程,包括两次被污染数据误导的弯路,以及最终如何用干净实验纠正。

1

排除分词器

逐 token id 对比:你好→[44205]、chat 模板套用后 13 个 id 两版逐位一致。唯一差别是类名(Qwen2TokenizerFast vs Qwen2Tokenizer),输出 id 序列完全相同。分词器排除。

2

排除 attention 派发 / sink / mask

逐一验证两版相同:embedding、RMSNorm、attention_mask(create_*_mask 输出逐值相同)。又测:两版默认 _attn_implementation 都是 eager(走本地 eager_attention_forward),attn_sink 也相同(sum=2.3861)。强制切 sdpa 两版都会变坏(但坏法不同),说明 sink 处理确实重要,但默认路径两版一致,并非版本差异来源。

3

弯路 A:误判 RoPE「inv_freq 维度算错」

早期一个实验(脏缓存状态)显示 5.6 的 inv_freq.shape=(128,) 而 4.57 是 (32,),一度认为是 5.x 忽略了 partial_rotary_factor=0.25后被证伪——干净状态下两版 shape 都是 (32,),维度并没错。

4

弯路 B:误判是 @dynamic_rope_update 装饰器,并写了无效补丁

另一个实验显示「同一份正确 inv_freq,被装饰的 forward → sin=0,手动绕过装饰器 → sin=9.63」,于是认为是 @dynamic_rope_update(223 行)在 5.x 下的副作用,并实现了「绕过装饰器」补丁加进 monkey patch。

这个补丁是错的(已撤销)。终极证明里把 49 个 forward 全部替换为「绕过装饰器」版本后,logits 仍是 19.25——因为绕过版仍然读 self.inv_freq,而它本身就是全 0。真正的问题不在 forward 怎么算,而在 inv_freq 这个输入就已经是 0。
5

干净定位:L0 逐子步指纹

在确定性前提下,hook L0 的每个子步打印指纹(norm / absum / first4)。结果第一个、也是唯一的分叉点是 rope 的 cos/sin:embedding、input_ln、q/k/v 投影(rope 之前)全部逐位相同,唯独 cos.norm=27.19 vs 28.84sin.norm=9.63 vs 0.00。其中 cos.norm=28.844=√(13×64) 正是「所有 cos 元素都=1」的范数 → 确认 5.6 的 RoPE 完全不旋转。

6

钉死根因:inv_freq 本身是全 0

直接 dump 真实加载模型的 rotary 状态:

rope_type=yarn  attention_scaling=1.00000
inv_freq.shape=(32,) norm=0.00000  first4=[0.0, 0.0, 0.0, 0.0]

shape 是对的(32),但值全是 0attention_scaling(纯 Python float、与 device 无关)正确算出 = 1.0,只有 inv_freq 这个 tensor 中招——这正是「meta-device 上 tensor 落地为 0、float 标量却正常」的典型特征。

五、根因机理

modeling 里 RoPE 的定义(节选):

class WeLMV4MoeRotaryEmbedding(nn.Module):
    def __init__(self, config, device=None):
        ...
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]        # 从 transformers 导入
        self.config.partial_rotary_factor = config.qk_rope_head_dim / config.head_dim   # 0.25
        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)   # ← 不进 checkpoint

    @torch.no_grad()
    @dynamic_rope_update
    def forward(self, x, position_ids):
        inv_freq_expanded = self.inv_freq[None,:,None].float()...        # ← 读 self.inv_freq
        freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1,2)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos() * self.attention_scaling
        sin = emb.sin() * self.attention_scaling
        return cos, sin
失效链条:
  1. inv_freqpersistent=False buffer → 不存进 checkpoint,只能靠 __init__rope_init_fn 现算。
  2. from_pretrained(device_map="auto", low_cpu_mem_usage=True) 先在 meta device 上构造模型 → __init__ 在 meta 上算 inv_freq → 落地成全 0
  3. 因为它不在 checkpoint 里,transformers 5.x 加载完权重后不会重新物化它inv_freq 恒为 0。
  4. forward:freqs = 0 @ position_ids = 0cos = cos(0) = 1sin = sin(0) = 0RoPE 等于没旋转
  5. q/k 不带位置信息 → 注意力无法区分 token 先后 → 整个网络在「无位置编码」下前向 → 末层 logits 排序错乱 → 贪心解码坍塌。

为什么 4.57.x 没问题

transformers 4.57.x 在 from_pretrained 完成后,会对这类不在 checkpoint 中的非持久 buffer 正确重新物化(按真实 device 重算 inv_freq)——这一步在 4.x 里是加载流程中独立的一环,与 _init_weights 无关,总会执行。transformers 5.0 重写了加载流程,把这一步搬进了基类 PreTrainedModel._init_weights(含 rope_scaling→rope_parametersstandardize_rope_params 等)。这一改动本身没问题——只要 modeling 的 _init_weights 会经过基类逻辑。本模型恰恰重写了 _init_weights 又没调 super()(详见下一小节的源码层定论),于是该步骤被静默绕过,inv_freq 留在 meta 落地的全 0 状态。这是唯一的版本差异点。

注意:与「sin 为何为 0」的两种解释。表面看 @dynamic_rope_update 也会让 forward 输出 sin=0,但那只是表象——真正原因是它的输入 self.inv_freq 已经是 0。绕过装饰器并不能修复(已实测无效);必须修复 inv_freq 本身

源码层定论:根因是自研 _init_weights 不调 super(),绕过了基类的 rope 重物化

进一步实地核对本机 transformers 5.8.1 与自研 modeling_welmv4_moe.py 源码后,把根因钉死在源码层。先纠正一个一度的误判:

更正:缺 original_inv_freq 并不是原因。核对 modeling_welmv4_moe.py:218-219,自研 rotary 其实注册了 original_inv_freq,类名也含 "RotaryEmbedding":
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq          # ← 实际是有的

所以那道 hasattr(module, "original_inv_freq") 门槛它能过。真正的问题在更上游——它的 _init_weights 根本没跑到那段基类逻辑

关键在于:transformers 5.x 把 inv_freq 的重物化逻辑搬进了基类 PreTrainedModel._init_weightsmodeling_utils.py:2402-2411 那个 "RotaryEmbedding" in name and hasattr(original_inv_freq) 分支)。谁的 _init_weights 能跑到这段基类逻辑,谁就安全。

而自研 modeling_welmv4_moe.py:869-884 重写了 _init_weights 却没调 super(),且只处理 nn.Linear / nn.Embedding完全没有 RotaryEmbedding 分支

def _init_weights(self, module):
    ...
    if isinstance(module, nn.Linear):
        ...
    elif isinstance(module, nn.Embedding):
        ...
    # ✗ 没有 super()._init_weights(module),也没有 rotary 分支
    # → 基类那段 rope inv_freq 重物化从不执行
模型_init_weights 行为基类 rope 重物化分支5.x meta-load 结果
Qwen3.6-27B
(库内置 qwen3_5)
def _init_weights:
  super()._init_weights(module)
  ...
(modeling_qwen3_5.py:813-814 先调 super)
✅ 执行 inv_freq 被重算 → 安全
WeLM-v4.5
(自研)
重写 _init_weights
只管 Linear/Embedding
不调 super、无 rotary 分支
❌ 从不执行 inv_freq 维持 meta 落地的 0 → 坍塌

这就解释了 4.x vs 5.x 的全部差异:4.57.x 时 inv_freq 的重物化是加载流程里独立的一步(与 _init_weights 无关,总会跑);5.x 重构后把它搬进了 _init_weights,于是任何「重写 _init_weights 又不调 super」的自研 modeling 就静默漏掉了这一步

最终结论(源码层确定):
治本修法(比当前兜底更干净,可二选一或都做)。modeling_welmv4_moe.py_init_weights 里补一行让基类逻辑生效:
def _init_weights(self, module):
    super()._init_weights(module)   # ← 关键:让 5.x 的 rope inv_freq 重物化分支跑起来
    cfg = self.config
    ...

现有「加载后手动重算 inv_freq」的兜底(第七节)依然有效且更保险——不依赖 transformers 内部行为、跨版本稳定,可保留作双保险

六、决定性证明

在 5.6 单次加载内做 A/B:默认 vs 加载后重算 inv_freq。

===== [baseline] 默认加载(device_map=auto/meta) =====
[baseline] inv_freq = shape=(32,) norm=0.00000  first4=[0.0, 0.0, 0.0, 0.0]
[baseline] logits.max=19.2500  top1='\n\n'              ← 坍塌

===== [realfix] 加载后重新物化 inv_freq =====
[realfix] inv_freq = shape=(32,) norm=1.33674  first4=[1.0, 0.6636, 0.44037, 0.29223]
[realfix] logits.max=30.0000   top1='你好'              ← 完全恢复

修复后 5.6 的 top5 与 4.57 逐位一致[30.0, 22.25, 20.75, 17.25, 16.875] → ['你好','您好','Hello','�','好的']

端到端验证(5.6.0 + 修复,真实 smoke_test):
[rope-fix] WeLMV4MoeRotaryEmbedding 共 49 个,重物化 inv_freq 49 个
[TEXT] OUTPUT:
<report>你好!有什么可以帮你的吗?</report>

七、修复方案与落点

核心思路:加载后重新物化 inv_freq,且仅在异常时(meta / 全 0 / NaN / 缺失)才动手,保证 4.57.x 等正常路径为 no-op、不引入任何行为变化,也无需锁版本。

def _rematerialize_rope_inv_freq(model, is_rank_0=True):
    for m in model.modules():
        if m.__class__.__name__ != "WeLMV4MoeRotaryEmbedding":
            continue
        inv = getattr(m, "inv_freq", None)
        needs_fix = (inv is None or getattr(inv,"is_meta",False)
                     or not bool(torch.isfinite(inv).all())
                     or float(inv.float().abs().sum()) == 0.0)
        if not needs_fix:
            continue                       # 4.57.x 正常 -> no-op
        inv_freq, scaling = m.rope_init_fn(m.config, device=torch.device("cpu"))
        m.register_buffer("inv_freq", inv_freq.to(torch.float32), persistent=False)
        m.attention_scaling = scaling
        m.original_inv_freq = m.inv_freq   # 5.x meta-load -> 重算覆盖全 0
文件改动调用时机
monkey_patches/
monkey_patch_welm_v45_oe.py
定义 _rematerialize_rope_inv_freq + 主入口调用(3.2 段)加载 + flash 修复后、deepspeed.initialize
monkey_patches/
monkey_patch_welm_v45.py
修掉坏 import(原 import 已删除的 _patch_rotary_no_dynamic_update,会 ImportError)→ 改 import 正确函数并调用同上
script/welm_v45_assets/
smoke_test_welm_v45.py
推理侧 load() 内调用(已端到端验证)加载后、eval()
顺带修掉的隐患:上一轮撤销「绕过装饰器」补丁时漏删了 OFF 版里的 import 行,它仍引用已删除的 _patch_rotary_no_dynamic_update——只要训练走到 OFF 路径就会直接 ImportError。本次一并修正。

八、影响范围:哪些场景会中招、哪些不会

触发条件可以归纳为一个充要组合transformers 5.x 模型在 meta device 上被构造(构造 __init__inv_freq 落地为全 0),加载完权重后又没有重新物化它。任何不满足这个组合的场景都不受影响

场景是否中招原因
transformers 4.57.x(任意加载方式)否 ✓加载完成后会对「不在 checkpoint 中的非持久 buffer」重新物化,inv_freq 按真实 device 重算正确
transformers 5.x + 普通加载
low_cpu_mem_usage=False 且无 device_map
否 ✓不经过 meta,__init__ 直接在真实 device 上算出 inv_freq(norm=1.337),全程正常
transformers 5.x + device_map="auto" / low_cpu_mem_usage=True是 ✗本报告的坍塌场景:meta 上算出 inv_freq=0,加载后不再物化 → cos≡1/sin≡0
transformers 5.x + DeepSpeed ZeRO-2取决于加载方式inv_freq 是 buffer,ZeRO-2 不切分它;是否中招只看 from_pretrained 本身是否走 meta(带 device_map/low_cpu_mem_usage 即中招,否则安全)
transformers 5.x + DeepSpeed ZeRO-3zero.Init()需注意zero.Init 只分片 parameter不分片 buffer,所以 ZeRO-3 本身不会破坏 inv_freq;但 ZeRO-3 常配合 meta/low_cpu_mem_usage 构造模型,一旦走 meta 路径,inv_freq 仍会落地为 0。判据依旧是「是否经过 meta 构造」
训练侧为何更危险:训练用自研 cute flash_attn attention,但 cos/sin 仍由 self.rotary_emb(...) 生成后喂入(monkey_patch_welmv4_5_moe_v2.py)。一旦 inv_freq=0,模型会在「无位置编码」下静默训练——loss 可能照降、不报错,但模型实际学坏,比推理直接坍塌更难察觉。

统一解法:加载后兜底物化,与场景无关

修复不依赖「判断当前是哪种场景」,而是在所有路径统一加一道兜底:模型加载完成后、deepspeed.initialize / ZeRO 包装之前,遍历每个 WeLMV4MoeRotaryEmbedding,仅当 inv_freq 异常(is_meta / 全 0 / NaN / 缺失)时用 rope_init_fn 在真实 device 上重算覆盖(见第七节代码)。

补充:历史 checkpoint-216 能在 4.57 正常推理,说明它当初是用正常 RoPE 训出来的(4.57 或未触发此坑的加载方式),不是被这个 bug 训坏的权重。另:sliding window、attention_mask(cute 路径用 cu_seqlens,不消费 transformers 的 4D mask)等均已验证与 transformers 版本无关,不受影响。

附录:实验脚本清单

均在 SFT/debug/,通过 bridge 在远程对 4.57.1 / 5.6.0 两环境跑同一脚本对比。

脚本作用关键结论
_dbg_attn_impl_smoking_gun.py对比默认 / 强制 eager / 强制 sdpa 的 logits两版默认都是 eager;派发不是版本差异来源
_dbg_mask_into_attn.pyhook attention 实际收到的 mask + 逐层残差流 + 确定性自检mask 两版相同严格 causal;两版确定性;残差流从 L0 即分叉
_dbg_L0_micro.pyL0 逐子步指纹(embed→ln→qkv→attn→mlp)唯一分叉 = rope cos/sin(sin=0)
_dbg_rope_fix_proof.py验证「绕过 @dynamic_rope_update」是否有效无效(揭示 inv_freq 本身=0)
_dbg_rope_realfix_proof.py验证「重算 inv_freq」修复有效:logits 19.25→30.0 / \n\n→你好