This package implements SGLD and cSGLD as a PyTorch Optimizer.
Install from pip as:
pip install torch-sgldTo install the latest directly from source, run
pip install git+https://github.com/activatedgeek/torch-sgld.gitThe general idea is to modify the usual gradient-based update loops
in PyTorch with the SGLD optimizer.
from torch_sgld import SGLD
f = module() ## construct PyTorch nn.Module.
sgld = SGLD(f.parameters(), lr=lr, momentum=.9) ## Add momentum to make it SG-HMC.
sgld_scheduler = ## Optionally add a step-size scheduler.
for _ in range(num_steps):
energy = f()
energy.backward()
sgld.step()
sgld_scheduler.step() ## Optional scheduler step.cSGLD can be implemented by using a cyclical learning rate schedule.
See the toy_csgld.ipynb notebook for a
complete example.
Apache 2.0