Skip to content

Commit 66ef5d5

Browse files
committed
ENH: Script to evaluate the dental model segmentation
1 parent 8ae4b7e commit 66ef5d5

File tree

1 file changed

+251
-0
lines changed

1 file changed

+251
-0
lines changed

src/py/eval_dental_modelseg.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
2+
from __future__ import print_function
3+
import numpy as np
4+
import vtk
5+
import argparse
6+
import os
7+
from datetime import datetime, time
8+
import json
9+
import glob
10+
import time
11+
from sklearn.metrics import confusion_matrix
12+
from sklearn.metrics import roc_curve, auc
13+
import matplotlib as mpl
14+
mpl.use('Agg')
15+
import matplotlib.pyplot as plt
16+
import itertools
17+
from scipy import interp
18+
import csv
19+
20+
parser = argparse.ArgumentParser(description='Evaluate dental model seg', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
21+
22+
parser.add_argument('--csv', type=str, help='CSV columns gt and prediction. Both VTK files with label array "RegionId"', required=True)
23+
parser.add_argument('--out', type=str, help='Out filename for plots', default="./out")
24+
25+
args = parser.parse_args()
26+
27+
y_pred_arr = []
28+
y_true_arr = []
29+
30+
fpr_arr = []
31+
tpr_arr = []
32+
roc_auc_arr = []
33+
iou_arr = []
34+
35+
abs_diff_arr = []
36+
mse_arr = []
37+
38+
eval_type = "class"
39+
class_names = ["Gum", "Teeth", "Boundary"]
40+
41+
if(eval_type == "class"):
42+
43+
with open(args.csv) as csvfile:
44+
csv_reader = csv.DictReader(csvfile)
45+
for row in csv_reader:
46+
47+
args = parser.parse_args()
48+
49+
reader = vtk.vtkPolyDataReader()
50+
reader.SetFileName(row["gt"])
51+
reader.Update()
52+
53+
clean = vtk.vtkCleanPolyData()
54+
clean.SetInputData(reader.GetOutput())
55+
clean.SetTolerance(0.0001)
56+
clean.Update()
57+
surf1 = clean.GetOutput()
58+
59+
surf1_label = surf1.GetPointData().GetArray('RegionId')
60+
for pid in range(surf1_label.GetNumberOfTuples()):
61+
y_true_arr.append(surf1_label.GetTuple(pid)[0])
62+
63+
reader = vtk.vtkPolyDataReader()
64+
reader.SetFileName(row["prediction"])
65+
reader.Update()
66+
67+
clean = vtk.vtkCleanPolyData()
68+
clean.SetInputData(reader.GetOutput())
69+
clean.SetTolerance(0.0001)
70+
clean.Update()
71+
surf2 = clean.GetOutput()
72+
73+
surf2_label = surf2.GetPointData().GetArray('RegionId')
74+
for pid in range(surf2_label.GetNumberOfTuples()):
75+
if surf2_label.GetTuple(pid)[0] == -1:
76+
y_pred_arr.append(0)
77+
else:
78+
y_pred_arr.append(surf2_label.GetTuple(pid)[0] - 1)
79+
80+
elif(eval_type == "segmentation"):
81+
fpr, tpr, _ = roc_curve(np.array(image_batch[1]).reshape(-1), np.array(y_pred).reshape(-1), pos_label=1)
82+
roc_auc = auc(fpr,tpr)
83+
84+
fpr_arr.append(fpr)
85+
tpr_arr.append(tpr)
86+
roc_auc_arr.append(roc_auc)
87+
88+
y_pred_flat = np.array(y_pred).reshape((len(y_pred), -1))
89+
labels_flat = np.array(image_batch[1]).reshape((len(y_pred), -1))
90+
91+
for i in range(len(y_pred)):
92+
intersection = 2.0 * np.sum(y_pred_flat[i] * labels_flat[i]) + 1e-7
93+
union = np.sum(y_pred_flat[i]) + np.sum(labels_flat[i]) + 1e-7
94+
iou_arr.append(intersection/union)
95+
96+
elif(eval_type == "image" or eval_type == "numeric"):
97+
abs_diff_arr.extend(np.average(np.absolute(y_pred - image_batch[1]).reshape([1, -1]), axis=-1))
98+
mse_arr.extend(np.average(np.square(y_pred - image_batch[1]).reshape([1, -1]), axis=-1))
99+
100+
101+
def plot_confusion_matrix(cm, classes,
102+
normalize=False,
103+
title='Confusion matrix',
104+
cmap=plt.cm.Blues):
105+
"""
106+
This function prints and plots the confusion matrix.
107+
Normalization can be applied by setting `normalize=True`.
108+
"""
109+
if normalize:
110+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
111+
print("Normalized confusion matrix")
112+
else:
113+
print('Confusion matrix, without normalization')
114+
115+
print(cm)
116+
117+
plt.imshow(cm, interpolation='nearest', cmap=cmap)
118+
plt.title(title)
119+
plt.colorbar()
120+
tick_marks = np.arange(len(classes))
121+
plt.xticks(tick_marks, classes, rotation=45)
122+
plt.yticks(tick_marks, classes)
123+
124+
fmt = '.3f' if normalize else 'd'
125+
thresh = cm.max() / 2.
126+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
127+
plt.text(j, i, format(cm[i, j], fmt),
128+
horizontalalignment="center",
129+
color="white" if cm[i, j] > thresh else "black")
130+
131+
plt.ylabel('True label')
132+
plt.xlabel('Predicted label')
133+
plt.tight_layout()
134+
135+
if(eval_type == "class"):
136+
# Compute confusion matrix
137+
138+
cnf_matrix = confusion_matrix(y_true_arr, y_pred_arr)
139+
np.set_printoptions(precision=3)
140+
141+
# Plot non-normalized confusion matrix
142+
fig = plt.figure()
143+
plot_confusion_matrix(cnf_matrix, classes=class_names, title='Confusion matrix, without normalization')
144+
confusion_filename = os.path.splitext(args.out)[0] + "_confusion.png"
145+
fig.savefig(confusion_filename)
146+
# Plot normalized confusion matrix
147+
fig2 = plt.figure()
148+
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, title='Normalized confusion matrix')
149+
150+
norm_confusion_filename = os.path.splitext(args.out)[0] + "_norm_confusion.png"
151+
fig2.savefig(norm_confusion_filename)
152+
153+
elif(eval_type == "segmentation"):
154+
155+
# First aggregate all false positive rates
156+
all_fpr = np.unique(np.concatenate([fpr for fpr in fpr_arr]))
157+
158+
# Then interpolate all ROC curves at this points
159+
mean_tpr = np.zeros_like(all_fpr)
160+
for i in range(len(fpr_arr)):
161+
mean_tpr += interp(all_fpr, fpr_arr[i], tpr_arr[i])
162+
163+
mean_tpr /= len(fpr_arr)
164+
165+
roc_auc = auc(all_fpr, mean_tpr)
166+
167+
roc_fig = plt.figure()
168+
lw = 1
169+
plt.plot(all_fpr, mean_tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
170+
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
171+
plt.xlim([0.0, 1.0])
172+
plt.ylim([0.0, 1.05])
173+
plt.xlabel('False Positive Rate')
174+
plt.ylabel('True Positive Rate')
175+
plt.title('Receiver operating characteristic')
176+
plt.legend(loc="lower right")
177+
178+
roc_filename = os.path.splitext(json_tf_records)[0] + "_roc.png"
179+
roc_fig.savefig(roc_filename)
180+
181+
iou_obj = {}
182+
iou_obj["iou"] = iou_arr
183+
184+
iou_json = os.path.splitext(json_tf_records)[0] + "_iou_arr.json"
185+
186+
with open(iou_json, "w") as f:
187+
f.write(json.dumps(iou_obj))
188+
189+
iou_fig_polar = plt.figure()
190+
ax = iou_fig_polar.add_subplot(111, projection='polar')
191+
theta = 2 * np.pi * np.arange(len(iou_arr))/len(iou_arr)
192+
colors = iou_arr
193+
ax.scatter(theta, iou_arr, c=colors, cmap='autumn', alpha=0.75)
194+
ax.set_rlim(0,1)
195+
plt.title('Intersection over union')
196+
locs, labels = plt.xticks()
197+
plt.xticks(locs, np.arange(0, len(iou_arr), round(len(iou_arr)/len(locs))))
198+
199+
iou_polar_filename = os.path.splitext(json_tf_records)[0] + "_iou_polar.png"
200+
iou_fig_polar.savefig(iou_polar_filename)
201+
202+
iou_fig = plt.figure()
203+
x_samples = np.arange(len(iou_arr))
204+
plt.scatter(x_samples, iou_arr, c=colors, cmap='autumn', alpha=0.75)
205+
plt.title('Intersection over union')
206+
iou_mean = np.mean(iou_arr)
207+
plt.plot(x_samples,[iou_mean]*len(iou_arr), label='Mean', linestyle='--')
208+
plt.text(len(iou_arr) + 2,iou_mean, '%.3f'%iou_mean)
209+
iou_stdev = np.std(iou_arr)
210+
stdev_line = plt.plot(x_samples,iou_mean + [iou_stdev]*len(iou_arr), label='Stdev', linestyle=':', alpha=0.75)
211+
stdev_line = plt.plot(x_samples,iou_mean - [iou_stdev]*len(iou_arr), label='Stdev', linestyle=':', alpha=0.75)
212+
plt.text(len(iou_arr) + 2,iou_mean + iou_stdev, '%.3f'%iou_stdev, alpha=0.75, fontsize='x-small')
213+
iou_filename = os.path.splitext(json_tf_records)[0] + "_iou.png"
214+
iou_fig.savefig(iou_filename)
215+
216+
elif(eval_type == "image" or eval_type == "numeric"):
217+
abs_diff_arr = np.array(abs_diff_arr)
218+
abs_diff_fig = plt.figure()
219+
x_samples = np.arange(len(abs_diff_arr))
220+
221+
plt.scatter(x_samples, abs_diff_arr, c=abs_diff_arr, cmap='cool', alpha=0.75, label='Mean absolute error')
222+
plt.xlabel('Samples')
223+
plt.ylabel('Absolute error')
224+
plt.title('Mean absolute error')
225+
226+
abs_diff_mean = np.array([np.mean(abs_diff_arr)]*len(abs_diff_arr))
227+
mean_line = plt.plot(x_samples,abs_diff_mean, label='Mean', linestyle='--')
228+
abs_diff_stdev = np.array([np.std(abs_diff_mean)]*len(abs_diff_mean))
229+
stdev_line = plt.plot(x_samples, abs_diff_mean + abs_diff_stdev, label='Mean', linestyle=':', alpha=0.75)
230+
plt.text(len(abs_diff_mean), np.mean(abs_diff_mean), "{0:.3f}".format(np.mean(abs_diff_mean)))
231+
232+
abs_filename = os.path.splitext(json_tf_records)[0] + "_abs_diff.png"
233+
abs_diff_fig.savefig(abs_filename)
234+
235+
mse_arr = np.array(mse_arr)
236+
mse_fig = plt.figure()
237+
plt.scatter(x_samples, mse_arr, c=mse_arr, cmap='cool', alpha=0.75, label='MSE')
238+
plt.xlabel('Samples')
239+
plt.ylabel('MSE')
240+
plt.title('Mean squared error')
241+
242+
mse_mean = np.array([np.mean(mse_arr)]*len(mse_arr))
243+
mse_line = plt.plot(x_samples,mse_mean, label='Mean', linestyle='--')
244+
mse_stdev = np.array([np.std(mse_arr)]*len(mse_arr))
245+
stdev_line = plt.plot(x_samples, mse_mean + mse_stdev, label='Mean', linestyle=':', alpha=0.75)
246+
plt.text(len(mse_mean), np.mean(mse_mean), "{0:.3f}".format(np.mean(mse_mean)))
247+
248+
mse_filename = os.path.splitext(json_tf_records)[0] + "_mse.png"
249+
mse_fig.savefig(mse_filename)
250+
251+
print("mae:", abs_diff_mean[0], "mse:", mse_mean[0])

0 commit comments

Comments
 (0)