The ReversibleInstanceNormMultivariate layer extends reversible instance normalization to multivariate time series by computing statistics across the batch dimension. This is essential for scenarios where you need consistent normalization across multiple series with different scales.
Key features:
- Batch-Level Normalization: Computes mean/std across all samples in the batch
- Reversible: Exact denormalization preserves interpretability
- Multivariate Support: Handles multiple features simultaneously
- Optional Affine: Learnable scale and shift parameters
- Training Stability: Improves convergence with diverse scaling
๐ How It Works
1 2 3 4 5 6 7 8 9101112131415161718
Input Time Series
(batch=B, time=T, features=F)
|
V
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Compute Batch Statistics โ
โ mean = mean(x, axis=[0,1]) โ <- Batch + Time
โ std = std(x, axis=[0,1]) โ (F,)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
V
Normalize: (x - mean) / (std + eps)
|
V
Optional Affine: y * gamma + beta
|
V
Normalized Output (B, T, F)
The normalization uses statistics computed across both batch and time dimensions, creating a global normalization for the entire dataset.
๐ก Why Use This Layer?
Scenario
RevIN
RevIN Multivariate
Result
Single Series
โ Perfect
โ ๏ธ Overkill
Use RevIN
Multiple Series
โ ๏ธ Independent
โ Unified
Use RevINMulti
Cross-Dataset
โ Poor
โ Consistent
Use RevINMulti
Scale Normalization
โ ๏ธ Per-series
โ Global
Use RevINMulti
๐ Use Cases
Multi-Sensor Forecasting: Normalize multiple sensor readings together
Portfolio Returns: Normalize stocks with different volatilities
Traffic Networks: Normalize flows across multiple routes
Power Grids: Normalize consumption across multiple substations
Healthcare: Normalize vital signs from multiple patients
๐ Quick Start
1 2 3 4 5 6 7 8 9101112131415161718
importkerasfromkerasfactory.layersimportReversibleInstanceNormMultivariate# Create normalization layer for multivariate datanormalizer=ReversibleInstanceNormMultivariate(num_features=5,affine=True)# Input: batch of multivariate time seriesx=keras.random.normal((32,100,5))# 32 samples, 100 timesteps, 5 features# Normalize for trainingx_norm=normalizer(x,mode='norm')# Use in model# ... model forward pass ...# Denormalize predictionsy_pred_norm=model(x_norm)y_pred=normalizer(y_pred_norm,mode='denorm')
fromkerasfactory.layersimportReversibleInstanceNormMultivariate# Multiple scales with shared normalizationnormalizer=ReversibleInstanceNormMultivariate(num_features=8,eps=1e-6,affine=True,name='multi_scale_norm')# Different time scalesshort_term=keras.random.normal((64,24,8))# hourlymedium_term=keras.random.normal((64,168,8))# weeklylong_term=keras.random.normal((64,730,8))# yearly# Normalize all with same statisticsshort_norm=normalizer(short_term,mode='norm')medium_norm=normalizer(medium_term,mode='norm')long_norm=normalizer(long_term,mode='norm')# Process separatelyshort_pred=short_model(short_norm)medium_pred=medium_model(medium_norm)long_pred=long_model(long_norm)# Denormalize with same statisticsshort_denorm=normalizer(short_pred,mode='denorm')medium_denorm=normalizer(medium_pred,mode='denorm')long_denorm=normalizer(long_pred,mode='denorm')
mean = (1 / (BรTรF)) ร ฮฃ(x) over all dimensions
std = sqrt((1 / (BรTรF)) ร ฮฃ(x - mean)ยฒ)
x_norm = (x - mean) / (std + eps)
if affine: y = gamma * x_norm + beta
Denormalization Reverse Pass
12
if affine: x_temp = (y - beta) / gamma
x_denorm = x_temp * (std + eps) + mean
๐พ Serialization
1 2 3 4 5 6 7 8 91011121314
importkeras# Build and compile modelmodel=keras.Sequential([ReversibleInstanceNormMultivariate(num_features=8),# ... other layers ...])model.compile(optimizer='adam',loss='mse')# Save model (includes layer configuration)model.save('model.h5')# Load modelloaded_model=keras.models.load_model('model.h5')
๐งช Testing & Validation
1 2 3 4 5 6 7 8 91011121314151617181920212223
importkerasimportnumpyasnpfromkerasfactory.layersimportReversibleInstanceNormMultivariate# Test exact reconstructionnormalizer=ReversibleInstanceNormMultivariate(num_features=8)x_original=keras.random.normal((32,100,8))# Normalizex_norm=normalizer(x_original,mode='norm')# Denormalizex_reconstructed=normalizer(x_norm,mode='denorm')# Check reconstruction errorerror=keras.ops.mean(keras.ops.abs(x_original-x_reconstructed))print(f"Reconstruction error: {error:.2e}")# Should be < 1e-5# Verify mean/std after normalizationmean_norm=keras.ops.mean(x_norm)std_norm=keras.ops.std(x_norm)print(f"Normalized mean: {mean_norm:.6f}")# Should be close to 0print(f"Normalized std: {std_norm:.6f}")# Should be close to 1
๐ฏ Performance Characteristics
Metric
Value
Time Complexity
O(BรTรF)
Space Complexity
O(F) for affine params
Memory Per Sample
O(F)
Training Speed
Fast
Inference Speed
Fast
Last Updated: 2025-11-04 | Keras: 3.0+ | Status: โ Production Ready