15
15
import time
16
16
import argparse
17
17
import glob
18
- import scipy .signal
19
18
import os
20
19
from einops import rearrange , reduce , repeat
21
20
from einops .layers .torch import Rearrange , Reduce
22
21
print (f"Torch Version : { torch .__version__ } " )
23
22
24
23
from datasets .sleep_edf import split_data , SleepEDF_MultiChan_Dataset , get_dataset
25
- from models .epoch_cmt import Epoch_Cross_Transformer_Network
24
+ from models .epoch_cmt import Epoch_Cross_Transformer_Network , train_epoch_cmt
26
25
from models .sequence_cmt import Seq_Cross_Transformer_Network
27
26
from utils .metrics import accuracy , kappa , g_mean , plot_confusion_matrix , confusion_matrix , AverageMeter
28
27
@@ -33,15 +32,17 @@ def parse_option():
33
32
34
33
parser .add_argument ('--project_path' , type = str , default = './results' , help = 'Path to store project results' )
35
34
parser .add_argument ('--data_path' , type = str , help = 'Path to the dataset file' )
36
- parser .add_argument ('--train_data_list' , type = list , default = [0 ,1 ,2 ,3 ] , help = 'Folds in the dataset for training' )
37
- parser .add_argument ('--val_data_list' , type = list , default = [0 ,1 ,2 ,3 ] , help = 'Folds in the dataset for validation' )
35
+ parser .add_argument ('--train_data_list' , nargs = "+" , default = [0 ,1 ,2 ,3 ] , help = 'Folds in the dataset for training' )
36
+ parser .add_argument ('--val_data_list' , nargs = "+" , default = [4 ] , help = 'Folds in the dataset for validation' )
37
+ parser .add_argument ('--is_retrain' , type = bool , default = False , help = 'To retrain a from saved checkpoint' )
38
+ parser .add_argument ('--model_path' , type = str , default = "" , help = 'Path to saved checkpoint for retraining' )
38
39
parser .add_argument ('--save_model_freq' , type = int , default = 50 , help = 'Frequency of saving the model checkpoint' )
39
40
40
41
#model parameters
41
- parser .add_argument ('--model_type' , type = str , default = 'Epoch' ,choices = ['Epoch' , 'Sequence' ], help = 'Model type epoch or sequence cross modal transformer ' )
42
- parser .add_argument ('--d_model ' , type = int , default = 256 , help = 'Embedding size of the CMT' )
43
- parser .add_argument ('--dim_feedforward' , type = int , default = 1024 , help = 'No of neurons in the hidden layer of feed forward block' )
44
- parser .add_argument ('--window_size ' , type = int , default = 50 , help = 'Size of non-overlapping window' )
42
+ parser .add_argument ('--model_type' , type = str , default = 'Epoch' ,choices = ['Epoch' , 'Sequence' ], help = 'Model type' )
43
+ parser .add_argument ('--d_model' , type = int , default = 256 , help = 'Embedding size of the CMT' )
44
+ parser .add_argument ('--dim_feedforward' , type = int , default = 1024 , help = 'No of neurons feed forward block' )
45
+ parser .add_argument ('--window_size' , type = int , default = 50 , help = 'Size of non-overlapping window' )
45
46
46
47
#training parameters
47
48
parser .add_argument ('--batch_size' , type = int , default = 32 , help = 'Batch Size' )
@@ -54,7 +55,7 @@ def parse_option():
54
55
parser .add_argument ('--beta_1' , type = float , default = 0.9 , help = 'beta 1 for adam optimizer' )
55
56
parser .add_argument ('--beta_2' , type = float , default = 0.999 , help = 'beta 2 for adam optimizer' )
56
57
parser .add_argument ('--eps' , type = float , default = 1e-9 , help = 'eps for adam optimizer' )
57
- parser .add_argument ('--weight_decay ' , type = float , default = 0.0001 , help = 'weight_decay for adam optimizer' )
58
+ parser .add_argument ('--weight_decay' , type = float , default = 0.0001 , help = 'weight_decay for adam optimizer' )
58
59
parser .add_argument ('--n_epochs' , type = int , default = 200 , help = 'No of training epochs' )
59
60
60
61
#For scheduler
@@ -75,8 +76,9 @@ def main():
75
76
76
77
args = parse_option ()
77
78
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
78
-
79
- print (args )
79
+ print ("Training Arguements ====================================>" )
80
+ for arg in vars (args ):
81
+ print (f" { arg } : { getattr (args , arg )} " )
80
82
81
83
if args .is_neptune : # Initiate Neptune
82
84
import neptune .new as neptune
@@ -89,11 +91,64 @@ def main():
89
91
print (f"Project directory already available at { args .project_path } " )
90
92
91
93
94
+ #Get Dataset
95
+ print ("Getting Dataset ===================================>" )
96
+ train_data_loader , val_data_loader = get_dataset (device ,args )
92
97
93
98
99
+ ##Load Model
100
+ if args .model_type == "Epoch" : # Initialize epoch cross-modal transformer
101
+ if args .is_retrain :
102
+ print (f"Loading previous checkpoint from { args .model_path } " )
103
+ Net = torch .load (f"{ args .model_path } " )
104
+ else :
105
+ print (f"Initializing Epoch Cross Modal Transformer ==================>" )
106
+ Net = Epoch_Cross_Transformer_Network (d_model = args .d_model , dim_feedforward = args .dim_feedforward ,
107
+ window_size = args .window_size ).to (device )
108
+
109
+ weights = torch .tensor (args .weigths )
110
+ criterion = nn .CrossEntropyLoss (weight = weights )
111
+ optimizer = torch .optim .Adam (Net .parameters (), lr = args .lr , betas = (args .beta_1 , args .beta_2 ),
112
+ eps = args .eps , weight_decay = args .weight_decay )
113
+ lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = args .step_size , gamma = args .gamma )
114
+
94
115
95
-
96
-
97
-
116
+ if args .is_neptune :
117
+ parameters = {
118
+ "Experiment" : "Training test" ,
119
+ 'Model Type' : "Epoch Cross-Modal Transformer" ,
120
+ 'd_model' : args .d_model ,
121
+ 'dim_feedforward' : args .dim_feedforward ,
122
+ 'window_size ' :args .window_size ,
123
+ 'Batch Size' : args .batch_size ,
124
+ 'Loss' : f"Weighted Categorical Loss,{ args .weights } " , # Check this every time
125
+ 'Optimizer' : "Adam" , # Check this every time
126
+ 'Learning Rate' : args .lr ,
127
+ 'eps' : args .eps ,
128
+ "LR Schduler" : "StepLR" ,
129
+ 'Beta 1' : args .beta_1 ,
130
+ 'Beta 2' : args .beta_2 ,
131
+ 'n_epochs' : args .n_epochs ,
132
+ 'val_set' : args .val_data_list [0 ]+ 1
133
+ }
134
+ run ['model/parameters' ] = parameters
135
+ run ['model/model_architecture' ] = Net
136
+
137
+
138
+ if not os .path .isdir (os .path .join (args .project_path ,"model_check_points" )):
139
+ os .makedirs (os .path .join (args .project_path ,"model_check_points" ))
140
+
141
+
142
+ # Train Epoch Cross-Modal Transformer
143
+ if args .model_type == "Epoch" :
144
+ train_epoch_cmt (Net , train_data_loader , val_data_loader , criterion , optimizer , lr_scheduler , device , args )
145
+
146
+
147
+
148
+ if __name__ == '__main__' :
149
+ main ()
150
+
98
151
99
-
152
+
153
+ # Training Epoch CMT
154
+ #python cmt_training.py --project_path "testing" --data_path "/home/mmsm/Sleep_EDF_Dataset" --train_data_list [0,1,2,3] --val_data_list [4] --model_type "Epoch" --is_neptune True --nep_project "jathurshan0330/V2-Cros" --nep_api "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmYmRmNjE0Zi0xMDRkLTRlNzUtYmIxNi03NzM2ODBlZDc5NTMifQ=="
0 commit comments