Reparametrization#

class torchbayesian.bnn.utils.Reparametrization(parametrization: Module, original: Tensor, *, unsafe: bool = False)#

Bases: Module

This class wraps the variational posterior module and handles safety checks for the replacement of the parameter or buffer by the variational posterior and its forward call.

It is the type of ‘module.reparametrizations[tensor_name]’ when ‘module[tensor_name]’ has been reparametrized with ‘register_reparametrization()’.

Parameters#

parametrizationModule

The parametrization function; the variational posterior in the context of Bayes by Backprop.

originalTensor

The tensor that is being reparametrized.

unsafebool

Whether to bypass correctness checks. Optional. Defaults to False.

Attributes#

variational_posteriorModule

The parametrization replacing ‘original’.

Notes#

This class is used internally by ‘register_reparametrization()’. It shall not be instantiated by the user.

forward() Tensor#

Wraps the parametrization’s forward call. Checks for scripting.

In the context of Bayes by Backprop, this returns a sample from the variational posterior.

extra_repr() str#

Returns the variational posterior put in place.

Returns#

extra_reprstr

The str extra representation of the reparametrization.