MarkTechPost AIAI工具與產品

How to Speed Up Transformer Training Using NVIDIA Apex (FusedAdam, FusedLayerNorm) and Native torch.amp

2026年6月2日 01:39

重點摘要

In this tutorial, we work through an implementation of NVIDIA Apex, focusing on the components that still matter in modern GPU training workflows. Instead of treating Apex as a general mixed-precision library, we separate the older parts from the still-useful ones and test them directly. We begin by checking the CUDA runtime, building Apex with the required CUDA and C++ extensions, and detecting which fused kernels are actually available in the environment. This matters because a Python-only Apex installation can appear successful while silently missing the high-performance kernels that make Apex useful. After the setup, we benchmark FusedAdam against PyTorch AdamW, compare FusedLayerNorm and FusedRMSNorm with standard normalization layers, and run both legacy apex.amp and modern torch.amp

站內 AI 整理稿

In this tutorial, we work through an implementation of NVIDIA Apex, focusing on the components that still matter in modern GPU training workflows. Instead of treating Apex as a general mixed-precision library, we separate the older parts from the still-useful ones and test them directly. We begin by checking the CUDA runtime, building Apex with the required CUDA and C++ extensions, and detecting which fused kernels are actually available in the environment. This matters because a Python-only Apex installation can appear successful while silently missing the high-performance kernels that make Apex useful. After the setup, we benchmark FusedAdam against PyTorch AdamW, compare FusedLayerNorm and FusedRMSNorm with standard normalization layers, and run both legacy apex.amp and modern torch.amp examples. We then bring everything together in a small Transformer training experiment, where we compare a vanilla FP32 PyTorch path with a fused Apex-plus-AMP path to assess the real effect on throughput. Copy CodeCopiedUse a different Browserimport os, sys, time, subprocess, importlib import torch assert torch.cuda.is_available(), ( "No CUDA GPU found. In Colab: Runtime > Change runtime type > Hardware accelerator = GPU" ) DEV = torch.device("cuda") print(f"[env] torch {torch.__version__} | CUDA {torch.version.cuda} | GPU {torch.cuda.get_device_name(0)}") def _module_present(name: str) -> bool: try: importlib.import_module(name) return True except Exception: return False def _build_apex(): print("[apex] building from source with CUDA + C++ extensions " "(~10-20 min on first run; grab a coffee)...") subprocess.run([sys.executable, "-m", "pip", "install", "-q", "ninja", "packaging"], check=True) if not os.path.isdir("apex"): subprocess.run(["git", "clone", "--depth", "1", "https://github.com/NVIDIA/apex"], check=True) env = os.environ.copy() env["APEX_CPP_EXT"] = "1" env["APEX_CUDA_EXT"] = "1" env["MAX_JOBS"] = "4" env["NVCC_APPEND_FLAGS"] = "--threads 4" cmd = [sys.executable, "-m", "pip", "install", "-v", "--no-build-isolation", "--no-cache-dir", "./apex"] proc = subprocess.run(cmd, env=env) if proc.returncode != 0: print("[apex] CUDA build failed -> falling back to PYTHON-ONLY install " "(fused kernels will be unavailable, tutorial still runs).") subprocess.run([sys.executable, "-m", "pip", "install", "-v", "--no-build-isolation", "--no-cache-dir", "./apex"], check=False) if not _module_present("amp_C"): _build_apex() HAS_AMP_C = _module_present("amp_C") HAS_FLN = _module_present("fused_layer_norm_cuda") try: import apex from apex.optimizers import FusedAdam from apex.normalization import FusedLayerNorm try: from apex.normalization import FusedRMSNorm HAS_RMS = True except Exception: HAS_RMS = False from apex import amp APEX_OK = True except Exception as e: print(f"[apex] import failed: {e}") APEX_OK = False print("\n[capabilities]") print(f" apex importable : {APEX_OK}") print(f" FusedAdam kernels : {HAS_AMP_C}") print(f" FusedLayerNorm krnl: {HAS_FLN}") print(f" FusedRMSNorm : {APEX_OK and HAS_RMS}") print("=" * 78) def bench(fn, iters=50, warmup=10): for _ in range(warmup): fn() torch.cuda.synchronize() t0 = time.perf_counter() for _ in range(iters): fn() torch.cuda.synchronize() return (time.perf_counter() - t0) / iters * 1e3 We start by preparing the CUDA environment, checking GPU availability, and printing the active PyTorch, CUDA, and GPU details. We then build NVIDIA Apex from source with CUDA and C++ extensions so that the fused kernels can be used directly rather than relying on a limited Python-only installation. We also detect whether FusedAdam, FusedLayerNorm, FusedRMSNorm, and legacy AMP are available, and define a reusable benchmarking helper for subsequent tests. Copy CodeCopiedUse a different Browserprint("\n### SECTION A: FusedAdam vs AdamW ###") def make_many_param_model(n_layers=60, dim=512): return torch.nn.Sequential(*[torch.nn.Linear(dim, dim) for _ in range(n_layers)]).to(DEV) def opt_step_factory(optimizer, model, dim=512): x = torch.randn(64, dim, device=DEV) def step(): optimizer.zero_grad(set_to_none=True) out = model(x).pow(2).mean() out.backward() optimizer.step() return step m1 = make_many_param_model() torch_adam = torch.optim.AdamW(m1.parameters(), lr=1e-3) ms_torch = bench(opt_step_factory(torch_adam, m1)) print(f" torch.optim.AdamW : {ms_torch:6.2f} ms / step") if HAS_AMP_C and APEX_OK: m2 = make_many_param_model() m2.load_state_dict(m1.state_dict()) fused_adam = FusedAdam(m2.parameters(), lr=1e-3) ms_fused = bench(opt_step_factory(fused_adam, m2)) print(f" apex.FusedAdam : {ms_fused:6.2f} ms / step " f"(~{ms_torch/ms_fused:0.2f}x on optimizer-bound step)") else: print(" apex.FusedAdam : SKIPPED (cuda ext not built)") We benchmark PyTorch AdamW against Apex FusedAdam using a model with many linear layers to make optimizer overhead visible. We run the same optimizer step pattern for both methods, so the comparison focuses on update speed rather than model differences. We then report the step time and speedup to assess whether the fused multi-tensor optimizer provides a practical benefit in the current GPU runtime. Copy CodeCopiedUse a different Browserprint("\n### SECTION B: FusedLayerNorm / FusedRMSNorm ###") B, T, H = 32, 512, 1024 x = torch.randn(B, T, H, device=DEV, requires_grad=True) torch_ln = torch.nn.LayerNorm(H).to(DEV) def ln_torch(): y = torch_ln(x); y.sum().backward() ms_ln_torch = bench(ln_torch) print(f" nn.LayerNorm : {ms_ln_torch:6.2f} ms / fwd+bwd") if HAS_FLN and APEX_OK: fused_ln = FusedLayerNorm(H).to(DEV) with torch.no_grad(): fused_ln.weight.copy_(torch_ln.weight); fused_ln.bias.copy_(torch_ln.bias) diff = (fused_ln(x.detach()) - torch_ln(x.detach())).abs().max().item() print(f" max|fused - torch| = {diff:.2e} (should be ~1e-3 or smaller)") def ln_fused(): y = fused_ln(x); y.sum().backward() ms_ln_fused = bench(ln_fused) print(f" apex.FusedLayerNorm: {ms_ln_fused:6.2f} ms / fwd+bwd " f"(~{ms_ln_torch/ms_ln_fused:0.2f}x)") if HAS_RMS: fused_rms = FusedRMSNorm(H).to(DEV) def rms_fused(): y = fused_rms(x); y.sum().backward() print(f" apex.FusedRMSNorm : {bench(rms_fused):6.2f} ms / fwd+bwd " f"(RMSNorm: no mean-subtraction, used by LLaMA-style models)") else: print(" apex.FusedLayerNorm: SKIPPED (cuda ext not built)") We compare the standard PyTorch LayerNorm with Apex FusedLayerNorm on a large tensor resembling transformer hidden states. We first check numerical correctness by copying the same affine parameters and measuring the maximum difference between fused and standard outputs. We then benchmark forward and backward passes and, when available, test FusedRMSNorm to demonstrate how Apex supports normalization layers used in LLaMA-style models. Copy CodeCopiedUse a different Browserprint("\n### SECTION C: mixed precision (apex.amp opt-levels, DEPRECATED) ###") def tiny_net(): return torch.nn.Sequential( torch.nn.Linear(256, 256), torch.nn.ReLU(), torch.nn.Linear(256, 256), torch.nn.ReLU(), torch.nn.Linear(256, 10), ).to(DEV) if APEX_OK: for level in ["O0", "O1", "O2"]: net = tiny_net() optimizer = (FusedAdam(net.parameters(), lr=1e-3) if HAS_AMP_C else torch.optim.AdamW(net.parameters(), lr=1e-3)) net, optimizer = amp.initialize(net, optimizer, opt_level=level, verbosity=0) xb = torch.randn(128, 256, device=DEV) yb = torch.randint(0, 10, (128,), device=DEV) lossfn = torch.nn.CrossEntropyLoss() for _ in range(20): optimizer.zero_grad() loss = lossfn(net(xb), yb) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() print(f" opt_level={level}: final loss = {loss.item():.4f}") else: print(" apex.amp: SKIPPED (apex not importable)") print("\n >> Modern recommended equivalent (torch.amp, no Apex needed):") net = tiny_net() optimizer = torch.optim.AdamW(net.parameters(), lr=1e-3) scaler = torch.amp.GradScaler("cuda") xb = torch.randn(128, 256, device=DEV); yb = torch.randint(0, 10, (128,), device=DEV) lossfn = torch.nn.CrossEntropyLoss() for _ in range(20): optimizer.zero

