torchbayesian.bnn.utils.PriorFactory#
- torchbayesian.bnn.utils.PriorFactory = <torchbayesian.bnn.utils.factories.factory.Factory object>#
This class is a base class for object factories.
This class serves as a dynamic registry of factory functions so that new factory functions can be registered to instances of this class with the decorator register_factory_function(). This allows users, for example, to register a custom-made ‘Prior’ to the API and call it through the typical torchbayesian pipeline.
Examples#
# Create an instance of ‘Factory’ to serve as a normalization layer factory Norm = Factory()
# Register to the normalization factory ‘Norm’ a factory function that gets batch normalization layers:
@Norm.register_factory_function(“batch”) def batch_norm_factory(dim) -> BatchNorm:
types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) return types[dim - 1]
# Now, one can ‘get’ from the normalization factory ‘Norm’ to obtain a batch normalization layer batch_norm_layer = Norm[“batch”, 2] # norm_dim=2
# note: batch_norm_layer is nn.BatchNorm3d and not nn.BatchNorm3d()