MarkTechPost AI生成式AI

How to Build Memory-Efficient Transformers with xFormers Using Packed Sequences, GQA, ALiBi, SwiGLU, and Causal Attention

2026年6月17日 00:02

重點摘要

In this tutorial, we implement xFormers: a practical toolkit for building fast, memory-efficient Transformer models on GPUs. We begin by validating memory-efficient attention against a standard attention implementation, then compare their speed and memory consumption across different sequence lengths. We then examine causal masking, packed variable-length sequences, grouped-query attention, and custom ALiBi positional biases. Finally, we combine these techniques into a trainable GPT-style model that uses xFormers attention, SwiGLU feed-forward layers, and automatic mixed-precision training. Setting Up xFormers and Validating Memory-Efficient Attention Copy CodeCopiedUse a different Browserimport subprocess, sys def _pip(*a): subprocess.run([sys.executable, "-m", "pip", "install", *a], chec

站內 AI 整理稿

In this tutorial, we implement xFormers: a practical toolkit for building fast, memory-efficient Transformer models on GPUs. We begin by validating memory-efficient attention against a standard attention implementation, then compare their speed and memory consumption across different sequence lengths. We then examine causal masking, packed variable-length sequences, grouped-query attention, and custom ALiBi positional biases. Finally, we combine these techniques into a trainable GPT-style model that uses xFormers attention, SwiGLU feed-forward layers, and automatic mixed-precision training. Setting Up xFormers and Validating Memory-Efficient Attention Copy CodeCopiedUse a different Browserimport subprocess, sys def _pip(*a): subprocess.run([sys.executable, "-m", "pip", "install", *a], check=False) try: import xformers except Exception: _pip("-q", "-U", "xformers") import math, time import torch, torch.nn as nn, torch.nn.functional as F import xformers, xformers.ops as xops from xformers.ops import fmha ab = fmha.attn_bias assert torch.cuda.is_available(), ( "No GPU detected. In Colab: Runtime → Change runtime type → GPU, then re-run.") device = "cuda" torch.manual_seed(0) print("torch :", torch.__version__) print("xformers :", xformers.__version__) print("GPU :", torch.cuda.get_device_name(0)) print("\n--- xformers.info (which kernels are built/available) ---") try: subprocess.run([sys.executable, "-m", "xformers.info"], check=False) except Exception as e: print("xformers.info unavailable:", e) def cuda_time(fn, iters=20, warmup=5): for _ in range(warmup): fn() torch.cuda.synchronize() s, e = (torch.cuda.Event(enable_timing=True) for _ in range(2)) s.record() for _ in range(iters): fn() e.record(); torch.cuda.synchronize() return s.elapsed_time(e) / iters def peak_mem_mb(fn): torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats() fn(); torch.cuda.synchronize() return torch.cuda.max_memory_allocated() / 1e6 def vanilla_attention(q, k, v, causal=False): """Reference attention that MATERIALIZES the [B,H,M,M] score matrix. Inputs are xformers-layout [B, M, H, K].""" q, k, v = (t.transpose(1, 2).float() for t in (q, k, v)) scores = (q @ k.transpose(-2, -1)) / math.sqrt(q.shape[-1]) if causal: M = scores.shape[-1] m = torch.triu(torch.ones(M, M, device=q.device, dtype=torch.bool), 1) scores = scores.masked_fill(m, float("-inf")) out = scores.softmax(-1) @ v return out.transpose(1, 2) print("\n" + "="*70 + "\n1. memory_efficient_attention basics + correctness\n" + "="*70) B, M, H, K = 2, 512, 8, 64 q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3)) out_xf = xops.memory_efficient_attention(q, k, v) out_ref = vanilla_attention(q, k, v).half() print("output shape :", tuple(out_xf.shape), "(layout B, M, H, K)") print("max abs diff vs ref : {:.2e}".format((out_xf - out_ref).abs().max().item())) print("-> it's EXACT attention (fp16 rounding only), just computed without") print(" ever storing the full MxM score matrix.") We install and import xFormers, verify GPU availability, and inspect the attention kernels supported by the environment. We define helper functions for measuring CUDA execution time and peak memory consumption. We then validate memory-efficient attention against standard attention to confirm that both produce results that closely match each other. Benchmarking Memory and Speed Against Naive Causal Attention Copy CodeCopiedUse a different Browserprint("\n" + "="*70 + "\n2. Memory & speed vs naive attention (fwd+bwd)\n" + "="*70) print(f"{'seqlen':>8} | {'naive MB':>10} | {'xformers MB':>12} | {'naive ms':>9} | {'xf ms':>7}") print("-"*60) for M in [512, 1024, 2048, 4096]: q, k, v = (torch.randn(2, M, 8, 64, device=device, dtype=torch.float16, requires_grad=True) for _ in range(3)) def run_xf(): o = xops.memory_efficient_attention(q, k, v); o.sum().backward() def run_naive(): o = vanilla_attention(q, k, v); o.sum().backward() try: nm = peak_mem_mb(run_naive); nt = cuda_time(run_naive, 8, 3) except RuntimeError: nm, nt = float("nan"), float("nan"); torch.cuda.empty_cache() xm = peak_mem_mb(run_xf); xt = cuda_time(run_xf, 8, 3) print(f"{M:>8} | {nm:>10.0f} | {xm:>12.0f} | {nt:>9.2f} | {xt:>7.2f}") print("-> naive memory grows ~4x per doubling of M (it stores BxHxMxM);") print(" xformers grows ~linearly and stays fast.") print("\n" + "="*70 + "\n3. Causal attention via LowerTriangularMask\n" + "="*70) B, M, H, K = 2, 256, 8, 64 q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3)) out_causal = xops.memory_efficient_attention(q, k, v, attn_bias=ab.LowerTriangularMask()) ref_causal = vanilla_attention(q, k, v, causal=True).half() print("causal max abs diff : {:.2e}".format((out_causal - ref_causal).abs().max().item())) print("-> the mask is implicit; no MxM boolean tensor is allocated.") We benchmark naive attention and xFormers attention across progressively longer sequences using forward and backward passes. We compare their execution times and peak GPU memory usage to observe how xFormers avoids quadratic memory growth. We also apply an implicit lower-triangular mask and verify causal attention against the reference implementation. Packing Variable-Length Sequences and Running Grouped-Query Attention Copy CodeCopiedUse a different Browserprint("\n" + "="*70 + "\n4. Variable-length packed batch — no padding waste\n" + "="*70) seqlens = [37, 120, 8, 200] total = sum(seqlens) H, K = 8, 64 q = torch.randn(1, total, H, K, device=device, dtype=torch.float16) k = torch.randn(1, total, H, K, device=device, dtype=torch.float16) v = torch.randn(1, total, H, K, device=device, dtype=torch.float16) try: bias = ab.BlockDiagonalMask.from_seqlens(seqlens) out_packed = xops.memory_efficient_attention(q, k, v, attn_bias=bias) s0 = seqlens[0] ref0 = vanilla_attention(q[:, :s0], k[:, :s0], v[:, :s0]).half() print("packed shape :", tuple(out_packed.shape), "(all", total, "tokens, no pad)") print("segment-0 max diff : {:.2e}".format((out_packed[:, :s0] - ref0).abs().max().item())) cbias = ab.BlockDiagonalCausalMask.from_seqlens(seqlens) _ = xops.memory_efficient_attention(q, k, v, attn_bias=cbias) print("-> also did a packed CAUSAL pass. This is how vLLM-style engines") print(" batch requests of different lengths with zero padding overhead.") splits = bias.split(out_packed) print("recovered segments :", [tuple(t.shape) for t in splits]) except Exception as e: print("BlockDiagonalMask path skipped on this version/backend:", repr(e)) print("\n" + "="*70 + "\n5. Grouped-query attention (5-D BMGHK layout)\n" + "="*70) B, M, K = 2, 256, 64 n_q_heads, n_kv_heads = 8, 2 G, Hq = n_kv_heads, n_q_heads // n_kv_heads try: qg = torch.randn(B, M, G, Hq, K, device=device, dtype=torch.float16) kg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16) vg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16) out_gqa = xops.memory_efficient_attention(qg, kg, vg) print("GQA output shape :", tuple(out_gqa.shape), "= [B, M, G, Hq, K]") print(f"-> {n_q_heads} query heads, only {n_kv_heads} KV heads: smaller KV-cache,") print(" which is exactly what Llama-/Mistral-class models use at inference.") except Exception as e: print("GQA 5-D path skipped on this version/backend:", repr(e)) We concatenate variable-length sequences and use BlockDiagonalMask to prevent attention from crossing sequence boundaries without padding. We recover the individual outputs and also perform packed causal attention for decoder-style workloads. We then demonstrate grouped-query attention, where multiple query heads share fewer key-value heads to reduce KV-cache requirements. Adding a Custom ALiBi Additive Positional Bias Copy CodeCopiedUse a different Browserprint("\n" + "="*70 + "\n6. Custom ALiBi additive bias\n" + "="*70) B, M, H, K = 1, 128, 8, 64 q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3)) try: slopes = (2.0 ** (-8.0 / H)) ** torch.arange(1, H + 1, device=device) po

