flashrl does RL with millions of steps/second ๐จ while being tiny: ~200 lines of code
๐ ๏ธ pip install flashrl or clone the repo & pip install -r requirements.txt
- If cloned (or if envs changed), compile: python setup.py build_ext --inplace
๐ก flashrl will always be tiny: Read the code (+paste into LLM) to understand it!
flashrl uses a Learner that holds an env and a model (default: Policy with LSTM)
import flashrl as frl
learn = frl.Learner(frl.envs.Pong(n_agents=2**14))
curves = learn.fit(40, steps=16, desc='done')
frl.print_curve(curves['loss'], label='loss')
frl.play(learn.env, learn.model, fps=8)
learn.env.close().fit does RL with ~10 million steps: 40 iterations ร 16 steps ร 2**14 agents!
Run it yourself via python train.py and play against the AI ๐ช
Click here, to read a tiny doc ๐
Learner takes the arguments
- env: RL environment
- model: A- Policymodel
- device: Per default picks- mpsor- cudaif available else- cpu
- dtype: Per default- torch.bfloat16if device is- cudaelse- torch.float32
- compile_no_lstm: Speedup via- torch.compileif- modelhas no- lstm
- **kwargs: Passed to the- Policy, e.g.- hidden_sizeor- lstm
Learner.fit takes the arguments
- iters: Number of iterations
- steps: Number of steps in- rollout
- desc: Progress bar description (e.g.- 'reward')
- log: If- True,- tensorboardlogging is enabled- run tensorboard --logdir=runsand visithttp://localhost:6006in the browser!
 
- run 
- stop_func: Function that stops training if it returns- Truee.g.
...
def stop(kl, **kwargs):
  return kl > .1
curves = learn.fit(40, steps=16, stop_func=stop)
...- lr,- anneal_lr& args of- ppoafter- bs: Hyperparameters
The most important functions in flashrl/utils.py are
- print_curve: Visualizes the loss across the- iters
- play: Plays the environment in the terminal and takes- model: A- Policymodel
- playable: If- True, allows you to act (or decide to let the model act)
- steps: Number of steps
- fps: Frames per second
- obs: Argument of the env that should be rendered as observations
- dump: If- True, no frame refresh -> Frames accumulate in the terminal
- idx: Agent index between- 0and- n_agents(default:- 0)
 
Each env is one Cython(=.pyx) file in flashrl/envs. That's it!
To add custom envs, use grid.pyx, pong.pyx or multigrid.pyx as a template:
- grid.pyxfor single-agent envs (~110 LOC)
- pong.pyxfor 1 vs 1 agent envs (~150 LOC)
- multigrid.pyxfor multi-agent envs (~190 LOC)
| Grid | Pong | MultiGrid | 
|---|---|---|
| Agent must reach goal | Agent must score | Agent must reach goal first | 
|  |  |  | 
I want to thank
- Joseph Suarez for open sourcing RL envs in C(ython)! Star PufferLib โญ
- Costa Huang for open sourcing high-quality single-file RL code! Star cleanrl โญ
and last but not least...

