-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
221 lines (179 loc) · 7.81 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import argparse
import copy
import math
from multiprocessing import freeze_support
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torch
# our libraries
from model.base_model import BaseModel
from model.multimodal_model import MultimodalModel
from uah_dataset.dataset import Dataset
from utils.config import config
config_dict = None
# initialize dataloader here
def prepare_data(mode: str):
if mode == "train":
dataset = Dataset(**config_dict["dataset_args"])
split_sizes = [int(math.ceil(len(dataset) * 0.8)), int(math.floor(len(dataset) * 0.2))]
trainset, valset = torch.utils.data.random_split(dataset, split_sizes)
trainloader = DataLoader(trainset, **config_dict["dataloader_args"])
validationloader = DataLoader(valset, **config_dict["dataloader_args"])
return trainloader, validationloader
if mode == "test":
dataset = Dataset(**config_dict["dataset_args"])
dataloader = DataLoader(dataset, pin_memory=True)
return dataloader
# initialize model here
def prepare_model():
# load model flag, which decides wether the model is trained or evaluated
load_flag = False if config_dict["evaluation"] == "None" else True
log = config_dict["model_args"]["log"]
config_dict["model_args"]["log"] = False if load_flag else log
# create the model, by loading its class name
model_name = config_dict["model_name"]
window_size = config_dict["dataset_args"]["window_size"]
model: BaseModel = config.get_model(model_name)(**config_dict["model_args"], window_size=window_size)
model.use_device(config_dict["device"])
# define log path in config and move the current hyperparameters to
# this directory in the case we have to train the model
if not load_flag:
config_dict["evaluation"] = model.log_path
config.store_args(f"{model.log_path}/config.yml", config_dict)
print(f"Prepared model: {model_name}")
return model
# if we only want to evaluate a model, we have to load the latest saved one
# from the provided dictionary
path = config_dict["evaluation"]
model_versions = []
for file in os.listdir(path):
if ".torch" in file:
model_versions.append(f"{path}/{file}")
model_versions.sort(reverse=True)
print(model_versions[0])
model.load(model_versions[0])
model.log_path = path
print(f"Loaded model: {model_name} ({path})")
return model
def train():
# prepare the train-, validation- and test datasets / dataloaders
train, validation = prepare_data(mode="train")
test = prepare_data(mode="test")
# prepare the model
model: BaseModel = prepare_model()
# showing weight analysis before training
if config_dict["model_name"] in ["Multimodal_v1", "Sensor_v1"]:
explain_model(model, initial=True)
# train the model and save it in the end
model.learn(train, validation, test, epochs=config_dict["train_epochs"],
save_every=config_dict["save_every"])
BaseModel.save_to_default(model)
# explain the model's weights
if config_dict["model_name"] in ["Multimodal_v1", "Sensor_v1"]:
explain_model(model)
# execute the model and look at results
for X in train:
sensor, image, label = X
sensor = sensor.to("cuda")
image = image.to("cuda")
label = label.to("cuda")
pred = model.forward((sensor, image)).argmax(dim=1)
print(pred, label)
def explain_model(model: MultimodalModel, initial: bool = False):
path = model.log_path if not initial else f"{model.log_path}/initial"
if not os.path.exists(path):
os.mkdir(path)
sensor_image = model.sensor_image_ratio()
sensor_importance = model.sensor_importance()
# pie diagram for sensor-image ratio
# labels = "Sensor", "Image"
# sizes = [100 * sensor_image, 100 * (1-sensor_image)]
# colors = [plt.cm.Reds(.4), plt.cm.Reds(.7)]
# fig1, ax1 = plt.subplots()
# ax1.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
# ax1.axis('equal')
# ax1.set_title("Sensor-Image Ratio")
# plt.axis('off')
# plt.savefig(f"{path}/sensor_image_ratio.png")
# bar diagram showing sensor importance
fig, ax = plt.subplots()
colors = [plt.cm.Reds(0.36 + .02 * i) for i in range(len(sensor_importance[0]))]
ax.bar(0.5 + np.arange(len(sensor_importance[0])), sensor_importance[0], color=colors)
ax.set_xlabel("Sensor ID")
ax.set_ylabel("Relative importance %")
# ax.set_title("Sensor Importance (no bias)")
plt.axis('off')
plt.savefig(f"{path}/sensor_importance_weights.png")
fig, ax = plt.subplots()
colors = [plt.cm.Reds(0.36 + .02 * i) for i in range(len(sensor_importance[1]))]
ax.bar(0.5 + np.arange(len(sensor_importance[1])), sensor_importance[1], color=colors)
ax.set_xlabel("Sensor ID")
ax.set_ylabel("Relative importance %")
# ax.set_title("Sensor Importance (bias included)")
plt.axis('off')
plt.savefig(f"{path}/sensor_importance_biases.png")
# alternative visualisation of sensor importance
data = np.array(sensor_importance[0])
data.resize((4, 6), refcheck=False)
data[data == 0.0] = np.nan
fig, ax = plt.subplots()
im = ax.matshow(data, cmap="Reds")
ax.set_axis_off()
for (i, j), z in np.ndenumerate(data):
if j + i * 6 < 22:
ax.text(j, i, f"{j + i * 6}", ha='center', va='center')
else:
ax.text(j, i, "", ha='center', va='center')
plt.tight_layout(pad=6)
cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
ticks = [np.nanmin(data), np.nanmax(data)]
bar = plt.colorbar(im, cax=cax, ticks=ticks)
bar.outline.set_visible(False)
bar.ax.set_yticklabels([f"{x: .2f}" for x in ticks])
# plt.title("Sensor Importance (no bias)")
plt.savefig(f"{path}/sensor_importance_weights_alt.png")
data = np.array(sensor_importance[1])
data.resize((4, 6), refcheck=False)
data[data == 0.0] = np.nan
fig, ax = plt.subplots()
im = ax.matshow(data, cmap="Reds")
ax.set_axis_off()
for (i, j), z in np.ndenumerate(data):
if j + i * 6 < 22:
ax.text(j, i, f"{j + i * 6}", ha='center', va='center')
else:
ax.text(j, i, "", ha='center', va='center')
plt.tight_layout(pad=6)
cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
ticks = [np.nanmin(data), np.nanmax(data)]
bar = plt.colorbar(im, cax=cax, ticks=ticks)
bar.outline.set_visible(False)
bar.ax.set_yticklabels([f"{x: .2f}" for x in ticks])
# plt.title("Sensor Importance (bias included)")
plt.savefig(f"{path}/sensor_importance_biases_alt.png")
def analyse():
# prepare the model
model: BaseModel = prepare_model()
# showing weight analysis before training
if config_dict["model_name"] in ["Multimodal_v1", "Sensor_v1"]:
explain_model(model)
else:
model_name = config_dict["model_name"]
print(f"Analysis not supported for model: {model_name}")
if __name__ == "__main__":
freeze_support()
# defining arguments
parser = argparse.ArgumentParser(description="This program trains and tests a deep " +
"learning model to detect a driving behaviour.")
parser.add_argument("--config", dest="train_config", help="Trains a model given the path to a configuration file.")
parser.add_argument("--analyse", dest="analyse", help="Analyse the weighs of a given model. Provide the config of a trained model.")
args = parser.parse_args()
if args.analyse:
config_dict = config.get_args(args.analyse)
analyse()
quit()
# load a configuration file
config_dict = config.get_args(args.train_config)
train()