You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Downloaded GitHub-hosted images to local assets directory and updated
all image references to use local paths. Converted standalone images to
markdown syntax while keeping centered images as HTML img tags for
proper rendering.
Signed-off-by: Bram Wasti <bwasti@meta.com>
Copy file name to clipboardExpand all lines: _posts/2025-11-10-bitwise-exact-rl.md
+6-6Lines changed: 6 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -12,34 +12,34 @@ Discussion on this can be found on ThinkingMachine’s post Defeating Nondetermi
12
12
13
13
Floating point numbers are effectively a binary scientific notation. They utilize three components: a sign bit (s), a mantissa (M) and an exponent (e).
14
14
<palign="center">
15
-
<imgwidth="340"height="130"alt="Screenshot 2025-11-10 at 5 12 41 PM"src="https://github.com/user-attachments/assets/24275084-1b8c-45fd-b40c-6169ed04c837" />
Each of these components are represented as integers and suffer from the exact same rounding errors you might expect. In bf16, the most commonly used representation for machine learning, 7 bits are dedicated to the mantissa. This is not very many bits! The value 3.0 can be represented exactly, but a value like 3.6 cannot…
19
19
20
20
<palign="center">
21
-
<imgwidth="480"height="355"alt="Screenshot 2025-11-10 at 5 13 24 PM"src="https://github.com/user-attachments/assets/1a51da11-b0b4-45fb-853d-bc19a23c1300" />
When you want a new value in bf16 you end up rounding it to the nearest available value. What’s of particular interest today is the implication of this rounding process happening at different points in a sequence of additions.
25
25
26
-
<imgwidth="944"height="414"alt="Screenshot 2025-11-10 at 5 13 56 PM"src="https://github.com/user-attachments/assets/aa334e61-778a-4a18-ab11-e88bd202d7d2" />
These rounding steps can cause two of the exact same inputs to generate *different* outputs! That means the same framework on the same hardware with the same inputs and the same weights can produce distinct outputs if *any* of the logic *anywhere* in the execution dispatches a different (but still correct) kernel.
29
29
30
30
## Demonstration
31
31
32
32
Reinforcement learning has been shown to amplify tiny numerical perturbations, leading to non-deterministic and unstable training behavior. By combining the [recent work](https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/deterministic_vllm_rl) of vLLM with TorchTitan we were able to demonstrate the stabilized training dynamics of reinforcement learning with exact bitwise parity between generator and trainer. This has been landed as a script in TorchTitan [here](https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py).
33
33
34
-
<imgwidth="1051"height="430"alt="Screenshot 2025-11-10 at 5 14 45 PM"src="https://github.com/user-attachments/assets/6cb38cab-89d4-409f-8abf-db1aeb1e24f2" />
Running the demonstration associated with this blog post we see exactly the issue described below. Running the generator with different kernels than the trainer (batch_inv_OFF) shows a reduced reward over 100 steps. Enabling bitwise exact training, we see the model not only train in fewer steps, but reach a higher total reward!
41
41
42
-
<imgwidth="1319"height="473"alt="Screenshot 2025-11-10 at 5 17 16 PM"src="https://github.com/user-attachments/assets/f2c9d6aa-68c2-4064-b4ab-de425f2b78a7" />
0 commit comments