-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel_parameters.py
More file actions
42 lines (28 loc) · 838 Bytes
/
model_parameters.py
File metadata and controls
42 lines (28 loc) · 838 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Counts the trainable parameter of a model."""
import os
import dkt_model
import numpy
import quick_experiment.utils
from quick_experiment import dataset
SKILLS = 140
CONFIG = {
'hidden_layer_size': 200, 'batch_size': 100, 'training_epochs': 0,
'max_num_steps': 100
}
class MockDKTDataset(dataset.LabeledSequenceDataset):
@property
def labels_type(self):
return numpy.float32
def classes_num(self, _=None):
"""The number of problems in the dataset"""
return SKILLS + 1
@property
def feature_vector_size(self):
return SKILLS * 2
def main():
assistment_dataset = MockDKTDataset()
model = dkt_model.DktLSTMModel(assistment_dataset, **CONFIG)
model.fit(partition_name='train')
model.count_trainable_parameters()
if __name__ == '__main__':
main()