Question 1. Did anyone get Orb running in float16 or bfloat16?
Question 2. If I take float32 weights and train a bit in float16, what would convince us model didn't break? (e.g., would low validation MAE on MPTraj be sufficient?)
Apologies if I'm missing anything.
(update) Sorry for wall of text below, added details to make others didn't have to reproduce. If useful happy to write PR (with non-hacky fix for brute_force_kNN).
TLDR: bf16 is 2-3x faster on rtx4090. To get xs model working I had to hack brute_force_kNN. This gave 6.5steps/s = 0.28ns/day for a 36k atom system. I expect around 0.6-0.8ns/day on h100. Big thing left is to evaluate performance drop, e.g., reproduce Figure 2 in bf16.