import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim

from torch import autograd
from torch.utils.data import DataLoader
from collections import Counter

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score, f1_score, precision_score, recall_score, fbeta_score, roc_curve, auc, roc_auc_score
    
def pretty_print(*values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))    
        
class Train:
    def __init__(self, envs, X_te, Y_te, net, handler, args):
        self.envs = envs
        self.X_te = X_te
        self.Y_te = Y_te
        self.net = net
        self.handler = handler
        self.args = args
        use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if use_cuda else "cpu")

    def get_distribution(self):
        return self.class_distribution
    
    # Define loss function helpers
    def mean_nll(self, logits, y):
        return F.binary_cross_entropy_with_logits(logits, y)

    def mean_accuracy(self, logits, y):
        preds = (logits > 0.).float()
        return ((preds - y).abs() < 1e-2).float().mean(), preds

    def penalty(self, logits, y):
        scale = torch.tensor(1.).cuda().requires_grad_()
        loss = self.mean_nll(logits * scale, y)
        grad = autograd.grad(loss, [scale], create_graph=True)[0]
        return torch.sum(grad**2)
    
    def predict(self, X, Y):
        loader_te = DataLoader(self.handler(X, Y,
                                            transform=self.args['transform']['test']),
                               shuffle=True, **self.args['loader_te_args'])
        self.clf.eval()
        total_loss = nll = acc = 0.0
        preds = torch.zeros(len(Y), 1, dtype=torch.float)
        preds_Y = torch.zeros(len(Y), 1, dtype=torch.float)
        with torch.no_grad():
            for x, y, idxs in loader_te:
                x, y = x.to(self.device), y.to(self.device)
                out = self.clf(x)
                y.resize_((y.shape[0], 1))
                train_nll = self.mean_nll(out, y.float())
                train_acc, temp_preds = self.mean_accuracy(out, y.float())
                
                nll += train_nll
                acc += train_acc

                probs = torch.sigmoid(out)
                if str(self.device) == 'cuda':
                    preds[idxs] = probs.cpu()
                    preds_Y[idxs] = temp_preds.cpu()
                else:
                    preds[idxs] = probs           
                    preds_Y[idxs] = temp_preds

        return nll/len(loader_te), acc/len(loader_te), preds_Y, preds     

    def train(self):        
        n_classes = self.args['n_classes']
        self.clf = self.net(n_classes=n_classes).to(self.device)
        if self.args['fc_only']: # feature extraction
            optimizer = optim.Adam(self.clf.fc.parameters(), self.args['optimizer_args']['lr'])
        else:
            optimizer = optim.Adam(self.clf.parameters(), self.args['optimizer_args']['lr'])
        
        pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test nll', 'test acc', 'test prec', 'test rec')

        for step in range(self.args['steps']):  
            for env_idx, env in enumerate(self.envs):
                x = env['images']
                y = env['labels']
                loader_tr = DataLoader(self.handler(x, y, transform=self.args['transform']['train']), 
                                       shuffle=True, **self.args['loader_tr_args'])
                self.clf.train()
                nll = acc = penalty = 0.0
                
                for batch_idx, (x, y, idxs) in enumerate(loader_tr):
                    x, y = x.to(self.device), y.to(self.device)
                    optimizer.zero_grad()
                    logits = self.clf(x)
            
                    y.resize_((y.shape[0], 1))
                    train_nll = self.mean_nll(logits, y.float())
                    train_acc, _ = self.mean_accuracy(logits, y.float())
                    train_penalty = self.penalty(logits, y.float())
                    
                    nll += train_nll
                    acc += train_acc
                    penalty += train_penalty
                env['nll'] = nll / len(loader_tr)
                env['acc'] = acc / len(loader_tr)
                env['penalty'] = penalty / len(loader_tr)
                
            train_nll = torch.stack([self.envs[0]['nll'], self.envs[1]['nll']]).mean()
            train_acc = torch.stack([self.envs[0]['acc'], self.envs[1]['acc']]).mean()
            train_penalty = torch.stack([self.envs[0]['penalty'], self.envs[1]['penalty']]).mean()
            weight_norm = torch.tensor(0.).cuda()
            
            if self.args['fc_only']:
                for w in self.clf.fc.parameters():
                    weight_norm += w.norm().pow(2)
            else:
                for w in self.clf.parameters():
                    weight_norm += w.norm().pow(2)     
            
            loss = train_nll.clone()
            loss += self.args['optimizer_args']['l2_regularizer_weight'] * weight_norm
            penalty_weight = (self.args['optimizer_args']['penalty_weight'] 
                              if step >= self.args['optimizer_args']['penalty_anneal_iters'] else 1.0)
            loss += penalty_weight * train_penalty
            if penalty_weight > 1.0:
                # Rescale the entire loss to keep gradients in a reasonable range
                loss /= penalty_weight

            loss.backward()
            optimizer.step()
            
            test_loss, test_acc, preds, probs = self.predict(self.X_te, self.Y_te)
            
            #acc_test = accuracy_score(self.Y_te.detach().cpu().numpy(), preds.detach().cpu().numpy())
            test_prec = precision_score(self.Y_te.detach().cpu().numpy(), preds.detach().cpu().numpy())
            test_rec = recall_score(self.Y_te.detach().cpu().numpy(), preds.detach().cpu().numpy())
           
            if step % 10 == 0:
                pretty_print(np.int32(step), train_nll.detach().cpu().numpy(), 
                             train_acc.detach().cpu().numpy(), train_penalty.detach().cpu().numpy(), 
                             test_loss.detach().cpu().numpy(), test_acc.detach().cpu().numpy(),
                             test_prec, test_rec)
                
        return train_acc.detach().cpu().numpy(), test_acc.detach().cpu().numpy(), preds, probs