VariationalPosterior#
- class torchbayesian.bnn.VariationalPosterior(shape: Size | list[int] | tuple[int, ...], *, dtype: dtype | None = None, device: device | str | int | None = None)#
Bases:
Module,ABCThis class serves as a base class for all variational posteriors used for Bayes by Backprop (BBB) variational inference (VI).
Parameters#
- shape_size
The shape of the parameter replaced by the variational posterior.
- dtypeOptional[_dtype]
The dtype of the parameter replaced by the variational posterior. Optional. Defaults to torch’s default dtype.
- deviceDevice
The device of the parameter replaced by the variational posterior. Optional. Defaults to torch’s default device.
Attributes#
- shape_size
The shape of the tensors sampled from the variational posterior.
Notes#
Subclasses used in ‘bnn.BayesianModule’ must work with ‘get_posterior()’; see ‘from_param’ constructor class method.
Recommended PyTorch-esque pattern for the constructor (‘__init__’ method) of custom subclasses of ‘VariationalPosterior’: (1) Call ‘super().__init__(…)’ then; (2) Create empty variational parameters with appropriate size. e.g. ‘self.mu = nn.Parameter(torch.empty(…))’ then; (3) Call a method ‘self.reset_parameters()’ at the end of ‘__init__’ to initialize the variational parameters.
- classmethod from_param(param: Tensor, **kwargs) T#
Alternate constructor used by ‘get_posterior’ inside ‘bnn.BayesianModule’.
Instantiates a variational posterior given the parameter to be replaced.
The default implementation constructs the posterior using only the parameter’s shape, dtype and device. Subclasses of ‘VariationalPosterior’ that require further access to ‘param’ should override this method.
Parameters#
- paramTensor
The tensor (parameter or buffer) being replaced by a variational posterior.
- **kwargs
Additional keyword arguments passed along to the variational posterior constructor.
Returns#
- posterior_instanceSelf
An instance of the variational posterior class.
- property dtype: dtype#
The replaced parameter’s supposed dtype.
- property device: device | str | int | None#
The replaced parameter’s supposed device.
- abstract property distribution: Distribution#
An element-wise torch.Distribution corresponding to the variational posterior. This is used for KL computation aligned with torch’s framework.
Shape should be the same as the shape of the parameters sampled from the variational posterior.
Returns#
- distributionDistribution
A torch.Distribution.