Skip to content
This repository was archived by the owner on Jul 2, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
143 commits
Select commit Hold shift + click to select a range
a339171
Added flash attention.
mmcdermott Sep 7, 2023
f479082
Merge branch 'dev' into enable_flash_attention
mmcdermott Oct 25, 2023
c313c1c
Added a schema output to pytorch dataset.
mmcdermott Oct 25, 2023
8437fe0
Fixed small bug with schema and added test.
mmcdermott Oct 25, 2023
8cbacb9
Small updates to dataset for testing compute costs.
mmcdermott Nov 9, 2023
0504580
First attempt; likely very buggy.
mmcdermott Nov 9, 2023
49c4f02
More changes; not sure if it is 100% working yet or not. Tests certai…
mmcdermott Nov 9, 2023
3d577ab
Fixed some more typos; now can run up through getitem and collate at …
mmcdermott Nov 10, 2023
00a0305
Fixed some small bugs; tests still not fully passing but sample data …
mmcdermott Nov 10, 2023
260161a
Fixed a number of small bugs; tests closer to passing.
mmcdermott Nov 10, 2023
6052b8a
I think all tests are now passing.
mmcdermott Nov 10, 2023
7098db8
Updated notebook.
mmcdermott Nov 10, 2023
070e711
Starter code for nested caching and subsampling
mmcdermott Nov 14, 2023
a7ab111
Added doctest.
mmcdermott Nov 21, 2023
c2c4a2e
Merge branch 'improved_compute' into duplication_bug_in_task_caching
mmcdermott Nov 21, 2023
6235bd0
Added test for task caching bug.
mmcdermott Nov 21, 2023
20ef6af
Fixed test.
mmcdermott Nov 21, 2023
c78cb4d
Lint fixes.
mmcdermott Nov 21, 2023
3a1821c
Merge pull request #77 from mmcdermott/duplication_bug_in_task_caching
mmcdermott Nov 21, 2023
e4f3e98
Merge branch 'improved_compute' into nested_caching
mmcdermott Nov 21, 2023
cf7b9e3
Added loguru
mmcdermott Nov 21, 2023
16e0c87
Trying logger changes.
mmcdermott Nov 21, 2023
e327db2
Added some logging and got tests passing.
mmcdermott Nov 21, 2023
9c0eeb5
Removed the old increment calls.
mmcdermott Nov 21, 2023
64df044
Added loguru throughout.
mmcdermott Nov 22, 2023
63a838a
Merge pull request #78 from mmcdermott/remove_unnecessary_increment
mmcdermott Nov 25, 2023
33215e7
Merge pull request #79 from mmcdermott/enable_flash_attention
mmcdermott Nov 25, 2023
866e5c1
Merge branch 'v0.5' into proper_logs
mmcdermott Nov 25, 2023
8c4d60f
Added loguru initialization code to all scripts.
mmcdermott Nov 25, 2023
1b5c30b
Fixed tests.
mmcdermott Nov 26, 2023
1534c03
Merge pull request #80 from mmcdermott/proper_logs
mmcdermott Nov 27, 2023
b24becd
Added caching fns
pargaw Nov 28, 2023
d9fe46b
Partial pseudo-code for caching subsets
mmcdermott Nov 28, 2023
f8447e9
Some small changes.
mmcdermott Nov 30, 2023
e5feba5
Fixed cache_suset to run fully
pargaw Dec 2, 2023
0a8a797
Merge pull request #81 from mmcdermott/remove_unnecessary_subject_man…
mmcdermott Dec 2, 2023
dd2035f
Fixed all tests and upgraded to polars 0.19.19
mmcdermott Dec 3, 2023
273e09a
Merge pull request #82 from mmcdermott/remove_with_row_count
mmcdermott Dec 3, 2023
b34b3fe
Added logging to other aspects of ESGPT.
mmcdermott Dec 3, 2023
ccf9da9
Merge pull request #83 from mmcdermott/more_logging
mmcdermott Dec 3, 2023
e1a3c3e
Removing currently unfinished and unused evaluation code.
mmcdermott Dec 3, 2023
ea9ce9f
Removed other buggy code and an unnecessary break statement in sklear…
mmcdermott Dec 3, 2023
7d1efef
Fixed URI function bug.
mmcdermott Dec 3, 2023
cf8fe51
Added an ESDS conversion option.
mmcdermott Dec 4, 2023
4bd8499
Added ESDS support; need to test with modifiers.
mmcdermott Dec 4, 2023
537ec1a
Added modifiers into standard sample data example and fixed slight is…
mmcdermott Dec 4, 2023
38a0897
Fixed bugs with saving new data_stats fot cached subset data
pargaw Dec 4, 2023
1b59238
Fixed tests by adding medications to the sample NA config.
mmcdermott Dec 4, 2023
acd2bf0
Merge pull request #85 from mmcdermott/ESDS_export
mmcdermott Dec 4, 2023
f7aa908
Changed cache_subset for train vs other splits
pargaw Dec 4, 2023
65d3c75
Fixed small bugs in data pipeline inhibiting runs on real data.
mmcdermott Dec 4, 2023
3078b9e
Fixed subset caching to account for extra padding on time_delta
pargaw Dec 7, 2023
44704e0
Stylistic changes.
mmcdermott Dec 7, 2023
b1af020
Merge branch 'v0.5' into nested_caching
mmcdermott Dec 7, 2023
acf5b69
Made subset size work in case of float subset size (which is supporte…
mmcdermott Dec 7, 2023
41cbcde
Made all uses of full data config a consistent property.
mmcdermott Dec 7, 2023
a303b39
Fixed tests with nested caching
mmcdermott Dec 7, 2023
0c04c89
Merge pull request #86 from mmcdermott/nested_caching
mmcdermott Dec 7, 2023
235d8b6
Updated some deprecated polars functions.
mmcdermott Dec 7, 2023
8e38596
Merge pull request #87 from mmcdermott/polars_0.19
mmcdermott Dec 7, 2023
99b7e11
Added two more logging lines.
mmcdermott Dec 8, 2023
5faa9a5
Merge branch 'v0.5' of github.com:mmcdermott/EventStreamML into v0.5
mmcdermott Dec 8, 2023
c017794
re-set to dev.
mmcdermott Dec 14, 2023
dd81924
Merge pull request #84 from mmcdermott/v0.5
mmcdermott Dec 14, 2023
54e4698
Merge branch 'dev' of github.com:mmcdermott/EventStreamGPT into dev
mmcdermott Dec 15, 2023
9a1d72d
Fixed typo in length_constraint in ConstructorPytorchDataset init
pargaw Dec 16, 2023
007fd49
Updated error message
mmcdermott Dec 18, 2023
ca65a71
Attempting to try using ragged tensors from https://github.com/mmcder…
mmcdermott Dec 18, 2023
5947a91
Fixed a few small errors
mmcdermott Dec 19, 2023
774a9bc
Fixed small typo; may or may not still be working.
mmcdermott Dec 19, 2023
d0ee233
Fixed small typo
mmcdermott Dec 19, 2023
5d3b403
Improved collate fn.
mmcdermott Dec 19, 2023
621a939
First working version
mmcdermott Dec 19, 2023
9f5013b
A version using numpy instead of torch for collation and such.
mmcdermott Dec 19, 2023
c05428b
Further optimizations
mmcdermott Dec 20, 2023
e346808
Added _cache_subset
pargaw Dec 22, 2023
b22e6ce
Fixed some logger typos
pargaw Dec 22, 2023
a023b62
temporarily set max recursion limit to account for large subset sizes
pargaw Dec 22, 2023
29c3b9f
Removed unused comment
pargaw Dec 22, 2023
8c73fd3
Make logging more detailed.
mmcdermott Jan 4, 2024
78a1aac
Merge branch 'using_ragged_tensors' of github.com:mmcdermott/EventStr…
mmcdermott Jan 4, 2024
4245128
Fixed some small typos and made cached subsets actually be re-loaded …
mmcdermott Jan 4, 2024
35bacf2
Cache data_parameters for subset sizes
pargaw Jan 4, 2024
6095101
Fixed lint errors.
mmcdermott Jan 9, 2024
fa33387
Added nested_ragged_tensors as a dependency.
mmcdermott Jan 22, 2024
357caac
Fixed a doc mismatch and a scripts typo.
mmcdermott Jan 22, 2024
7c4b8a9
Added tuning subset size filter
pargaw Feb 1, 2024
91a700b
Changed defaults of final_validation_metrics_config to allow for flex…
pargaw Feb 4, 2024
3ef80b8
Fixed lint errors.
mmcdermott Feb 23, 2024
31bbdbc
Added basic code for setting some mandatory patient split sets.
mmcdermott Mar 11, 2024
1ff2c4b
updated tutorial notebook and added TODO.
mmcdermott Apr 5, 2024
1c148a8
Merge pull request #94 from mmcdermott/minor_fixes
mmcdermott Apr 5, 2024
afb029c
Merge branch 'dev' of github.com:mmcdermott/EventStreamGPT into dev
mmcdermott Apr 5, 2024
7fa60a6
Fixed a typo in the documentation.
mmcdermott Apr 7, 2024
81afe4d
Merge branch 'dev' of github.com:mmcdermott/EventStreamML into dev
mmcdermott Apr 11, 2024
c217975
Updated a few deprecated functions.
mmcdermott Apr 19, 2024
aa953ab
Removed pre-processors; tests passing except polars dataset integrati…
mmcdermott Apr 20, 2024
9c3b897
Added seeding to caching of data.
mmcdermott Apr 20, 2024
5fc2d63
Updated python version in tests workflow.
mmcdermott Apr 21, 2024
2c22322
Split improperly shared parameters.
mmcdermott Apr 22, 2024
3d3a60b
Merge pull request #99 from mmcdermott/update_transformers
mmcdermott Apr 22, 2024
2434b4f
Merge pull request #98 from mmcdermott/add_seeding_to_nrt
mmcdermott Apr 22, 2024
96fbca1
Merge branch 'dev' into polars_upgrade
mmcdermott Apr 22, 2024
0ae5092
upgraded polars and is working except for new ordering in polars test.
mmcdermott Apr 22, 2024
26ab54a
Removed the brittle polars dataset test. Should be broken out into pr…
mmcdermott Apr 22, 2024
6b44e77
Merge pull request #100 from mmcdermott/remove_polars_integration_uni…
mmcdermott Apr 22, 2024
1ce756e
Merge branch 'dev' into polars_upgrade
mmcdermott Apr 22, 2024
839231e
Updated notebook to fix error when no subjects are in valid set.
mmcdermott Apr 22, 2024
4deaf85
Re-scan file after collect and write -- for some reason this is neces…
mmcdermott Apr 22, 2024
f370f08
Forgot nb change.
mmcdermott Apr 22, 2024
9eead53
Merge pull request #101 from mmcdermott/polars_upgrade
mmcdermott Apr 22, 2024
475a7f0
Merged.
mmcdermott Apr 22, 2024
a51e695
Merged
mmcdermott Apr 22, 2024
8d866d9
Merge pull request #97 from mmcdermott/remove_preprocessors
mmcdermott Apr 22, 2024
fd20002
Merge branch 'dev' into using_ragged_tensors
mmcdermott Apr 22, 2024
187a150
Merge branch 'dev' into load_from_ckpt
mmcdermott Apr 22, 2024
3545a40
Merge branch 'dev' into specify_test_set_exclusions
mmcdermott Apr 22, 2024
71e4b6c
Merge pull request #102 from mmcdermott/load_from_ckpt
mmcdermott Apr 22, 2024
c959fd8
Fixed test set exclusion code and added tests
mmcdermott Apr 22, 2024
2d00aec
Merge branch 'dev' into specify_test_set_exclusions
mmcdermott Apr 22, 2024
70c7174
Merge pull request #103 from mmcdermott/specify_test_set_exclusions
mmcdermott Apr 22, 2024
1dfb6ef
Merge branch 'dev' into using_ragged_tensors
mmcdermott Apr 22, 2024
da2307f
Partially tested attempt at adding future windows.
mmcdermott Apr 24, 2024
f84069c
Merged in dev changes.
mmcdermott Apr 24, 2024
8da042a
Added tests to touch new capability.
mmcdermott Apr 24, 2024
5f9697b
Merge pull request #105 from mmcdermott/add_future_summary_windows_to…
mmcdermott Apr 24, 2024
2873344
Merge branch 'dev' into using_ragged_tensors
mmcdermott Apr 24, 2024
476661a
Updated to write NRT files and use NRT files in a smarter, much less …
mmcdermott May 16, 2024
cd9a661
Some improvements to the test code; pytorch tests are currently failing.
mmcdermott May 17, 2024
6a370bf
Removing out-dated tests with new pytorch dataset format.
mmcdermott May 17, 2024
79d41a4
Fixed pytorch dataset tests (mostly by removing those that were faili…
mmcdermott May 17, 2024
6c20b5d
Further updated polars and fixed a small test case that the polars ch…
mmcdermott May 17, 2024
330f6ac
Makes measurement_configs a property instead of a static access.
mmcdermott May 18, 2024
a3e4bd9
Updated the docstring for the generate_synthetic_data script.
mmcdermott May 18, 2024
2f433a6
Merge branch 'dev' into using_ragged_tensors
mmcdermott May 18, 2024
0350d7c
Merge pull request #90 from mmcdermott/using_ragged_tensors
mmcdermott May 31, 2024
e77b5c9
Merge branch 'dev' of github.com:mmcdermott/EventStreamML into dev
mmcdermott Jun 20, 2024
a0550b1
Updated to_int_index to not fail with newer versions of polars. Also …
mmcdermott Jun 22, 2024
5b4a279
Updated polars version in the pyproject.toml
mmcdermott Jun 22, 2024
ce9e2c3
Merge pull request #115 from mmcdermott/update_packages
mmcdermott Jun 22, 2024
fc42974
Merged main
mmcdermott Mar 20, 2025
c5909c3
Fixed typo
mmcdermott Mar 20, 2025
6cb0906
Maybe fixed tests
mmcdermott Mar 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
- name: Checkout
uses: actions/checkout@v3

- name: Set up Python 3.10
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.10"
python-version: "3.11"

- name: Install packages
run: |
Expand Down
22 changes: 13 additions & 9 deletions EventStream/baseline/FT_task_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import polars.selectors as cs
import wandb
from hydra.core.config_store import ConfigStore
from loguru import logger
from omegaconf import OmegaConf
from sklearn.decomposition import NMF, PCA
from sklearn.ensemble import RandomForestClassifier
Expand All @@ -35,7 +36,7 @@
from ..tasks.profile import add_tasks_from
from ..utils import task_wrapper

pl.enable_string_cache(True)
pl.enable_string_cache()


def load_flat_rep(
Expand Down Expand Up @@ -187,6 +188,7 @@ def load_flat_rep(
if do_cache_filtered_task:
cached_fp.parent.mkdir(exist_ok=True, parents=True)
df.collect().write_parquet(cached_fp, use_pyarrow=True)
df = pl.scan_parquet(cached_fp).select("subject_id", "timestamp", *window_features)

df = df.select("subject_id", "timestamp", *window_features)
if subjects_included.get(sp, None) is not None:
Expand Down Expand Up @@ -649,7 +651,7 @@ def eval_binary_classification(Y: np.ndarray, probs: np.ndarray) -> dict[str, fl


def train_sklearn_pipeline(cfg: SklearnConfig):
print(f"Saving config to {cfg.save_dir / 'config.yaml'}")
logger.info(f"Saving config to {cfg.save_dir / 'config.yaml'}")
cfg.save_dir.mkdir(exist_ok=True, parents=True)
OmegaConf.save(cfg, cfg.save_dir / "config.yaml")

Expand All @@ -674,7 +676,7 @@ def train_sklearn_pipeline(cfg: SklearnConfig):

# TODO(mmd): Window sizes may violate start_time constraints in task dfs!

print(f"Loading representations for {', '.join(cfg.feature_selector.window_sizes)}")
logger.info(f"Loading representations for {', '.join(cfg.feature_selector.window_sizes)}")
subjects_included = {}

if cfg.train_subset_size not in (None, "FULL"):
Expand Down Expand Up @@ -706,24 +708,26 @@ def train_sklearn_pipeline(cfg: SklearnConfig):
Xs_and_Ys = {}
for split in ("train", "tuning", "held_out"):
st = datetime.now()
print(f"Loading dataset for {split}")
logger.info(f"Loading dataset for {split}")
df = flat_reps[split].with_columns(normalized_label.alias(cfg.finetuning_task_label)).collect()

X = df.drop(["subject_id", "timestamp", cfg.finetuning_task_label])
Y = df[cfg.finetuning_task_label].to_numpy()
print(f"Done with {split} dataset with X of shape {X.shape} " f"(elapsed: {datetime.now() - st})")
logger.info(
f"Done with {split} dataset with X of shape {X.shape} " f"(elapsed: {datetime.now() - st})"
)
Xs_and_Ys[split] = (X, Y)

print("Initializing model!")
logger.info("Initializing model!")
model = cfg.get_model(dataset=ESD)

print("Fitting model!")
logger.info("Fitting model!")
model.fit(*Xs_and_Ys["train"])
print(f"Saving model to {cfg.save_dir}")
logger.info(f"Saving model to {cfg.save_dir}")
with open(cfg.save_dir / "model.pkl", mode="wb") as f:
pickle.dump(model, f)

print("Evaluating model!")
logger.info("Evaluating model!")
all_metrics = {}
for split in ("tuning", "held_out"):
X, Y = Xs_and_Ys[split]
Expand Down
6 changes: 3 additions & 3 deletions EventStream/data/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ the following data:
indices of the measures that correspond to the measurement observations in `dynamic_indices`.
8. `dynamic_values`, which is of the same (ragged) shape as `dynamic_indices` and contains any unique
numerical values associated with those measurements. Items may be missing (reflected with `None` or
`np.NaN`, depending on the data library format) or may have been filtered out as outliers (reflected with
`np.NaN`).
`float('nan')`, depending on the data library format) or may have been filtered out as outliers (reflected with
`float('nan')`).

### Measurements

Expand Down Expand Up @@ -390,7 +390,7 @@ Let us define the following variables:
}
```

`static_data_values` and `data_values` in the above dictionary may contain `np.NaN` entries where values were
`static_data_values` and `data_values` in the above dictionary may contain `float('nan')` entries where values were
not observed with a given data element. All other data elements are fully observed. The elements correspond to
the following kinds of features:

Expand Down
Loading
Loading