Patch Mpt 🔔

# Test rotary cache fix rotary = PatchedRotaryEmbedding(dim=64, max_seq_len=512) x = torch.randn(1, 10, 64) cos1, sin1 = rotary(x, seq_len=10) cos2, sin2 = rotary(x, seq_len=20) # seqlen changes → recalc cache assert cos1.shape[0] == 10 assert cos2.shape[0] == 20 print("Rotary cache patch: OK")

def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): if seq_len == self._cached_seq_len: return inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)) t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self._cached_cos = emb.cos().to(dtype) self._cached_sin = emb.sin().to(dtype) self._cached_seq_len = seq_len patch mpt

# If already 4D, assume correct if attention_mask.dim() == 4: return attention_mask.to(dtype) x2 = x.chunk(2

# patches/mpt_patch_rotary_cache.py """ Patch for MPT model: - Fix rotary embedding cache when sequence length changes between forward passes. - Correct attention mask broadcasting for cross-attention layers. """ import torch import torch.nn as nn from typing import Optional, Tuple 1. Patch Rotary Embedding Cache ---------------------------------------------------------------------- def patched_rotate_half(x: torch.Tensor) -> torch.Tensor: """Split and rotate half the hidden dims (fixed for fp16 stability).""" x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) dim=-1) return torch.cat((-x2

# Convert to additive mask (0 = keep, -inf = mask) return mask.to(dtype).masked_fill(mask == 0, 0.0).masked_fill(mask == 1, float("-inf")) 3. Monkey-patch into existing MPT model (example) ---------------------------------------------------------------------- def apply_mpt_patches(model: nn.Module): """Replace rotary and mask functions in an existing MPT model.""" # Patch rotary class if found for name, module in model.named_modules(): if "rotary" in name.lower() and hasattr(module, "cos_cached"): module. class = PatchedRotaryEmbedding print(f"[PATCH] Replaced rotary in name")