register_reparametrization#

torchbayesian.bnn.utils.register_reparametrization(module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False) Module#

Replaces a tensor (parameter or buffer) in a module by registering a stochastic reparametrization module in its place.

Assume that ‘tensor_name=”weight”’ for simplicity. After reparametrization, the original tensor ‘module.weight’ is removed and replaced by a Python property. Accessing ‘module.weight’ now calls a corresponding ‘Reparametrization’ module which returns a tensor returned by calling ‘parametrization()’ (typically, that is a tensor sampled from the variational posterior), rather than using the original tensor.

If the original tensor requires a gradient, the backward pass differentiates through the reparametrization module and the optimizer updates the variational parameters of the reparametrization module instead of the original parameter. The parameters and buffers of ‘parametrization’ are registered to the model and state_dict, and the original tensor is removed from state_dict.

Parameters#

moduleModule

The module whose tensor is to be reparametrized.

tensor_namestr

The name of the parameter or buffer to reparametrize.

parametrizationModule

The module with which to replace the tensor. Typically, this is the variational posterior.

unsafebool

Whether to bypass correctness checks. Optional. Defaults to False.

Returns#

moduleModule

The module, reparametrized in-place.

Examples#

mod = nn.Linear(2, 4) # This module has parameters ‘mod.bias’ and ‘mod.weight’

# We replace the parameter ‘mod.weight’ by a Gaussian variational posterior with learnable parameters # corresponding to the mean and standard deviation : register_reparametrization(mod, “weight”, bnn.GaussianPosterior(…))

# The module ‘mod’ now has parameters ‘mod.bias’, ‘mod.reparametrizations.weight.variational_posterior.mu’ # and ‘mod.reparametrizations.weight.variational_posterior.rho’