ReViT: Rotational-equivariant Vision Transformers for Neural PDE Solvers

1Technical University of Munich
Oral at ICML 2026
MHD velocity
MHD magnetic
TCF velocity

ReViT is the first Vision Transformer framework that enforces strict rotational equivariance on grid-based physical fields. By mapping scalar and vector inputs into locally invariant representations derived from physics-based canonical bases, ReViT enables standard self-attention without symmetry violations—yielding significant accuracy gains across 2D and 3D PDE benchmarks.


Abstract

Physics obeys strict symmetries like rotational equivariance. However, the standard Transformer architectures widely used in physics foundation models do not enforce these constraints by construction. We introduce ReViT, a rotationally equivariant Vision Transformer framework for neural PDE solvers operating on grid-based physical fields that strictly enforces rotational equivariance.

ReViT maps scalar and vector inputs into locally invariant representations derived from physics-based canonical bases, enabling the use of standard self-attention without symmetry violations. Built on a hierarchical Swin-style backbone with a precomputed reference basis pyramid, ReViT preserves equivariance across multi-scale operations.

We evaluate ReViT on a wide range of 2D and 3D PDE benchmarks, such as Magnetohydrodynamics and Turbulent Channel Flows, demonstrating significant gains over state-of-the-art baselines. ReViT exhibits strong generalization, and reduces MSE by up to 65% compared with the best-performing alternatives.


Theoretical Analysis: Why ViTs Fail

We consider a physical field \(\mathbf{f}: \Omega \rightarrow \mathcal{V}\) on a spatial domain \(\Omega \subset \mathbb{R}^d\). For a vector field \(\mathbf{u}\), the rotation group acts on both the domain and the value: \([L_g \mathbf{u}](x) = \mathbf{R}(g)\mathbf{u}(\mathbf{R}(g)^{-1}x)\). A neural network \(\Phi\) is equivariant if \(\Phi(L_g \mathbf{u}) = L_g \Phi(\mathbf{u})\). We identify three distinct mechanisms by which standard ViTs violate rotational equivariance:

C1. The Tokenization Barrier. Standard patch embeddings flatten the vector field within \(P_i\) into \(\mathbf{v} \in \mathbb{R}^{K^d \cdot d}\) and apply a learnable linear map \(\mathbf{E}\). Under rotation \(g\), pixels permute to \(\pi_g \mathbf{v}\), but \(\mathbf{E}(\pi_g \mathbf{v}) \neq \mathbf{E}(\mathbf{v})\)—projecting rotated patterns into disjoint latent regions.
C1: Tokenization barrier — flattening a rotated patch maps it to a different raster vector.
C2. Loss of Spatial Equivariance. Absolute PEs break equivariance since \(\mathbf{P}(\mathbf{R}(g)x) \neq \boldsymbol{\pi}_g \mathbf{P}(x)\). Relative PEs depend on \(\boldsymbol{\delta}_{ij} = x_j - x_i\), restoring translation equivariance, but an isotropic function like \(\|x_i - x_j\|\) discards all directional information.
C2: PE dilemma — existing PEs are either directional or equivariant, never both.
C3. The Representational Mismatch. For vector fields, rotations act on both coordinates and values. By Schur's Lemma, any linear map commuting with rotations must be \(\mathbf{W} = \lambda \mathbf{I}\). Standard ViT projections are unconstrained and dense, violating \(\mathbf{W}(\mathbf{R}(g)\mathbf{u}) = \mathbf{R}(g)(\mathbf{W}\mathbf{u})\) and breaking equivariance.

Methodology of ReViT

By adapting local canonicalization to ViTs, we decouple basis transformations from feature learning, solving challenges C1–C3. The model alternates between invariant processing (features in local canonical frames) and global transitions (physical basis transformations). The architecture consists of three stages: (1) Local Canonicalization, (2) Invariant Transformer Processing, and (3) Equivariant Decoding.

ReViT architecture overview.

Figure 2. Overview of the ReViT architecture. The hierarchical encoder-decoder alternates between invariant processing (blue) and global transitions (orange). The Reference Basis Pyramid (purple) supplies local bases \(\mathcal{B}^{(l)}\) to mediate resolution changes (Merge, Expand) and the final Equivariant Rebase.

Interactive Local Rebasing

Rotate Globally, Learn Locally

The local basis is built from a patch-aggregated vector and used to rebase global vectors into local coordinates. When the field rotates, the basis co-rotates, so projecting vectors with \(\mathbf{B}_i^{\mathsf T}\) keeps the local representation stable.

Red box: selected patch used to show the patch-level basis. Global vector \(\mathbf{u}\): orange, drawn in the fixed \(X\)-\(Y\) basis. Rebased vector \(\mathbf{u}^{\text{local}}=\mathbf{B}_i^{\mathsf T}\mathbf{u}\): green, drawn in the co-rotating local basis.
Global field fixed \(X\)-\(Y\) basis
Rebased field local basis
Global patch field
Global token sequence (co-rotating)
Local basis (co-rotating)
Rebased token sequence (invariant to rotation)
\(\bar{\mathbf{u}}_i = \frac{1}{|P_i|}\sum_{k\in P_i}\mathbf{u}_k\) \(\mathbf{b}_{1,i} = \bar{\mathbf{u}}_i / \|\bar{\mathbf{u}}_i\|\) \(\mathbf{u}^{\text{local}}_{i,k} = \mathbf{B}_i^{\mathsf T}\mathbf{u}_{i,k}\)

