Mixed Channels Version of PDE-Transformer
The mixed channels (MC) version of PDE-Transformer is an efficient implementation that embeds different physical channels within the same token. This approach offers several advantages and trade-offs compared to the separate channels version.
Key Characteristics
Architecture
- Embeds different physical channels (e.g., velocity, density) within the same token
- Uses a shared token representation for all physical quantities
- Maintains a single attention mechanism across all channels
Advantages
- Computational Efficiency
- Higher information density per token
- Reduced memory footprint
- Faster training and inference times
- More efficient use of model parameters
- Simplified Architecture
- Single token stream for all physical quantities
- Simpler implementation and maintenance
Limitations
-
Transfer Learning Constraints
- Channel types must be known at training time
- Less flexible for transfer learning applications
- May require retraining for new channel configurations
-
Representation Learning
- Less disentangled representation of physical channels
- May require more training data to learn complex relationships
Performance
The mixed channels version achieves comparable performance to the separate channels version while being more computationally efficient. However, it may show reduced performance in transfer learning scenarios or when dealing with novel channel configurations.
Initialization
Here's how to initialize the PDETransformer model:
from pdetransformer.core.mixed_channels import PDETransformer
# Initialize the model
model = PDETransformer(
sample_size=256,
in_channels=4,
out_channels=4,
type="PDE-S",
periodic=True,
carrier_token_active=False,
window_size=8,
patch_size=4,
)
# Optional parameters (kwargs)
model = PDETransformer(
# ... required parameters ...
hidden_size=96,
num_heads=4,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000,
)
Parameter Explanation
sample_size
: The spatial dimensions of the input/output grid (e.g., 64 for a 64x64 grid). This is a positional argument, but it not used right now and only there to provide a unified initialization of different models. Therefore it can be set to an arbitrary value and the model can be applied to grids of variable sizes.in_channels
: Number of input physical channels (e.g., velocity components, density)out_channels
: Number of output physical channelstype
: Defines the model config, i.e. PDE-S, PDE-B or PDE-Lperiodic
: Whether to use periodic boundary conditions in the simulationcarrier_token_active
: Whether to use carrier tokens for global information exchange. This enables hierarchical attention (https://arxiv.org/abs/2306.06189). Not compatible with periodic boundary conditions at the moment; default: False.window_size
: Size of the local attention window (larger windows capture more context but require more computation)patch_size
: Size of patches used for token embedding (smaller patches preserve more spatial detail but increase computation)hidden_size
: Dimension of the hidden representations (determines model capacity)num_heads
: Number of attention heads (more heads allow for different types of attention patterns)mlp_ratio
: Ratio of MLP hidden dimension to embedding dimensionclass_dropout_prob
: Dropout probability for class conditioningnum_classes
: Number of classes for conditioning (e.g., different PDE types or simulation parameters)
Usage
Here's an example of how to use the model for forward prediction:
import torch
model = model.cuda()
# prepare input data
batch_size = 4
height, width = 64, 64
num_channels = 4
x = torch.randn(batch_size, num_channels, height, width).cuda()
class_labels = torch.zeros(batch_size).int().cuda()
timestep = torch.zeros(batch_size).cuda()
output = model(
hidden_states=x,
timestep=timestep,
class_labels=class_labels,
)
print('Output shape: ', output.sample.shape)
Forward Pass Parameters
-
x
: tensor of shape (B, C, H, W)- B: Batch size
- C: Input channels
- H, W: Height and width of the spatial grid
-
timestep
: tensor of shape (B,). Can be set toNone
. Used as diffusion time when training probabilistic models. -
class_labels
: tensor of shape (B,). Needs to be int or long. Can be set toNone
. Used to indicate type of PDE, see datasets for class labels of APE 2D PDEs.
Example Notebook
An example notebook how to run inference for pretrained PDE-Transformer and additional explanations/code examples can be found at notebooks/visualization_mc_ape2d.ipynb.