VariationalPosterior#

class torchbayesian.bnn.VariationalPosterior(shape: Size | list[int] | tuple[int, ...], *, dtype: dtype | None = None, device: device | str | int | None = None)#

Bases: Module, ABC

This class serves as a base class for all variational posteriors used for Bayes by Backprop (BBB) variational inference (VI).

Parameters#

shape_size

The shape of the parameter replaced by the variational posterior.

dtypeOptional[_dtype]

The dtype of the parameter replaced by the variational posterior. Optional. Defaults to torch’s default dtype.

deviceDevice

The device of the parameter replaced by the variational posterior. Optional. Defaults to torch’s default device.

Attributes#

shape_size

The shape of the tensors sampled from the variational posterior.

Notes#

Subclasses used in ‘bnn.BayesianModule’ must work with ‘get_posterior()’; see ‘from_param’ constructor class method.

Recommended PyTorch-esque pattern for the constructor (‘__init__’ method) of custom subclasses of ‘VariationalPosterior’: (1) Call ‘super().__init__(…)’ then; (2) Create empty variational parameters with appropriate size. e.g. ‘self.mu = nn.Parameter(torch.empty(…))’ then; (3) Call a method ‘self.reset_parameters()’ at the end of ‘__init__’ to initialize the variational parameters.

classmethod from_param(param: Tensor, **kwargs) T#

Alternate constructor used by ‘get_posterior’ inside ‘bnn.BayesianModule’.

Instantiates a variational posterior given the parameter to be replaced.

The default implementation constructs the posterior using only the parameter’s shape, dtype and device. Subclasses of ‘VariationalPosterior’ that require further access to ‘param’ should override this method.

Parameters#

paramTensor

The tensor (parameter or buffer) being replaced by a variational posterior.

**kwargs

Additional keyword arguments passed along to the variational posterior constructor.

Returns#

posterior_instanceSelf

An instance of the variational posterior class.

property dtype: dtype#

The replaced parameter’s supposed dtype.

property device: device | str | int | None#

The replaced parameter’s supposed device.

abstract property distribution: Distribution#

An element-wise torch.Distribution corresponding to the variational posterior. This is used for KL computation aligned with torch’s framework.

Shape should be the same as the shape of the parameters sampled from the variational posterior.

Returns#

distributionDistribution

A torch.Distribution.

abstract sample_parameters() Tensor#

Samples a value for the parameters from the variational posterior distribution (itself parametrized by the variational parameters).

Returns#

paramTensor

A value of the parameter, sampled from the variational posterior.

forward() Tensor#

Forward call for the reparametrization module.

Samples the parameters from the variational posterior distribution (itself parametrized by the variational parameters).

Returns#

paramTensor

A value of the parameter, sampled from the variational posterior.