Skip to content

nimanzik/mdn-pytorch

Repository files navigation

MDN PyTorch

A PyTorch implementation of Mixture Density Networks (MDN) for modelling multi-modal distributions in regression tasks.

Python 3.13+ pytorch uv Ruff prek CI License: MIT

Overview

A Mixture Density Network (MDN) is a neural network that predicts the parameters of a mixture of distributions rather than a single point estimate. This makes MDN particularly useful for regression problems where:

  • The relationship between inputs and outputs is one-to-many (inverse problems).
  • The target distribution is multi-modal or highly uncertain.

This implementation uses Gaussian mixture models where each component shares the same standard deviation across output dimensions (isotropic Gaussians).

Features

  • Multiple inference modes:
    • weighted_mean: Expected value E[Y|X] computed as weighted average of component means.
    • argmax_mean: Mean of the most probable component (it's a fast approximation).
    • sample_[mean|median]: Mean/median of samples drawn from the mixture distribution.
  • PyTorch Lightning integration: Built-in training module for easy training and validation.
  • Pydantic-based configuration: Type-safe model and training setup.
  • C-Mixup support: Data augmentation that improves generalisation on regression tasks.

Installation

You can add (install) the package to your project using uv:

uv add git+https://github.com/nimanzik/mdn-pytorch.git

Quick Start

import torch

from mdn_pytorch.loss import mdn_loss
from mdn_pytorch.model import MixtureDensityNetwork

input_dim = 12  # Number of input features
output_dim = 3  # Number of output dimensions
batch_size = 32

model = MixtureDensityNetwork(
    input_dim=input_dim,
    hidden_dims=[128, 64, 32],
    output_dim=output_dim,
    n_components=5,
    activation_type="gelu",
)

# Forward pass: get mixture parameters
x = torch.randn(batch_size, input_dim)
log_pi, mu, sigma = model(x)

# Compute loss
y_true = torch.randn(batch_size, output_dim)
loss = mdn_loss(log_pi, mu, sigma, y_true)

# Generate predictions
predictions = model.predict(x, inference_type="sample_median")

# Sample from the mixture distribution
samples = model.generate_samples(x, n_samples=1_000)

# Predict specific quantiles
quantiles = model.predict_quantiles(x, quantiles=[0.05, 0.5, 0.95])

Requirements

  • Python ≥ 3.13
  • PyTorch ≥ 2.9.0
  • PyTorch Lightning ≥ 2.5.6
  • timm ≥ 1.0.22
  • Pydantic ≥ 2.12.4
  • NumPy ≥ 2.3.4
  • scikit-learn ≥ 1.7.2

License

This project is licensed under the MIT License. See the LICENSE file for details.

About

Mixture Density Network (MDN) implementation in PyTorch, with Lightning training support, Pydantic config management, and C-Mixup data augmentation

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors