BayesianModule#
- class torchbayesian.bnn.BayesianModule(module: Module, variational_posterior: str | Tuple[str, Dict[str, Any]] | None = None, prior: str | Tuple[str, Dict[str, Any]] | None = None, *, dtype: dtype | None = None, device: device | str | int | None = None, verbose: bool = False)#
Bases:
ModuleThis class is a Torch ‘nn.Module’ container that reparametrizes in-place the ‘nn.Parameter’ parameters of any torch model or module with a variational posterior, and allows computation of the KL divergence.
The wrapped modules are thus converted into a Bayesian neural network (BNN) via variational inference (VI) through Bayes by Backprop (BBB), as described in “Weight Uncertainty in Neural Networks” by Blundell et al.
After construction, ‘module’ original parameters are no longer optimizable ‘nn.Parameter’; instead, the trainable ‘nn.Parameter’ weights live inside the injected reparametrizations.
Parameters#
- module: Module
The module or model to turn into a Bayesian neural network (BNN). Its parameters will be reparameterized in-place.
- variational_posteriorOptional[str | Tuple[str, Dict[str, Any]]]
The variational posterior distribution for the parameters. Either the name (str) of the variational posterior or a tuple of its name and keyword arguments. Defaults to ‘GaussianPosterior’.
- priorOptional[str | Tuple[str, Dict[str, Any]]]
The prior distribution for the parameters. Either the name (str) of the prior or a tuple of its name and keyword arguments. Defaults to ‘GaussianPrior’ with zero mean and unit standard deviation.
- dtype: Optional[_dtype]
The dtype on which the KL divergence accumulator reference buffer is initialized. A buffer is initialized in order to track the device and dtype of the module’s parameters through internal calls to ‘_apply’ so that the KL accumulator’s device and dtype fit that of the module’s parameters. Optional. Defaults to torch default dtype. It is recommended to use ‘BayesianModule(…).to(device, dtype)’ instead of this argument !
- device: Device
The device on which the KL divergence accumulator reference buffer is initialized. A buffer is initialized in order to track the device and dtype of the module’s parameters through internal calls to _apply so that the KL accumulator’s device and dtype fit that of the module’s parameters. Optional. Defaults to torch default device. It is recommended to use ‘BayesianModule(…).to(device, dtype)’ instead of this argument !
- verbosebool
Whether to print a message for each parameter being reparameterized. Defaults to False.
Attributes#
- modulenn.Module
The reparameterized module or model.
Notes#
By default, a Gaussian variational posterior distribution and a Gaussian prior distribution are used. This allows an analytical evaluation of the KL divergence and is a common standard in BBB VI, even though the original paper proposes a scale mixture of Gaussians as the prior.
Custom variational posteriors and priors can be used by registering a factory function to ‘PosteriorFactory’ or ‘PriorFactory’, as described in the docs of ‘Factory’ in ‘torchbayesian.bnn.utils.factories’.
Warning#
This should be applied before registering model parameters to an optimizer. Otherwise, the new variational parameters must be manually registered to the optimizer.
- forward(input: Tensor) Tensor#
Forward pass of the module.
Parameters#
- inputTensor
The input tensor.
Returns#
- outputTensor
The output tensor.
- kl_divergence(*, reduction: str = 'sum', approx_num_samples: int | None = None) Tensor#
Gets the KL divergence KL(posterior || prior) of all parameters in the ‘BayesianModule’.
If analytical solution is not defined between the two distributions, MC approximation of the KL divergence can be used.
Parameters#
- reductionstr
The reduction to apply to the full elementwise parameters KL divergence. Either “sum” (sum of all elements making up all the parameters) or “mean” (mean of all elements making up all the parameters). In theory, true ELBO uses the sum of the elementwise KL divergences, but in practice this can scale badly with model size and mini-batching. Therefore, in practice, it is not uncommon to scale the KL divergence or to use a mean reduction of the KL divergence . Defaults to “sum”.
- approx_num_samplesOptional[int]
The number of samples for the MC approximation of the KL divergence. Only useful if no analytical solution of the KL divergence between the posterior and prior distributions is implemented. Defaults to None.
Returns#
- kl_divTensor
The KL divergence KL(posterior || prior) of the ‘BayesianModule’.
Raises#
- ValueError
If reduction type is invalid.
Warning#
KL divergence is computed using an accumulator in order to avoid the overhead with using a list of KL terms, but the accumulator must be on appropriate device which is why ‘BayesianModule’ tracks the buffer ‘_kl_meta’. As such, if ‘BayesianModule’ is not initialized on same device as the original module, and if no move to appropriate dtype/device is done afterward (e.g. using ‘net.to(…)’ or ‘net.cuda()’), then the accumulator’s dtype/device might not fit with KL divergence terms coming from the parameters. This is easily fixable by calling .to(…) to move all parameters and buffers of the ‘BayesianModule’ to the same dtype/device.
- TODO – Currently, this feature may not work for distributions where factorization is not possible
Element-wise 1d KL -> Sum (or mean) breaks down.