Deep Ensemble (Toy)

This example shows how to use deep ensembles with PusH on a toy regression dataset.

[1]:
# Set a seed for reproducibility
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

torch.manual_seed(0)

# Generate dataset and grid
X = torch.rand(100, 1) * 0.5
x_grid = torch.linspace(-5, 5, 1000).reshape(-1, 1)

# Define function
def target_toy(x, seed):
    torch.manual_seed(seed)
    epsilons = torch.randn(3) * 0.02
    return (
        x + 0.3 * torch.sin(2 * torch.pi * (x + epsilons[0])) +
        0.3 * torch.sin(4 * torch.pi * (x + epsilons[1])) + epsilons[2]
    )

# Generate target values with different seeds
Y = torch.stack([target_toy(x, seed) for x, seed in zip(X, range(X.shape[0]))])
[2]:
class GenericNet(nn.Module):
    def __init__(self, input_dim):
        super(GenericNet, self).__init__()
        self.input_dim = input_dim

        self.net = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=16),
            nn.ELU(),
            nn.Linear(in_features=16, out_features=16),
            nn.ELU(),
            nn.Linear(in_features=16, out_features=1)
        )
        # Additional layers can be added here based on the architecture
        # Apply Xavier uniform initialization to the weights
        self.init_weights(self.net)

    def forward(self, x):
        return self.net(x)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.01)
[3]:
import push.bayes.ensemble
# Combine X and Y into a TensorDataset
dataset = TensorDataset(X, Y)

# Create a DataLoader for batch processing during training
batch_size = 100  # Adjust according to your needs
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
epochs=3000
lr = 0.03
# ensemble = push.bayes.ensemble.train_deep_ensemble(
#         train_loader,
#         torch.nn.MSELoss(),
#         epochs,
#         GenericNet, 1,
#         lr=lr,
#         num_devices=2,
#         num_ensembles=num_ensembles,
#         bootstrap=False
#     )
ensemble_sizes = [2, 4, 8, 16]
num_devices = 2
# Dictionary to store the trained models for each ensemble size
trained_ensembles = {}

# Train an ensemble for each size specified
for size in ensemble_sizes:
    print(f"Training ensemble with {size} models...")
    ensemble = push.bayes.ensemble.train_deep_ensemble(
        train_loader,
        torch.nn.MSELoss(),
        epochs,
        GenericNet, 1,
        lr=lr,
        num_devices=num_devices,
        num_ensembles=size,
        bootstrap=False
    )
    trained_ensembles[size] = ensemble
    print(f"Completed training for ensemble size {size}.")
Training ensemble with 2 models...
100%|██████████| 3000/3000 [00:31<00:00, 95.07it/s, loss=tensor(0.0025)]
Completed training for ensemble size 2.
Training ensemble with 4 models...
100%|██████████| 3000/3000 [00:45<00:00, 66.02it/s, loss=tensor(0.0025)]
Completed training for ensemble size 4.
Training ensemble with 8 models...
100%|██████████| 3000/3000 [01:15<00:00, 39.63it/s, loss=tensor(0.0025)]
Completed training for ensemble size 8.
Training ensemble with 16 models...
100%|██████████| 3000/3000 [05:02<00:00,  9.90it/s, loss=tensor(0.0024)]
Completed training for ensemble size 16.

[14]:
def get_ensemble_outputs(ensembles, x_grid):
    """
    Get predictions, means, and standard deviations from multiple ensembles.

    Parameters:
        ensembles (dict): A dictionary of ensembles keyed by the number of models.
        x_grid (torch.Tensor): The input tensor for which predictions are made.

    Returns:
        dict: A dictionary containing predictions, means, and std for each ensemble size.
    """
    results = {}
    for size, ensemble in ensembles.items():
        output = ensemble.posterior_pred(x_grid, f_reg=True, mode=["mean", "std", "pred"])
        results[size] = {
            'mean': output["mean"],  # Assuming the mean is the first element
            'std': output["std"],   # Assuming the std is the second element
            'pred': output["pred"]  # Assuming the predictions are the third element
        }
        # print(f"Output for ensemble size {size}: Mean shape {output["mean"].shape}, STD shape {output["shape"].shape}")
    return results

outputs = get_ensemble_outputs(trained_ensembles, x_grid)
[4]:
ensemble_output = ensemble.posterior_pred(x_grid, f_reg=True, mode=["mean", "std", "pred"])
[21]:
import matplotlib.pyplot as plt
import torch

def plot_all_toys(outputs, x_grid, X, Y):
    # Create a figure with 4 rows (one for each ensemble size) and 2 columns
    fig, axs = plt.subplots(nrows=4, ncols=2, figsize=[12, 24])

    for i, (num_models, data) in enumerate(outputs.items()):
        # Predictive mean and std (first column of each row)
        axs[i, 0].plot(X, Y, "kx", label="Toy data", markersize=6)
        axs[i, 0].set_xlim(-0.5, 1)
        axs[i, 0].set_ylim(-2, 2)
        axs[i, 0].plot(x_grid, data["mean"], "r--", linewidth=2)
        axs[i, 0].fill_between(x_grid.squeeze(), (data["mean"] - data["std"]).squeeze(), (data["mean"] + data["std"]).squeeze(), alpha=0.5, color="red")
        axs[i, 0].fill_between(
            x_grid.squeeze(), (data["mean"] - 2 * data["std"]).squeeze(), (data["mean"] + 2 * data["std"]).squeeze(), alpha=0.2, color="red"
        )
        axs[i, 0].set_title(f"{num_models} Model Ensemble: Predictive Mean and Std")

        # Individual model outputs (second column of each row)
        axs[i, 1].plot(X, Y, "kx", label="Toy data", markersize=6)
        axs[i, 1].set_xlim(-0.5, 1)
        axs[i, 1].set_ylim(-1.5, 2)
        for j in range(num_models):
            axs[i, 1].plot(x_grid, torch.tensor([sublist[j] for sublist in data["pred"]]), linestyle="--", linewidth=2)
        axs[i, 1].set_title(f"{num_models} Model Ensemble: Individual Outputs")

    plt.tight_layout()
    plt.show()

# Example usage
plot_all_toys({
    2: outputs[2],
    4: outputs[4],
    8: outputs[8],
    16: outputs[16]
}, x_grid, X, Y)

../../../_images/examples_01_Bayesian_Deep_Learning_01_Deep_Ensemble_de_toy_6_0.png