Skip to content

Commit d36365c

Browse files
Complete epoch cmt training script
1 parent bb9887f commit d36365c

12 files changed

+1824
-42
lines changed

.ipynb_checkpoints/Epoch_cross_transformer_training-checkpoint.ipynb

+655-1
Large diffs are not rendered by default.

.ipynb_checkpoints/cmt_training-checkpoint.py

+70-15
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
import time
1616
import argparse
1717
import glob
18-
import scipy.signal
1918
import os
2019
from einops import rearrange, reduce, repeat
2120
from einops.layers.torch import Rearrange, Reduce
2221
print(f"Torch Version : {torch.__version__}")
2322

2423
from datasets.sleep_edf import split_data, SleepEDF_MultiChan_Dataset, get_dataset
25-
from models.epoch_cmt import Epoch_Cross_Transformer_Network
24+
from models.epoch_cmt import Epoch_Cross_Transformer_Network,train_epoch_cmt
2625
from models.sequence_cmt import Seq_Cross_Transformer_Network
2726
from utils.metrics import accuracy, kappa, g_mean, plot_confusion_matrix, confusion_matrix, AverageMeter
2827

@@ -33,15 +32,17 @@ def parse_option():
3332

3433
parser.add_argument('--project_path', type=str, default='./results', help='Path to store project results')
3534
parser.add_argument('--data_path', type=str, help='Path to the dataset file')
36-
parser.add_argument('--train_data_list', type=list, default = [0,1,2,3] , help='Folds in the dataset for training')
37-
parser.add_argument('--val_data_list', type=list, default = [0,1,2,3] , help='Folds in the dataset for validation')
35+
parser.add_argument('--train_data_list', nargs="+", default = [0,1,2,3] , help='Folds in the dataset for training')
36+
parser.add_argument('--val_data_list', nargs="+", default = [4] , help='Folds in the dataset for validation')
37+
parser.add_argument('--is_retrain', type=bool, default=False, help='To retrain a from saved checkpoint')
38+
parser.add_argument('--model_path', type=str, default="", help='Path to saved checkpoint for retraining')
3839
parser.add_argument('--save_model_freq', type=int, default = 50 , help='Frequency of saving the model checkpoint')
3940

4041
#model parameters
41-
parser.add_argument('--model_type', type=str, default = 'Epoch' ,choices=['Epoch', 'Sequence'], help='Model type epoch or sequence cross modal transformer')
42-
parser.add_argument('--d_model ', type=int, default = 256 , help='Embedding size of the CMT')
43-
parser.add_argument('--dim_feedforward', type=int, default = 1024 , help='No of neurons in the hidden layer of feed forward block')
44-
parser.add_argument('--window_size ', type=int, default = 50 , help='Size of non-overlapping window')
42+
parser.add_argument('--model_type', type=str, default = 'Epoch' ,choices=['Epoch', 'Sequence'], help='Model type')
43+
parser.add_argument('--d_model', type=int, default = 256, help='Embedding size of the CMT')
44+
parser.add_argument('--dim_feedforward', type=int, default = 1024, help='No of neurons feed forward block')
45+
parser.add_argument('--window_size', type=int, default = 50, help='Size of non-overlapping window')
4546

4647
#training parameters
4748
parser.add_argument('--batch_size', type=int, default = 32 , help='Batch Size')
@@ -54,7 +55,7 @@ def parse_option():
5455
parser.add_argument('--beta_1', type=float, default = 0.9 , help='beta 1 for adam optimizer')
5556
parser.add_argument('--beta_2', type=float, default = 0.999 , help='beta 2 for adam optimizer')
5657
parser.add_argument('--eps', type=float, default = 1e-9 , help='eps for adam optimizer')
57-
parser.add_argument('--weight_decay ', type=float, default = 0.0001 , help='weight_decay for adam optimizer')
58+
parser.add_argument('--weight_decay', type=float, default = 0.0001 , help='weight_decay for adam optimizer')
5859
parser.add_argument('--n_epochs', type=int, default = 200 , help='No of training epochs')
5960

6061
#For scheduler
@@ -75,8 +76,9 @@ def main():
7576

7677
args = parse_option()
7778
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78-
79-
print(args)
79+
print("Training Arguements ====================================>")
80+
for arg in vars(args):
81+
print(f" {arg} : {getattr(args, arg)}")
8082

8183
if args.is_neptune: # Initiate Neptune
8284
import neptune.new as neptune
@@ -89,11 +91,64 @@ def main():
8991
print(f"Project directory already available at {args.project_path}")
9092

9193

94+
#Get Dataset
95+
print("Getting Dataset ===================================>")
96+
train_data_loader, val_data_loader = get_dataset(device,args)
9297

9398

99+
##Load Model
100+
if args.model_type == "Epoch": # Initialize epoch cross-modal transformer
101+
if args.is_retrain:
102+
print(f"Loading previous checkpoint from {args.model_path}")
103+
Net = torch.load(f"{args.model_path}")
104+
else:
105+
print(f"Initializing Epoch Cross Modal Transformer ==================>")
106+
Net = Epoch_Cross_Transformer_Network(d_model = args.d_model, dim_feedforward = args.dim_feedforward,
107+
window_size = args.window_size ).to(device)
108+
109+
weights = torch.tensor(args.weigths)
110+
criterion = nn.CrossEntropyLoss(weight=weights)
111+
optimizer = torch.optim.Adam(Net.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),
112+
eps = args.eps, weight_decay = args.weight_decay)
113+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
114+
94115

95-
96-
97-
116+
if args.is_neptune:
117+
parameters = {
118+
"Experiment" : "Training test",
119+
'Model Type' : "Epoch Cross-Modal Transformer",
120+
'd_model' : args.d_model,
121+
'dim_feedforward' : args.dim_feedforward,
122+
'window_size ':args.window_size ,
123+
'Batch Size': args.batch_size,
124+
'Loss': f"Weighted Categorical Loss,{args.weights}", # Check this every time
125+
'Optimizer' : "Adam", # Check this every time
126+
'Learning Rate': args.lr,
127+
'eps' : args.eps,
128+
"LR Schduler": "StepLR",
129+
'Beta 1': args.beta_1,
130+
'Beta 2': args.beta_2,
131+
'n_epochs': args.n_epochs,
132+
'val_set' : args.val_data_list[0]+1
133+
}
134+
run['model/parameters'] = parameters
135+
run['model/model_architecture'] = Net
136+
137+
138+
if not os.path.isdir(os.path.join(args.project_path,"model_check_points")):
139+
os.makedirs(os.path.join(args.project_path,"model_check_points"))
140+
141+
142+
# Train Epoch Cross-Modal Transformer
143+
if args.model_type == "Epoch":
144+
train_epoch_cmt(Net, train_data_loader, val_data_loader, criterion, optimizer, lr_scheduler, device, args)
145+
146+
147+
148+
if __name__ == '__main__':
149+
main()
150+
98151

99-
152+
153+
# Training Epoch CMT
154+
#python cmt_training.py --project_path "testing" --data_path "/home/mmsm/Sleep_EDF_Dataset" --train_data_list [0,1,2,3] --val_data_list [4] --model_type "Epoch" --is_neptune True --nep_project "jathurshan0330/V2-Cros" --nep_api "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJmYmRmNjE0Zi0xMDRkLTRlNzUtYmIxNi03NzM2ODBlZDc5NTMifQ=="

0 commit comments

Comments
 (0)