Skip to content

float16 #71

@AlexanderMath

Description

@AlexanderMath

Question 1. Did anyone get Orb running in float16 or bfloat16?

Question 2. If I take float32 weights and train a bit in float16, what would convince us model didn't break? (e.g., would low validation MAE on MPTraj be sufficient?)

Apologies if I'm missing anything.

(update) Sorry for wall of text below, added details to make others didn't have to reproduce. If useful happy to write PR (with non-hacky fix for brute_force_kNN).

TLDR: bf16 is 2-3x faster on rtx4090. To get xs model working I had to hack brute_force_kNN. This gave 6.5steps/s = 0.28ns/day for a 36k atom system. I expect around 0.6-0.8ns/day on h100. Big thing left is to evaluate performance drop, e.g., reproduce Figure 2 in bf16.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions