The FeatureMixing layer applies feed-forward MLPs across the feature (channel) dimension to mix information between different time series while preserving temporal structure. This enables cross-series learning and correlation discovery, complementing the TemporalMixing layer's temporal processing.
This layer is essential for capturing relationships between features in multivariate forecasting, allowing the model to learn feature interactions through a two-layer feed-forward network with a configurable hidden dimension.
๐ How It Works
The FeatureMixing layer processes data through feature-space transformations:
Flatten: Reshapes from (batch, time, features) to (batch, time ร features)
Batch Normalization: Normalizes across all dimensions (epsilon=0.001, momentum=0.01)
Reshape: Restores to (batch, time, features)
First Dense Layer: Projects to hidden dimension ff_dim with ReLU activation
First Dropout: Regularization after first layer
Second Dense Layer: Projects back to original feature dimension
Second Dropout: Final stochastic regularization
Residual Connection: Adds input to output for gradient flow
graph LR
A["Input<br/>(batch, time, feat)"] --> B["Flatten<br/>โ (batch, tรf)"]
B --> C["BatchNorm<br/>ฮต=0.001"]
C --> D["Reshape<br/>โ (batch, t, f)"]
D --> E["Dense(ff_dim)<br/>ReLU"]
E --> F["Dropout 1<br/>rate=dropout"]
F --> G["Dense(feat)<br/>Linear"]
G --> H["Dropout 2<br/>rate=dropout"]
H --> I["Residual<br/>output + input"]
I --> J["Output<br/>(batch, t, f)"]
style A fill:#e6f3ff,stroke:#4a86e8
style J fill:#e8f5e9,stroke:#66bb6a
style E fill:#fff9e6,stroke:#ffb74d
style I fill:#f3e5f5,stroke:#9c27b0
๐ก Why Use This Layer?
Challenge
Traditional Approach
FeatureMixing Solution
Feature Correlation
Independent processing
๐ฏ Joint feature learning
Cross-Series Learning
Ignores relationships
๐ Learnable cross-series interactions
Non-Linear Interactions
Linear combinations
๐ง Non-linear MLPs for expressiveness
Flexibility
Fixed architectures
๐๏ธ Configurable ff_dim for capacity
๐ Use Cases
Cross-Series Correlation: Discovering relationships between multiple time series
Feature Interactions: Learning non-linear interactions between features
Dimensionality Modulation: Using ff_dim to compress or expand feature space
Multivariate Forecasting: When features have strong interdependencies
Transfer Learning: Feature extraction with learned cross-series patterns
importtensorflowastflayer=FeatureMixing(n_series=7,input_size=96,dropout=0.2,ff_dim=64)x=keras.random.normal((32,96,7))# Training mode: dropout active, batch norm learningoutput_train1=layer(x,training=True)output_train2=layer(x,training=True)train_variance=tf.reduce_mean(tf.abs(output_train1-output_train2))print(f"Training variance (due to dropout): {train_variance:.6f}")# Inference mode: deterministic, batch norm frozenoutput_infer1=layer(x,training=False)output_infer2=layer(x,training=False)tf.debugging.assert_near(output_infer1,output_infer2)print("Inference: outputs are identical โ")
Analyzing Feature Interactions
1 2 3 4 5 6 7 8 9101112131415
# Single feature impact# Create layer to study feature interactionslayer=FeatureMixing(n_series=5,input_size=48,dropout=0.1,ff_dim=16)# Create test input with one feature set to different valuesx_base=keras.random.normal((1,48,5))x_modified=x_base.numpy().copy()x_modified[0,:,0]=1.0# Set feature 0 to constantout_base=layer(x_base,training=False)out_modified=layer(x_modified,training=False)# Check how much the output changedimpact=tf.reduce_mean(tf.abs(out_base-out_modified))print(f"Feature 0 impact on output: {impact:.6f}")
importtensorflowastffromkerasfactory.layersimportFeatureMixinglayer=FeatureMixing(n_series=7,input_size=96,dropout=0.1,ff_dim=64)x=tf.random.normal((32,96,7))# Test 1: Shape preservationoutput=layer(x)assertoutput.shape==x.shape,"Shape mismatch!"# Test 2: Feature mixing effectoutput=layer(x,training=False)diff=tf.reduce_max(tf.abs(output-x))assertdiff>0,"Output should differ due to feature mixing"# Test 3: Different ff_dim variantsforff_dimin[4,7,14,32]:layer_var=FeatureMixing(n_series=7,input_size=96,dropout=0.1,ff_dim=ff_dim)out_var=layer_var(x)assertout_var.shape==x.shape,f"Failed for ff_dim={ff_dim}"print("โ All tests passed!")
โ ๏ธ Common Issues & Solutions
Issue
Cause
Solution
NaN outputs
Unstable batch norm or extreme inputs
Normalize inputs; check weight initialization
Slow convergence
ff_dim too small or dropout too high
Increase ff_dim; reduce dropout to 0.05-0.1
Memory issues
Large ff_dim or batch size
Reduce ff_dim; use smaller batches
Poor feature learning
Insufficient mixing capacity
Increase ff_dim; use 1.5-2x n_series
Overfitting
Insufficient regularization
Increase dropout to 0.2-0.3
๐ Related Layers & Components
TemporalMixing: Complements by mixing temporal dimension