import torch from orm_model import CTC_CNN, default_model_params from orm_dataset import CTC_PriMuS import os import time import cv2 from PIL import Image import matplotlib as plt def decode(target): decoded = [] decoded_2 = [] prev = 0 for note in target: if note == prev: continue else: decoded.append(note) prev = note for note in decoded: if note != 0: decoded_2.append(note) return decoded_2 data_dir = './package_aa' dict_path = './vocabulary_semantic.txt' train_loss=val_loss=[] img_height = 128 learning_rate = 0.04 num_epochs = 500 batch_size = 16 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") corpus_list = os.listdir(data_dir)[0:700] primus = CTC_PriMuS(data_dir, corpus_list, dict_path, True, val_split=0.1) params = default_model_params(img_height, primus.vocabulary_size) #print(primus.nextBatch(["targets"])[0][0][50:100]) # model model = CTC_CNN(img_height, primus.vocabulary_size).to(device) # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # loss loss_func = torch.nn.CTCLoss() loss_log = [] #batch = primus.nextBatch(params) #inputs = torch.tensor(batch['inputs']).to(device) #targets = torch.tensor(batch['targets']).to(device) #inputs=torch.transpose(inputs,1,2) #inputs=torch.transpose(inputs,1,3) # loop start_time = time.time() for epoch in range(num_epochs): print(f'Epoch {epoch}/{num_epochs - 1}') print('-' * 10) epoch_time = time.time() # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: val_i = 0 model.eval() running_loss = 0.0 running_corrects = 0 # Iterate over data. while True: if phase == 'train': batch = primus.nextBatch(params) #print(primus.current_idx) #img = Image.fromarray(batch['inputs'][0][:][:][:]*255,'RGB') #img.show() inputs = torch.tensor(batch['inputs']).to(device) targets = torch.tensor(batch['targets']).to(device) #input_image=cv2.imread("C:\\Users\\psiml\\Downloads\\MusicDetector\\package_aa\\000051650-1_1_1\\000051650-1_1_1.png") #inputs=torch.tensor([input_image]).to(device) #output_line="clef-G2 keySignature-EbM timeSignature-3/4 note-Bb5_quarter note-Eb5_eighth note-Bb5_eighth note-C6_eighth note-Bb5_eighth barline note-Ab5_eighth note-Ab5_eighth rest-sixteenth note-Ab5_sixteenth note-G5_sixteenth note-Ab5_sixteenth note-Bb5_sixteenth note-Ab5_sixteenth note-G5_sixteenth note-Ab5_sixteenth barline " #targets = torch.tensor([output_line]).to(device) else: batch = primus.getValidation(params) inputs = torch.tensor(batch['inputs'][val_i:val_i+params['batch_size']]).to(device) targets = torch.tensor(batch['targets'][val_i:val_i+params['batch_size']]).to(device) #inputs = inputs.view(inputs.shape[0], inputs.shape[3], inputs.shape[1], inputs.shape[2]) inputs=torch.transpose(inputs,1,2) #print(inputs.shape) inputs=torch.transpose(inputs,1,3) #print(inputs.shape) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) outputs_transposed = outputs.permute(1,0,2) output_lengths = [len(x) for x in outputs] target_lenghts = [len(x) for x in targets] #print(output_lengths) #print(target_lenghts) #print(torch.argmax(outputs[5], dim=1)) #print(targets[5]) decoded_out = [torch.argmax(out, dim=1) for out in outputs] decoded_target = [target for target in targets] correct = 0 total = 0 for out, target in zip(decoded_out, decoded_target): total += len(target) for o, t in zip(out, target): if o == t: correct += 1 print(decoded_out[0]) print(decoded_target[0]) #print(outputs.size) running_corrects = correct / total loss = loss_func(outputs_transposed, targets, output_lengths, target_lenghts) # backward + optimize only if in training phase if phase == 'train': loss.backward() #print(model.conv1.weight.grad) optimizer.step() # statistics TODO: pitaj running_loss += loss.item() * inputs.size(0) #running_corrects += torch.sum(preds == labels.data).item() if phase == 'train': if primus.current_idx == 0: break else: val_i += params['batch_size'] if val_i >= len(primus.validation_list): break if phase == 'train': epoch_loss = running_loss/ len(primus.training_list) #loss_log.append([epoch_loss]) train_loss.append(epoch_loss) else: epoch_loss = running_loss / len(primus.validation_list) #loss_log[-1].append(epoch_loss) val_loss.append(epoch_loss) #epoch_acc = float(running_corrects) / dataset_sizes[phase] print(f'{phase} Loss: {epoch_loss:.4f}')#' Acc: {epoch_acc:.4f}') if phase == 'train': print(f'{phase} Acc: {running_corrects/16:.4f}') else: print(f'{phase} Acc: {running_corrects/16:.4f}') #metrics[phase+"_loss"].append(epoch_loss) #metrics[phase+"_acc"].append(epoch_acc) ''' # deep copy the model if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict())''' print() #scheduler.step() torch.save(model, './model1') print(time.time() - epoch_time) time_elapsed = time.time() - start_time print(f'Training complete in {(time_elapsed // 60):.0f}m {time_elapsed % 60:.0f}s') plt.plot(train_loss,np.linspace(0,len(train_loss),len(train_loss)),color='r') plt.show() plt.plot(val_loss,np.linspace(0,len(val_loss),len(val_loss)),color='g') plt.show() #print(loss_log) #print('Bestval Acc: {best_acc:4f}')