Skip to content

AI Hardware Accelerators: TPUs, Trainium, and Beyond

Technical Overview

The GPU's dominance in AI hardware is being challenged by purpose-built AI ASICs that optimize specifically for the matrix multiplication and tensor contraction operations that dominate neural network workloads. Google's TPU, AWS Trainium/Inferentia, Cerebras WSE, Graphcore IPU, and Intel Gaudi represent fundamentally different architectural approaches—each making distinct tradeoffs between computation density, memory bandwidth, programmability, and power efficiency. Understanding these architectures requires appreciating the "roofline model": most neural network operations are either compute-bound (arithmetic intensity > memory bandwidth ratio) or memory-bandwidth-bound (loading weights/activations dominates). ASICs can optimize for both simultaneously, often outperforming GPUs on targeted workloads.

Prerequisites

  • Understanding of matrix multiplication complexity and BLAS operations
  • Familiarity with neural network forward pass operations (matmul, softmax, layernorm)
  • Knowledge of DRAM/HBM memory hierarchy and bandwidth constraints
  • Understanding of systolic arrays and SIMD execution models
  • Basic digital design concepts (pipeline stages, dataflow computing)

Core Content

Google TPU (Tensor Processing Unit)

The TPU is Google's custom ASIC optimized for matrix multiplication, designed specifically for TensorFlow/JAX and used to train and serve virtually all of Google's production AI (Search, Translate, Assistant, Gemini).

TPU design philosophy: Rather than a general-purpose processor with vector units tacked on, the TPU is organized around a Matrix Multiply Unit (MXU), which is a fixed-size systolic array. All other operations (activations, normalization, element-wise ops) are secondary.

Systolic Array Architecture:

Systolic Array (128×128 MAC array in TPUv1/v2, 128×256 in TPUv4):

     Weights flow ──────────────────────────────────▶
        W[0,0]  W[0,1]  W[0,2]  W[0,3]
           │       │       │       │
           ▼       ▼       ▼       ▼
Input ─▶ [MAC]──▶[MAC]──▶[MAC]──▶[MAC] ──▶ (partial sum accumulates)
[0]    ↓  ↓       ↓       ↓       ↓
         [MAC]──▶[MAC]──▶[MAC]──▶[MAC]
[1]    ↓  ↓       ↓       ↓       ↓
         [MAC]──▶[MAC]──▶[MAC]──▶[MAC]
[2]       ↓       ↓       ↓       ↓
          ▼       ▼       ▼       ▼
        Sum[0]  Sum[1]  Sum[2]  Sum[3]

Input data flows DOWN through rows (pulse: one row enters per cycle)
Weight data flows RIGHT through columns (pre-loaded, stationary)
Partial sums accumulate horizontally across each row

The key insight: in a systolic array, data flows through a grid of Processing Elements (PEs). Each PE multiplies input × weight and passes the result to its neighbor. Data is reused spatially across the array without re-reading from DRAM. This maximizes arithmetic intensity—the number of MACs per byte of DRAM access.

