We should have some API to change the dtype of a module after it has been initialized.
Without such an API, the user has to propagate dtype through the module at construction time, which is cumbersome.
Possible APIs:
module = MyModule()
module.to(tp.float16)
module.cast(tp.float16)
module = tp.cast(module, tp.float16)