1. Local Canonicalization — solves C3

For each patch \(P_i\), we compute a Local Canonical Basis \(\mathbf{B}_i \in SO(d)\) deterministically from the field values. Vectors are projected into local frames: \(\mathbf{u}^{\text{local}}_{i,k} = \mathbf{B}_i^T \mathbf{u}_k\). This is provably invariant:

$$\begin{aligned} (\mathbf{R}(g)\mathbf{B}_i)^{T} (\mathbf{R}(g)\mathbf{u}_i^{\text{global}}) &= \mathbf{B}_i^{T} \mathbf{R}(g)^{T} \mathbf{R}(g) \mathbf{u}_i^{\text{global}} \\ &= \mathbf{B}_i^{T} \mathbf{u}_i^{\text{global}} \\ &= \mathbf{u}^{\text{local}}_{i,k} \end{aligned}$$

In 3D, the basis is derived from the mean velocity \(\bar{\mathbf{u}}_i\) and mean vorticity \(\bar{\boldsymbol{\omega}}_i\), using stabilized analytical orthogonalization based on sequential cross-products.

S3: Local canonical basis — project vectors via B_i^T from global to local frame.

2. Invariant Patch Aggregation — solves C1

Instead of flattening patches (which breaks under rotation-induced permutations, C1), we treat each patch as a permutation-invariant set:

Step 1. Map each local vector through an MLP to obtain invariant features:

$$\mathcal{X}_i = \{\mathbf{h}_{i,k} \mid k \in P_i\}, \quad \mathbf{h}_{i,k} = \text{MLP}(\mathbf{u}^{\text{local}}_{i,k})$$

Step 2. Aggregate with a Set-Transformer: concatenate a learnable query token \(\mathbf{h}_{\text{query}}\) with the set, apply self-attention, and extract the query output:

$$\mathbf{H} = [\mathbf{h}_{\text{query}},\, \mathbf{h}_1, \dots, \mathbf{h}_{K^d}], \quad \mathbf{z}_i = \text{SA}(\mathbf{H})[0]$$

Result. The aggregation is strictly permutation-invariant—pixel ordering within the patch is irrelevant:

$$\text{Agg}(\pi(\mathcal{X}_i)) = \text{Agg}(\mathcal{X}_i)$$
S1: Set-Transformer aggregation — patch vectors → MLP → set → query token → z_i.

3. Rebased Relative Positional Encoding — solves C2

We project displacement vectors into the query token's local basis: \(\mathbf{p}_{ij \to i} = \mathbf{B}_i^T(x_j - x_i)\). This is invariant to global rotations while preserving local anisotropy. The modified self-attention becomes:

$$\text{SA}(\mathbf{X}) = \text{softmax}\left(\frac{\mathbf{X}\mathbf{W}_Q (\mathbf{X}\mathbf{W}_K)^T}{\sqrt{d_k}}+\mathbf{P}\right)\mathbf{X}\mathbf{W}_V$$
S2: Rebased PE — project displacement δ_ij into local basis B_i^T for directional yet equivariant encoding.

4. Equivariant Decoder

The transformer outputs invariant tokens \(\mathbf{h}_i^{(L)}\). A local query decoder uses a canonical grid \(\mathcal{G} = \{\mathbf{\xi}_m \in [-1,1]^d\}\) mapped via Fourier features as spatial queries. Cross-attention reconstructs dense spatial details:

$$\mathbf{z}^{\text{local}}_{i} = \text{CrossAttn}\left(\mathbf{Q}_{\text{grid}},\, \mathbf{h}_i^{(L)}\right)$$

