The TokenEmbedding layer embeds raw time series values using 1D convolution with learnable filters and bias. It transforms raw numerical input values into rich, learnable feature representations suitable for transformer-based models and deep learning architectures.
This layer is inspired by the TokenEmbedding component used in state-of-the-art time series forecasting models like Informer and TimeMixer. It provides a learnable alternative to fixed embeddings, allowing the model to discover optimal feature representations during training.
๐ How It Works
The TokenEmbedding layer processes data through a 1D convolutional transformation:
Input Reception: Receives raw time series values of shape (batch, time_steps, channels)
Transposition: Rearranges to (batch, channels, time_steps) for Conv1D
1D Convolution: Applies learnable 3ร1 kernels across the time dimension
Same Padding: Preserves temporal dimension using "same" padding
Output Generation: Returns embedded features of shape (batch, time_steps, d_model)
graph TD
A["Input: (batch, time, c_in)"] -->|Transpose| B["(batch, c_in, time)"]
B -->|Conv1D kernel=3<br/>filters=d_model| C["(batch, d_model, time)"]
C -->|Transpose| D["Output: (batch, time, d_model)"]
style A fill:#e6f3ff,stroke:#4a86e8
style D fill:#e8f5e9,stroke:#66bb6a
style B fill:#fff9e6,stroke:#ffb74d
style C fill:#f3e5f5,stroke:#9c27b0
๐ก Why Use This Layer?
Challenge
Fixed Embeddings
Learnable Tokens
TokenEmbedding's Solution
Feature Learning
No learning
Limited
โจ Learnable 1D convolution
Contextual Awareness
No context
Local only
๐ฏ Kernel-size receptive field
Adaptation
Static
Slow
โก Trained end-to-end
Multivariate Support
Single channel
Per-channel
๐ True multi-channel learning
Initialization
Random/fixed
Basic
๐ง Kaiming normal init
๐ Use Cases
Time Series Forecasting: Embedding raw values in LSTM/Transformer models
Anomaly Detection: Feature extraction for anomaly detection models
Time Series Classification: Converting raw series to embeddings for classification
Multivariate Analysis: Processing multiple correlated time series simultaneously
Feature Engineering: Automatic feature extraction from raw temporal data
Preprocessing Pipeline: As first layer in deep time series models
Pre-training: For self-supervised learning on time series
importkerasfromkerasfactory.layersimportTokenEmbedding,PositionalEmbedding# Build forecasting modeldefcreate_forecasting_model():inputs=keras.Input(shape=(96,7))# 96 time steps, 7 features# Embed raw valuesx=TokenEmbedding(c_in=7,d_model=64)(inputs)# Add positional encodingx=x+PositionalEmbedding(max_len=96,d_model=64)(x)# Process with transformersx=keras.layers.MultiHeadAttention(num_heads=8,key_dim=8)(x,x)x=keras.layers.Dense(128,activation='relu')(x)x=keras.layers.Dense(32,activation='relu')(x)# Forecast future valuesoutputs=keras.layers.Dense(7)(x)# Forecast next 7 featuresreturnkeras.Model(inputs,outputs)model=create_forecasting_model()model.compile(optimizer='adam',loss='mse')
With Multivariate Time Series
1 2 3 4 5 6 7 8 9101112131415161718
fromkerasfactory.layersimportTokenEmbedding,TemporalEmbedding,DataEmbeddingWithoutPosition# Multi-feature time series embeddingtoken_emb=TokenEmbedding(c_in=12,d_model=96)temporal_emb=TemporalEmbedding(d_model=96,embed_type='fixed')# Input datax=keras.random.normal((32,100,12))# 12 featuresx_mark=keras.random.uniform((32,100,5),minval=0,maxval=24,dtype='int32')# Embed valuesx_embedded=token_emb(x)# Add temporal contexttemporal_features=temporal_emb(x_mark)combined=x_embedded+temporal_featuresprint(f"Combined embedding shape: {combined.shape}")# (32, 100, 96)
Training Efficiency: Fast convergence with proper initialization
Inference Speed: Optimized for batch processing
๐จ Advanced Usage
Custom Initialization
12345678
fromkerasfactory.layersimportTokenEmbedding# Create layer with custom initializationtoken_emb=TokenEmbedding(c_in=8,d_model=64)# Access the conv layer for custom initializationconv_layer=token_emb.convconv_layer.kernel_initializer=keras.initializers.HeNormal()
# Get layer configurationconfig=token_emb.get_config()# Save to fileimportjsonwithopen('token_embedding_config.json','w')asf:json.dump(config,f)# Recreate from confignew_layer=TokenEmbedding.from_config(config)
๐งช Testing & Validation
1 2 3 4 5 6 7 8 910111213141516171819
# Test with different input sizestoken_emb=TokenEmbedding(c_in=7,d_model=64)# Small batchx_small=keras.random.normal((1,96,7))out_small=token_emb(x_small)assertout_small.shape==(1,96,64)# Large batchx_large=keras.random.normal((256,96,7))out_large=token_emb(x_large)assertout_large.shape==(256,96,64)# Different time stepsx_diff_time=keras.random.normal((32,200,7))out_diff_time=token_emb(x_diff_time)assertout_diff_time.shape==(32,200,64)print("โ All shape tests passed!")
Last Updated: 2025-11-04 Version: 1.0 Keras: 3.0+ Status: โ Production Ready