Reparametrization#
- class torchbayesian.bnn.utils.Reparametrization(parametrization: Module, original: Tensor, *, unsafe: bool = False)#
Bases:
ModuleThis class wraps the variational posterior module and handles safety checks for the replacement of the parameter or buffer by the variational posterior and its forward call.
It is the type of ‘module.reparametrizations[tensor_name]’ when ‘module[tensor_name]’ has been reparametrized with ‘register_reparametrization()’.
Parameters#
- parametrizationModule
The parametrization function; the variational posterior in the context of Bayes by Backprop.
- originalTensor
The tensor that is being reparametrized.
- unsafebool
Whether to bypass correctness checks. Optional. Defaults to False.
Attributes#
- variational_posteriorModule
The parametrization replacing ‘original’.
Notes#
This class is used internally by ‘register_reparametrization()’. It shall not be instantiated by the user.
- forward() Tensor#
Wraps the parametrization’s forward call. Checks for scripting.
In the context of Bayes by Backprop, this returns a sample from the variational posterior.