Predictions are lifted back to global coordinates: \(\mathbf{u}'_{i} = \mathbf{B}_i \cdot \mathbf{z}^{\text{local}}_{i}\). Since \(\mathbf{z}^{\text{local}}\) is invariant and \(\mathbf{B}_i\) co-rotates with the input, the output is strictly equivariant.

Equivariant Decoder: canonical grid → Fourier features → CrossAttn → z^local (invariant) → B_i lift → u' (equivariant).

5. Reference Basis Pyramid

A pyramid of local canonical bases \(\mathcal{B}^{(l)} = \{\mathbf{B}_k^{(l)}\}\) is pre-computed at each resolution \(l\) from the input field. Standard patch merging is invalid for invariant tokens (averaging vectors in disparate local bases breaks physical consistency). Resolution changes use a “globalize–resample–localize” procedure: features are projected back to the global frame, undergo valid spatial operations (pooling/interpolation), and are re-projected into the target resolution's local bases—ensuring rotation-invariance across all scales.

Interactive Shifted Windows

Same Tokens, Same Windows — Across Rotation

ReViT leverages the permutation-equivariance of self-attention applied in each window, so the only thing the shifted-window pipeline needs to guarantee for global rotation equivariance is this: the set of tokens that share a window is the same in the normal lane and in the rotated lane. No attention is drawn here — the fake numbers label token identity so you can trace them through the cycle. Step through the stages to watch which tokens land in which window at each moment where attention would run.

Rotated lane
Pipeline step
Normal lane Input 8×8
Rotated lane Input 8×8
Tracked identity set 36, 37, 44, 45
Window-set invariance at this stage windows match across lanes
Tracked tokens in …

Stage 1 — Input 8×8 partitioned into four 4×4 windows.

Equivariance Analysis:

ReViT achieves exact chiral octahedral group \(O\) equivariance and approximate \(SO(3)\) equivariance. The gap stems from grid-based constraints: resampling introduces interpolation bias, and discretization artifacts arise from fixed patch/window boundaries. Unlike \(\frac{\pi}{2}\) rotations, arbitrary rotations break grid symmetry. Data augmentation helps dampen these artifacts.


Results

Classification: RotMNIST

ReViT achieves SOTA accuracy (98.26%) while delivering a ~4× speedup and a ~53× memory reduction (1.81 GB vs. 95.5 GB) compared to lifted baselines. The baselines' inefficiency stems from the lifting operation that expands self-attention complexity to \(\mathcal{O}(N^2 |\mathcal{H}|^2)\).

ModelAcc (%)Train (ms)Infer (ms)Mem (GB)
GSA-Nets(\(R_4\))97.46298.8±0.9110.0±0.15.27
GSA-Nets(\(R_8\))97.90144.2±2.165.2±0.229.9
GSA-Nets(\(R_{12}\))97.97272.6±0.6118.7±0.595.5
GE-ViT(\(R_{12}\))98.01281.0±0.7118.9±0.395.5
ReViT (Ours)98.2667.7±0.231.0±0.71.81

2D Advection (Adv)

ReViT achieves the lowest MSE (\(\approx 10^{-4}\)) among all compared methods. It consistently outperforms the non-equivariant PDETrans, highlighting the contribution of equivariant mechanisms. The computational overhead of ReViT is comparable to PDETrans with only 11.6% increase, yet delivers 37.2% MSE reduction.

Robustness on Arbitrary Angles (2D Kolmogorov Flow)

We analyze prediction accuracy over 20 rollout steps across angular intervals of \(\frac{\pi}{12}\) within the range \((0, \pi)\), focusing on orthogonal angular pairs (\(\theta\) and \(\theta + \frac{\pi}{2}\)). ReViT demonstrates perfect equivariance for all orthogonal pairs with exactly +0.0% relative error, regardless of input angle \(\theta\). PDETrans shows high variance (up to +162.8%) on unseen angles.

3D Magnetohydrodynamics (MHD)

ReViT achieves the lowest MSE (\(0.82 \times 10^{-2}\)) and highest \(R^2\) (0.98), outperforming the strongest baseline (AViT) by approximately 44% in MSE. ReViT preserves sharp, high-frequency structures of the magnetic field and velocity eddies, remaining virtually indistinguishable from the reference.

3D Turbulent Channel Flow (TCF)

TCF represents a symmetry-starved regime with severe spatial anisotropy. ReViT performs the best with MSE of \(0.21 \times 10^{-2}\) and \(R^2\) of 0.96, representing a 65% error reduction compared to the next best models.

3D Quantitative Results

Metrics computed over the full chiral octahedral group \(O\) with three different seeds, reported as mean ± std:

Model MHD TCF
MSE (\(\times 10^{-2}\)) ↓\(R^2\) ↑ MSE (\(\times 10^{-2}\)) ↓\(R^2\) ↑
AFNO16.40 ± 42.300.60 ± 1.0028.40 ± 56.40-3.79 ± 9.49
P3D10.20 ± 6.240.73 ± 0.155.72 ± 2.940.04 ± 0.05
UNet3D3.64 ± 0.930.90 ± 0.037.12 ± 3.67-0.20 ± 0.62
Swin3D3.58 ± 1.190.90 ± 0.030.60 ± 0.220.90 ± 0.04
AViT2.20 ± 0.360.94 ± 0.010.60 ± 0.230.90 ± 0.04
ReViT-3D (Ours)0.82 ± 0.000.98 ± 0.000.21 ± 0.000.96 ± 0.00

Ablation Study

A systematic ablation study identifies the necessity of each ReViT component. Removing any single component leads to measurable degradation in both accuracy and equivariance.


BibTeX

@inproceedings{ReViT2026,
  title     = {{ReViT}: Rotational-equivariant Vision Transformers for Neural {PDE} Solvers},
  author    = {Hao Wei and Bjoern List and Nils Thuerey},
  booktitle = {Forty-Third International Conference on Machine Learning},
  year      = {2026},
}