TPUv4 specifications: - 275 TFLOP/s (BF16 MXU) - 32 GB HBM per chip - 1,200 GB/s HBM bandwidth (vs A100's 2 TB/s) - 2 MXUs (matrix multiply units) - ICI bandwidth: 1,200 GB/s chip-to-chip (ICI = Inter-Chip Interconnect) - Power: 170W per chip

TPU Pod (v4): 4,096 TPUv4 chips connected via 3D toroidal mesh ICI. The mesh provides 600 GB/s per chip to neighbors, enabling extremely fast collective operations. Total Pod: ~1.1 ExaFLOP/s BF16.

TPU ICI Topology:

TPUv4 Pod: 4D Toroidal Mesh

Each chip has 6 ICI links (±X, ±Y, ±Z directions)
In a 4×4×4 sub-cube:

        Z
        │   Y
        │  /
        │ /
        └────── X

Links are bidirectional: 1,200 GB/s total per chip
Mesh routing: packets hop through nearest-neighbor until destination
Collective AllReduce: uses in-network aggregation on ICI mesh

XLA (Accelerated Linear Algebra) Compiler: TPUs are programmed via XLA, a whole-program compiler. Unlike CUDA where the programmer writes individual kernels, XLA takes the entire computation graph, performs global optimization, and compiles to TPU instructions. Key optimizations: - Operator fusion: merge element-wise ops into single kernel to avoid HBM roundtrips - Memory layout optimization: choose tensor memory format (row-major vs column-major) to maximize MXU throughput - Collective scheduling: overlap all-reduce with local computation automatically - No manual memory management: XLA allocates/frees HBM automatically

TPU programming model (JAX):

import jax
import jax.numpy as jnp

@jax.jit  # compiles to XLA, runs on TPU
def train_step(params, batch):
    loss = model(params, batch)
    grads = jax.grad(loss)(params)  # XLA autodiff
    params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return params

# Multi-device (pmap = parallel map over TPU chips)
train_step_parallel = jax.pmap(train_step, axis_name='batch')

AWS Trainium and Inferentia

Amazon designed two distinct ASICs: Trainium for training and Inferentia for inference, reflecting their divergent optimization targets.

AWS Inferentia (2019): 128 TOPS (INT8), 4 NeuronCores, on-chip SRAM. Used for high-throughput inference of BERT, ResNet, GPT-2 at low cost. Custom memory model with neuron-SRAM for weight caching.

AWS Inferentia2 (2023): 190 TOPS (FP16), 4 NeuronCores v2, 32 GB HBM2e per chip, 820 GB/s bandwidth. NeuronLink interconnect (1.6 TB/s chip-to-chip).

AWS Trainium (2021): Training-focused ASIC. 210 TFLOP/s (BF16), 2 NeuronCores-v2, 32 GB HBM. NeuronCore architecture includes a scalable matrix math engine, vector engine, and scalar engine.

Trainium2 (2024): 4× better performance/watt vs Trainium. 16-chip UltraServer configuration with NeuronLink-v3 (36 TB/s chip-to-chip).

AWS Neuron SDK: Compiler + runtime framework. Models are traced (torch.jit or TorchScript), then compiled to Neuron executable. Supports PyTorch (via torch_neuronx) and TensorFlow.

import torch_neuronx

model = MyModel().eval()
# Compile once, cache for deployment
trace = torch_neuronx.trace(model, sample_input)
trace.save("model.pt")

Tradeoff vs H100: Inferentia2 achieves ~3× lower cost-per-inference for BERT-large at batch=8. Trainium is cost-competitive with A100 for LLM fine-tuning but less flexible (requires Neuron compiler; some operations fall back to CPU).

Cerebras WSE (Wafer-Scale Engine)

The most radical departure from conventional chip design: Cerebras manufactures a single processor that uses an entire 300mm silicon wafer.

WSE-2 (2021): 850,000 AI-optimized cores, 40 GB on-chip SRAM (zero off-chip DRAM for weights), 2.6 PB/s memory bandwidth, 7nm TSMC. 26,000 mm² (a standard GPU is ~800 mm²).

WSE-3 (2023): 4 nm TSMC, 900,000 cores, 44 GB on-chip SRAM, 1.2 ExaFLOP/s sparse, 21.1 PetaFLOP/s dense.

The wafer-scale advantage: In a GPU, weights must be fetched from HBM on every operation. HBM bandwidth is ~3 TB/s. On WSE, weights live in on-chip SRAM—2.6 PB/s bandwidth. This is 800× more memory bandwidth. For memory-bandwidth-limited operations (inference at small batch, embedding lookup), WSE is dramatically faster.

Manufacturing challenge: A 300mm wafer at 7nm has a defect rate that makes any large chip unlikely to be defect-free. Cerebras's solution: yield-tolerant design. Each wafer contains thousands of cores connected by a 2D mesh. Defective cores are identified at test time and disabled; the routing mesh routes around them. At ≥98% core yield, a wafer with 3% defective cores still delivers 97% performance.

Cerebras programming model: Uses Cerebras Software Platform (CSP) with PyTorch front-end. Models are compiled to a dataflow graph that maps onto the 2D mesh.

Use cases: Cerebras is particularly effective for models that fit entirely in 40 GB SRAM (up to ~20B parameters in INT8). For models larger than SRAM capacity, system-level architecture (CS-2 systems cluster) uses an external MemoryX module for parameter storage.

Cerebras CS-2 cluster: Multiple CS-2 units connected to a shared MemoryX (remote SRAM) and SwarmX (all-reduce fabric). Demonstrated 1.2T parameter model training.

Graphcore IPU (Intelligence Processing Unit)

Graphcore's IPU takes a different approach: the Bulk Synchronous Parallel (BSP) execution model, with massive fine-grained parallelism and SRAM-only on-chip memory.

IPU MK2 (Colossus GC200): - 1,472 independent processor tiles - 900 MB on-chip SRAM (638 MB for data, 262 MB for code) - 58 TFLOP/s (FP16 bulk compute) - 47.5 TB/s total memory bandwidth (to on-chip SRAM) - No HBM — all computation in SRAM only

BSP execution model:

BSP: Bulk Synchronous Parallel
  Phase 1: Local Compute — all 1,472 tiles compute independently, no sync
  Phase 2: Global Barrier — all tiles synchronize
  Phase 3: Communication — tiles exchange data via on-chip exchange fabric

  Repeat for each BSP step (one transformer layer ≈ several BSP steps)

BSP avoids the latency of cache-coherency protocols (as in CPUs) by making communication explicit and periodic. The programmer specifies exactly what data each tile exchanges and when.

IPU architecture:

1,472 Tiles on Colossus chip:

Tile 0  │ Tile 1  │ Tile 2  │  ...  │ Tile 1471
[ALU]   │ [ALU]   │ [ALU]   │       │ [ALU]
[SRAM]  │ [SRAM]  │ [SRAM]  │       │ [SRAM]
        │         │         │       │
        └─────────┴─────────┴───────┘
              Exchange Fabric (IPU-Link)
              All-to-all, 6 TB/s aggregate

PopART (Poplar Advanced Runtime): Graphcore's compiler and runtime. Operates on computation graphs; maps operations onto tiles using a partitioning algorithm that minimizes inter-tile communication.

IPU Pod: 16 IPUs connected via IPU-Link (1.6 TB/s per direction). POD-64 = 64 IPUs, effective 3.7 PFlop/s.

IPU strengths: Fine-grained parallelism, low-latency memory, ideal for sparse operations and models with irregular data dependencies. Demonstrated advantages for graph neural networks (GNN) and sparse transformers.

Intel Gaudi 2

Gaudi 2 (2022): 24 Tensor Processor Cores (TPC), 96 GB HBM2e, 2.45 TB/s bandwidth, 24 TFLOP/s (BF16). Notably includes 24 100GbE ports on-chip for scale-out without external NICs.

Gaudi 3 (2024): 64 TPCs, 128 GB HBM2e, 3.7 TB/s bandwidth, 1,835 TOPS (INT8), 2 400GbE OSFP. Targeting H100 competitor pricing (~50% lower cost per TFLOP).

Intel Habana SynapseAI SDK: PyTorch integration via habana_frameworks. Models run on Gaudi with minimal code changes: model.to('hpu') (Habana Processing Unit).

Gaudi strength: Highly competitive price-performance for inference. Hugging Face Optimum-Habana library supports LLaMA, Mistral, Llama 3 on Gaudi 2/3.

SambaNova SN40L

SambaNova SN40L (2023): Reconfigurable Dataflow Unit (RDU). 520 TOPS (INT8), 8 GB HBM3 + 128 GB DRAM. The "reconfigurable" aspect means the fabric of memory-compute units can be rewired at runtime to match different model topologies. Targets inference at large scale with model parallelism built in hardware.

Comparison Table

Accelerator BF16 TFLOP/s Memory (GB) Bandwidth (TB/s) Power (W) Primary Use
NVIDIA H100 SXM5 1,979 80 HBM3 3.35 700 Training + Inference
Google TPUv4 275 32 HBM 1.2 170 Training (JAX/TF)
AWS Trainium 210 32 HBM 0.82 350 Training
AWS Inferentia2 190 32 HBM 0.82 300 Inference
Intel Gaudi 3 1,835 TOPS 128 HBM 3.7 900 Training + Inference
Cerebras WSE-3 1,200 (sparse) 44 SRAM 2,600 23,000 Training (specific)
Graphcore IPU (16×) 3,700 TFLOP/s 14.4 SRAM ~47 TB/s ~2,400 Training

The Memory Bandwidth Wall

The most fundamental challenge for AI accelerators is the "memory bandwidth wall": as neural networks grow larger, weight loading dominates computation. A 70B model in FP16 requires loading 140 GB per forward pass. Even at 3.35 TB/s (H100), this takes 42ms for a single token. Doubling memory bandwidth halves inference latency—compute TFLOP/s is irrelevant if memory-bound.

Roofline model:

Arithmetic Intensity = FLOPs / Bytes_accessed

For a linear layer Y = XA (B=1, S=1, H=4096, H'=4096):
  FLOPs = 2 × 4096 × 4096 = 33.5 MFLOPs
  Bytes = 4096 × 4096 × 2 (weight) + 4096 × 2 (input) + 4096 × 2 (output) = 33.5 MB
  Arithmetic intensity = 33.5 MFLOPs / 33.5 MB = 1 FLOP/Byte

H100 peak: 1,979 TFLOP/s compute, 3.35 TB/s bandwidth
  Compute ceiling: 1,979,000 GFLOPs/s
  Bandwidth ceiling: 3,350 GB/s × 1 FLOP/Byte = 3,350 GFLOPs/s

  The layer is memory-bandwidth limited at B=1!
  Increasing batch size: B=32 → intensity = 32 FLOP/Byte → still bandwidth-limited
  At B=2000: intensity = 2000 FLOP/Byte → compute-limited

This drives the trend toward ever-higher HBM bandwidth (H200: 4.8 TB/s) and co-design of memory+compute.

Historical Context

Google announced the first TPU deployment in a data center in 2016 (TPUv1 for inference only). TPUv2 added training capability in 2017. The first external availability was TPUv2/v3 on Google Cloud in 2018. Graphcore shipped its first IPU systems in 2018. Cerebras unveiled the WSE-1 at Hot Chips 2019, attracting massive attention for its wafer-scale approach. AWS launched Inferentia in 2019 and Trainium in 2021. Intel acquired Habana Labs (Gaudi) for $2B in 2019. By 2023, non-NVIDIA AI accelerators had captured a meaningful fraction of cloud inference workloads (est. 15–20% of AWS AI compute).

Production Examples

Google Search ranking and Translate: Served entirely on TPU inference pods. Every Google Search result is ranked by TPU-accelerated ML models.

Gemini training (Google, 2023): Trained on TPUv4 Pods. Gemini Ultra reportedly used 4,096+ TPUv4 chips. The entire training infrastructure runs JAX on TPUs.

Amazon Alexa and Bedrock: Inferentia serves some Alexa NLU workloads. AWS Bedrock's lowest-cost Claude inference tiers use Inferentia2.

Meta's MTIA (Meta Training and Inference Accelerator, 2023): Custom ASIC for recommendation model inference. 800 TOPS, used for ranking billions of Facebook posts/day.

Debugging Notes

XLA compilation time: JAX/TPU first run triggers XLA compilation which can take 10–60 minutes for large models. Use jax.xla_computation() to inspect the compiled HLO. Cache compiled artifacts with jax.experimental.compilation_cache.

TPU silent NaN: Unlike CUDA which can raise exceptions on NaN, XLA/TPU may silently propagate NaN through computation. Use jax.debug.print and explicit jnp.any(jnp.isnan(x)) checks in training loops.

IPU tile memory overflow: If a model's tensors don't fit in per-tile SRAM budget, PopART falls back to streaming from off-chip DRAM via exchange fabric. Performance degrades 10–100×. Monitor with popart.Session.modelToHost() memory reports.

Cerebras compilation failure for dynamic shapes: WSE compilation requires static shapes. PyTorch dynamic shapes (variable-length sequences) must be padded to the maximum length at compile time. Models with highly variable sequence lengths suffer from padding waste.

Security Implications

Supply chain risk for custom ASICs: Google, AWS, and others designing custom chips face supply chain vulnerabilities at TSMC (geopolitical risk, fab capacity). Unlike GPU procurement (multiple vendors), custom ASICs are single-source. NVIDIA H100 shortage in 2023 was mitigated by TPU availability for Google but not for others.

Side-channel via ICI timing: TPU's shared ICI mesh is used by multiple customer workloads on the same Pod. In a multi-tenant cloud setting, ICI contention timing could theoretically leak cross-tenant information. Google mitigates with dedicated Pod allocation for sensitive workloads.

Firmware trust on AWS Nitro/Trainium: Trainium instances run on Nitro hypervisor with hardware root of trust. The NeuronCore firmware is loaded by AWS. Customers must trust AWS to not inject backdoors via firmware updates. AWS Nitro's published security whitepaper describes the attestation chain.

Performance Implications

Compilation-time cost amortization: XLA/Neuron/PopART compilation overhead is significant but one-time. For long training runs (days to months), compilation overhead is <0.1%. For short inference workloads (hundreds of requests), compilation overhead can dominate—use compiled model caching.

Batch size optimization for TPU: TPU MXU is most efficient when input dimensions are multiples of 128 (TPUv4 MXU is 128×256). Pad batch sizes and sequence lengths to multiples of 128 for full MXU utilization. At non-aligned sizes, MXU utilization drops sharply.

WSE efficiency for LLM inference: Cerebras claims 60× faster LLM inference latency vs H100 for a 13B model at batch=1, because weights are in on-chip SRAM. This is credible based on memory bandwidth arithmetic: 44 GB SRAM / 2.6 PB/s = 17 µs to read all weights vs H100's 80 GB / 3.35 TB/s × 2 (70B weight FP16) = 42ms. Caveat: models larger than 44 GB require off-chip memory access.

Failure Modes and Real Incidents

Incident: TPU silent matrix multiply error (Google 2021, reported in XLA bug tracker): A specific combination of tensor shapes triggered an edge case in the MXU's systolic array controller, producing incorrect partial sums. The bug was in compiled XLA code, not hardware. Detection: numerical comparison with CPU reference. Mitigation: XLA compiler patch.

Incident: Cerebras WSE tile failures at scale: Early WSE-1 deployments showed defective tiles causing BSP barrier timeouts (a defective tile never reaches barrier). The reconfiguration system correctly routed around bad tiles after re-test, but the recovery took 10+ minutes of downtime. WSE-2 improved yield tolerance algorithms.

Incident: Trainium NeuronCore SRAM ECC error during training: A single-bit ECC error in NeuronCore on-chip SRAM caused gradient corruption in one training step. The error was corrected by SECDED ECC but flagged in logs. Training continued from that checkpoint. The error recurred, suggesting a marginal SRAM cell; the instance was replaced.

Modern Usage

TPUv4/v5 for Gemini training (2023–2024): Google uses TPU v4p and v5p (dedicated accelerator pods for performance) for Gemini training. TPUv5e (efficiency variant) for fine-tuning and inference.

AWS Trainium2 for LLaMA fine-tuning (2024): AWS customers using SageMaker HyperPod for distributed fine-tuning. Reported 2× cost reduction vs comparable GPU instances.

Intel Gaudi 2 for Llama inference (2024): Hugging Face TGI supports Gaudi 2 natively. Deployed in Intel Developer Cloud for LLM serving at ~40% lower cost than H100.

Future Directions

  • In-memory computing: Resistive RAM (RRAM) and phase-change memory arrays that compute matrix multiply in-situ, eliminating data movement entirely; IBM demonstrating 64-core analog AIMC chips
  • Optical neural network processors: Lightmatter Envise uses photonic MZI arrays for ultra-low-power analog matrix multiply; announced Passage interconnect for photonic chip-to-chip communication
  • CXL-attached accelerator memory: CXL 3.0 enables memory pooling where accelerators access a shared coherent DRAM pool, breaking the per-chip memory wall
  • Neuromorphic for sparse workloads: Intel Loihi 2, IBM NorthPole — event-driven sparse neural network hardware with 10,000× better efficiency for sparse spiking networks
  • Co-design with model architecture: TPU architecture directly influenced BF16 adoption, systolic-array-friendly dense matmul in transformers, and chunked attention patterns

Exercises

  1. Roofline analysis: For a GPT-3 70B inference forward pass at batch=1, sequence_length=1, calculate: (a) total FLOPs for one transformer layer (attention + FFN), (b) total bytes accessed (weights + KV cache + activations), (c) arithmetic intensity, (d) whether compute-bound or memory-bound on H100 and TPUv4. Repeat for batch=64.

  2. Systolic array simulation: Implement a 4×4 systolic array simulator in Python. Given weight matrix W and input matrix X, simulate the cycle-by-cycle data flow and verify output matches numpy's matrix multiply. Measure cycles required vs naive triple-nested loop.

  3. XLA compilation experiment: Set up a JAX environment on a machine with TPU (or CPU fallback). Define a simple transformer layer. Measure: (a) first execution time (includes XLA compilation), (b) subsequent execution times. Profile with jax.profiler.trace and inspect the HLO (High-Level Operations) with jax.xla_computation.

  4. Memory bandwidth wall measurement: On any GPU, benchmark the throughput of a simple element-wise operation (e.g., output = a * b + c) as a function of tensor size. At what size does it become memory-bandwidth limited? Compare theoretical bandwidth to achieved bandwidth using nvbandwidthtest.

  5. TPU vs GPU cost-performance: For a specific workload (e.g., fine-tuning LLaMA 7B for 1B tokens), use AWS/GCP pricing to calculate total cost on: (a) NVIDIA A100 (p4d.24xlarge), (b) H100 (p5.48xlarge), (c) TPUv4-8 (GCP), (d) Trainium (trn1.32xlarge). Estimate time-to-completion and cost for each.

References

  • Norman P. Jouppi et al., "In-Datacenter Performance Analysis of a Tensor Processing Unit," ISCA 2017
  • Jouppi et al., "TPU v4: An Optically Reconfigurable Supercomputer for Machine Learning with Hardware Support for Embeddings," ISCA 2023
  • Amazon AWS Trainium and Inferentia documentation: https://aws.amazon.com/machine-learning/trainium/
  • Cerebras WSE-3 architecture whitepaper: https://www.cerebras.net/chip/
  • Graphcore IPU architecture overview: https://www.graphcore.ai/products/ipu
  • Intel Gaudi 3 AI Accelerator Architecture, Intel 2024
  • Williams et al., "Roofline: An Insightful Visual Performance Model for Multicore Architectures," CACM 2009
  • Patterson et al., "Carbon Emissions and Large Neural Network Training," arXiv 2021