Skip to content

Commit 94e14aa

Browse files
committed
use test_split and validation_split for pubchem
1 parent 1d8a7c3 commit 94e14aa

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

chebai/preprocessing/datasets/pubchem.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,13 @@ def setup_processed(self):
154154
print("Load data from file", filename)
155155
data = self._load_data_from_file(filename)
156156
print("Create splits")
157-
train, test = train_test_split(data, train_size=self.train_split)
157+
train, test = train_test_split(
158+
data, train_size=1 - (self.validation_split + self.test_split)
159+
)
158160
del data
159-
test, val = train_test_split(test, train_size=self.train_split)
161+
test, val = train_test_split(
162+
test, train_size=self.test_split / (self.validation_split + self.test_split)
163+
)
160164
torch.save(train, os.path.join(self.processed_dir, "train.pt"))
161165
torch.save(test, os.path.join(self.processed_dir, "test.pt"))
162166
torch.save(val, os.path.join(self.processed_dir, "validation.pt"))

0 commit comments

Comments
 (0)