Skip to content

Add token sampling (temperature, top-k, top-p) for GPT-2 generation#33

Draft
neegovindan wants to merge 1 commit intogovindansriram:mainfrom
neegovindan:issue26
Draft

Add token sampling (temperature, top-k, top-p) for GPT-2 generation#33
neegovindan wants to merge 1 commit intogovindansriram:mainfrom
neegovindan:issue26

Conversation

@neegovindan
Copy link

Summary

Fixes #26 by adding non-greedy token selection for generation.

Acceptance Criteria Mapping

  • Temperature sampling: implemented via scaling in .
  • Top-K sampling: implemented via logits filtering with .
  • Top-P (nucleus) sampling: implemented via cumulative-probability filtering with boundary-token retention.
  • Torch/Triton acceptable for now: implemented in Torch ( + ) as requested.

Changes

  • Added

    • input validation for sampling arguments
  • Updated
    • added
    • supports greedy () and sampling ()
    • EOS-aware early-stop handling for batches
  • Updated exports in
  • Updated parity path in
    • switched custom model generation from manual argmax loop to deterministic
  • Added
    • greedy matches argmax
    • top-k deterministic behavior for
    • top-p filtering behavior and boundary-token coverage
    • invalid argument checks

Notes

  • This PR keeps deterministic generation for parity tests while enabling sampling for real inference.
  • Sampling currently runs in Torch; CUDA/Triton kernels can be layered on later without API changes.

Validation

  • Listing '/Users/ngovindan/Documents/Work/app/cobraml/CobraMLng'...
    Compiling '/Users/ngovindan/Documents/Work/app/cobraml/CobraMLng/setup.py'...
    Listing '/Users/ngovindan/anaconda3/lib/python311.zip'...
    Can't list '/Users/ngovindan/anaconda3/lib/python311.zip'
    Listing '/Users/ngovindan/anaconda3/lib/python3.11'...
    Compiling '/Users/ngovindan/anaconda3/lib/python3.11/_sysconfigdata_x86_64_apple_darwin13_4_0.py'...
    Listing '/Users/ngovindan/anaconda3/lib/python3.11/lib-dynload'...
    Listing '/Users/ngovindan/anaconda3/lib/python3.11/site-packages'...
    Compiling '/Users/ngovindan/anaconda3/lib/python3.11/site-packages/pycurl.py'...
    Listing '/Users/ngovindan/anaconda3/lib/python3.11/site-packages/aeosa'... on all modified Python files passed.
  • Full ============================= test session starts ==============================
    platform darwin -- Python 3.11.3, pytest-7.3.1, pluggy-1.0.0
    rootdir: /Users/ngovindan/Documents/Work/app/cobraml/CobraMLng
    configfile: pyproject.toml
    testpaths: python/tests
    plugins: anyio-3.5.0, hydra-core-1.3.2
    collected 0 items / 3 errors

==================================== ERRORS ====================================
____________ ERROR collecting python/tests/layers/test_attention.py ____________
ImportError while importing test module '/Users/ngovindan/Documents/Work/app/cobraml/CobraMLng/python/tests/layers/test_attention.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../../../../anaconda3/lib/python3.11/importlib/init.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
python/tests/layers/test_attention.py:2: in
from cobraml.layers import MultiHeadAttention, FusedMultiHeadAttention
E ModuleNotFoundError: No module named 'cobraml'
_____________ ERROR collecting python/tests/models/test_gpt2-xl.py _____________
ImportError while importing test module '/Users/ngovindan/Documents/Work/app/cobraml/CobraMLng/python/tests/models/test_gpt2-xl.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../../../../anaconda3/lib/python3.11/importlib/init.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
python/tests/models/test_gpt2-xl.py:1: in
from cobraml.utils import load_hf_config
E ModuleNotFoundError: No module named 'cobraml'
____________ ERROR collecting python/tests/models/test_sampling.py _____________
ImportError while importing test module '/Users/ngovindan/Documents/Work/app/cobraml/CobraMLng/python/tests/models/test_sampling.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
../../../../../anaconda3/lib/python3.11/importlib/init.py:126: in import_module
return _bootstrap._gcd_import(name[level:], package, level)
python/tests/models/test_sampling.py:4: in
from cobraml.models import sample_next_token
E ModuleNotFoundError: No module named 'cobraml'
=========================== short test summary info ============================
ERROR python/tests/layers/test_attention.py
ERROR python/tests/models/test_gpt2-xl.py
ERROR python/tests/models/test_sampling.py
!!!!!!!!!!!!!!!!!!! Interrupted: 3 errors during collection !!!!!!!!!!!!!!!!!!!!
============================== 3 errors in 10.81s ============================== could not be run in this environment due a local torch import/runtime abort.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Sampling

2 participants