Related

相關文章

DeepSeek 識圖模式正式上線 App 和網頁端

DeepSeek 多模態研究員 Xiaokang Chen 今日表示,DeepSeek 的識圖模式已在網頁和 App 端正式上線。IT之家測試,目前 DeepSeek 的 App 端識圖模式依然提示“圖片理解功能內測中”,網頁端沒有這項提示。

18 小時前

Kimi Work 迎重大升級:推出“目標模式”並打通外部應用插件

月之暗面旗下 Kimi 電腦客戶端近日煥新升級,為 Kimi Work(Beta 版)引入兩項重磅新特性:目標模式實現連續自主工作 24 小時,插件中心正式對接多家主流辦公軟件,提升工作流效率。為加速用戶深度體驗,官方同步推出限時優惠,2026 年 6 月全月,使用 Work 模式的會員額度消耗直接打 5 折,帶來實惠。

1 天前8300

網易雲音樂旗下AI情感陪伴App“妙時”宣佈7月14日停運

網易雲音樂旗下“妙時”(含AI奇遇)AI情感陪伴應用發佈停運公告,將於7月14日0時全面停止服務。客服迴應屬正常業務調整,不影響其他產品。目前已停止新用戶註冊和充值,用戶可在8月14日前申請退還剩餘代幣和會員費,並導出AI戀人聊天記錄。

1 天前9400