-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy patheval.py
75 lines (62 loc) · 2.71 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
from torchmeta.utils.prototype import get_prototypes
from train.metric_based import get_accuracy
from utils import MetricLogger
import numpy as np
import argparse
import torch
import torch.nn as nn
parser = argparse.ArgumentParser(description = 'STUNT')
parser.add_argument('--data_name', default = 'income', type = str)
parser.add_argument('--shot_num', default = 1, type=int)
parser.add_argument('--load_path', default = '', type=str)
parser.add_argument('--seed', default = 0, type = int)
args = parser.parse_args()
if args.data_name == 'income':
input_size = 105
output_size = 2
hidden_dim = 1024
class MLPProto(nn.Module):
def __init__(self, in_features, out_features, hidden_sizes, drop_p = 0.):
super(MLPProto, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.hidden_sizes = hidden_sizes
self.drop_p = drop_p
self.encoder = nn.Sequential(
nn.Linear(in_features, hidden_sizes, bias=True),
nn.ReLU(),
nn.Linear(hidden_sizes, hidden_sizes, bias=True)
)
def forward(self, inputs):
embeddings = self.encoder(inputs)
return embeddings
model = MLPProto(input_size, hidden_dim, hidden_dim)
model.load_state_dict(torch.load(args.load_path))
train_x = np.load('./data/'+args.data_name+'/xtrain.npy')
train_y = np.load('./data/'+args.data_name+'/ytrain.npy')
test_x = np.load('./data/'+args.data_name+'/xtest.npy')
test_y = np.load('./data/'+args.data_name+'/ytest.npy')
train_idx = np.load('./data/'+args.data_name+'/index{}/train_idx_{}.npy'.format(args.shot_num, args.seed))
few_train = model(torch.tensor(train_x[train_idx]).float())
support_x = few_train.detach().numpy()
support_y = train_y[train_idx]
few_test = model(torch.tensor(test_x).float())
query_x = few_test.detach().numpy()
query_y = test_y
def get_accuracy(prototypes, embeddings, targets):
sq_distances = torch.sum((prototypes.unsqueeze(1)
- embeddings.unsqueeze(2)) ** 2, dim=-1)
_, predictions = torch.min(sq_distances, dim=-1)
return torch.mean(predictions.eq(targets).float()) * 100.
train_x = torch.tensor(support_x.astype(np.float32)).unsqueeze(0)
train_y = torch.tensor(support_y.astype(np.int64)).unsqueeze(0).type(torch.LongTensor)
val_x = torch.tensor(query_x.astype(np.float32)).unsqueeze(0)
val_y = torch.tensor(query_y.astype(np.int64)).unsqueeze(0).type(torch.LongTensor)
prototypes = get_prototypes(train_x, train_y, output_size)
acc = get_accuracy(prototypes, val_x, val_y).item()
print(args.seed, acc)
out_file = 'result/{}_{}shot/test'.format(args.data_name, args.shot_num)
with open(out_file, 'a+') as f:
f.write('seed: '+str(args.seed)+' test: '+str(acc))
f.write('\n')