-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·50 lines (43 loc) · 2.98 KB
/
main.py
File metadata and controls
executable file
·50 lines (43 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from argparse import ArgumentParser
import json
import os
from loguru import logger
from utils.train_utils import trainer
def main(args):
trainer(args)
if __name__=="__main__":
parser=ArgumentParser()
parser.add_argument("--experiment_name", type=str, required=True, help="Name of experiment")
parser.add_argument("--save_path", type=str, help="Path to save results", default="data/results/")
parser.add_argument("--train_aa_data_path", type=str, help="the path to the train aa data", default="data/AAList/train_aa_with_similarity.csv")
parser.add_argument("--train_mol_data_path", type=str, help="the path to train general mol", default="data/ZINC15/train_zinc15_10M_2D.csv")
parser.add_argument("--val_aa_data_path", type=str, help="the path to the val aa data", default="data/AAList/val_aa_with_similarity.csv")
parser.add_argument("--val_mol_data_path", type=str, help="the path to val general mol", default="data/ZINC15/val_zinc15_10M_2D.csv")
parser.add_argument("--model_channels", type=int, default=256, help="The number of channels for model")
parser.add_argument("--num_head", type=int, default=16, help="The number of channels for model")
parser.add_argument("--topological_net_layers", type=int, default=6, help="The number of topological net layers")
parser.add_argument("--decoder_layers", type=int, default=2, help="The number of decoder layers")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=10000, help="The number of training epochs")
#parser.add_argument("--validate_data_size", type=int, default=2, help="The number of data points for validation")
parser.add_argument("--logger_step", type=int, help="Step for logger", default=5)
parser.add_argument("--cache_path", type=str, help="path to cache files", default="data/cache/")
parser.add_argument("--load_path", type=str, help="path to load state dict", default=None)
parser.add_argument("--num_workers", type=int, help="num of workers to load data", default=5)
parser.add_argument("--max_combine", type=int, help="", default=4)
parser.add_argument("--norm", type=str, choices=["BatchNorm", "GraphNorm", "LayerNorm", "None"], default=None)
parser.add_argument("--aba", type=int, choices=[0,1,2], default=0)
parser.add_argument("--model", type=str, choices=["GPS"], default="GPS")
parser.add_argument("--num_inner_l", type=int,default=2)
parser.add_argument("--cont_weight", type=float,default=0.01)
args=parser.parse_args()
if args.norm=="None":
args.norm=None
save_path=os.path.join(args.save_path, args.experiment_name)
if not os.path.exists(save_path):
os.mkdir(save_path)
with open(os.path.join(save_path, "config.json"), "w") as f:
f.write(json.dumps(vars(args)))
logger.add(os.path.join(save_path, "log"))
main(args)