Related

相關文章

鈦媒體生成式AI

Edge AI Daily 早報(6月19日)

AI Engineer World's Fair 2026規模再創新高,標誌AI工程從幕後走向舞臺中央。行業面臨結構性調整:楊立昆警示OpenAI年虧210億美元揭示商業模式脆弱性,Transformer之父轉投OpenAI反映人才爭奪白熱化。Anthropic多線佈局——語音支持七種語言、加入碳清除聯盟、落子首爾辦事處,展現生態擴張野心。監管壓力加劇,意大利依據DMA調查蘋果iCloud,巴西開放iOS側載佣金降至5%,蘋果圍牆花園持續崩塌。

3 小時前
智東西生成式AI

谷歌時隔6年再發智能音箱,Gemini上桌,售價不到700元

智東西 編譯 | 劉煜 編輯 | 陳駿達 智東西6月18日消息,谷歌昨日宣佈,其首款搭載居家版Gemini語音助手的智能音箱(Google Home Speaker)已開啟預售,將於當地時間6月25日正式上市,售價為99.99美元(約合人民幣677.03元)。在此之前,谷歌已有6年沒有推出過獨立智能音箱產品。 谷歌這款智能音箱外觀近似球形,風格類似亞馬遜新一代Echo音箱與蘋果舊款音箱HomePod Mini。 ▲谷歌智能音箱(圖源:谷歌官網) 使用音箱時,用戶只需通過口令“Hey Google”或“OK Google”喚醒Gemini,就可以繼續下達相應指令。這與谷歌舊款音箱、智能顯示屏等喚醒語音助手的方式相同。此外,用戶只要按照日常說話習慣下達命令,Gemini便能理解用戶意圖,相比之前大大提升溝通效率。 一、加強短時對話記憶,會員可與Gemini不限次數對話 谷歌此次推出的全新音箱升級諸多功能。其中,音箱搭載的Gemini語音助手擁有10款全新擬人化語音音色,用戶可以根據喜好自行選擇聲線。音箱還可支持用戶一次性下達多條語音指令,即使指令未能說對、說完整,用戶中途改口Gemini也能識別。 Gemini還具備多鏈路推理能力,落地到實際生活場景中比較實用。例如,用戶問:“我支持的足球隊下場比賽天氣如何?”Gemini收到指令後,會自動查詢賽事時間、舉辦地點,同時匹配相應時段天氣,再給出答覆。 同時,Gemini加強了短時對話記憶,能承接上下文實現連續對話功能。即使用戶連續追問、甚至串聯多項任務、不重複交代前置條件,該語音助手也能實現來回連貫交流。 ▲谷歌Gemini對話場景(圖源:谷歌官網) 不僅如此,Gemini搭配的連續對話功能,能讓應答後的音箱麥克風保持短暫收音,用戶無需重複喊“OK Google”就能繼續提問。該功能現已全面支持所有Gemini原生適配的語言,包括

23 小時前

微軟,考慮接入DeepSeek

這篇消息聚焦「微軟,考慮接入DeepSeek」。原始導語提到:Copilot Cowork轉為按量計費。 從 AI 情報角度來看,這類內容值得關注其背後的技術進展、產品落地、產業競爭與後續市場影響。

23 小時前