1
+
2
+ from __future__ import print_function
3
+ import numpy as np
4
+ import vtk
5
+ import argparse
6
+ import os
7
+ from datetime import datetime , time
8
+ import json
9
+ import glob
10
+ import time
11
+ from sklearn .metrics import confusion_matrix
12
+ from sklearn .metrics import roc_curve , auc
13
+ import matplotlib as mpl
14
+ mpl .use ('Agg' )
15
+ import matplotlib .pyplot as plt
16
+ import itertools
17
+ from scipy import interp
18
+ import csv
19
+
20
+ parser = argparse .ArgumentParser (description = 'Evaluate dental model seg' , formatter_class = argparse .ArgumentDefaultsHelpFormatter )
21
+
22
+ parser .add_argument ('--csv' , type = str , help = 'CSV columns gt and prediction. Both VTK files with label array "RegionId"' , required = True )
23
+ parser .add_argument ('--out' , type = str , help = 'Out filename for plots' , default = "./out" )
24
+
25
+ args = parser .parse_args ()
26
+
27
+ y_pred_arr = []
28
+ y_true_arr = []
29
+
30
+ fpr_arr = []
31
+ tpr_arr = []
32
+ roc_auc_arr = []
33
+ iou_arr = []
34
+
35
+ abs_diff_arr = []
36
+ mse_arr = []
37
+
38
+ eval_type = "class"
39
+ class_names = ["Gum" , "Teeth" , "Boundary" ]
40
+
41
+ if (eval_type == "class" ):
42
+
43
+ with open (args .csv ) as csvfile :
44
+ csv_reader = csv .DictReader (csvfile )
45
+ for row in csv_reader :
46
+
47
+ args = parser .parse_args ()
48
+
49
+ reader = vtk .vtkPolyDataReader ()
50
+ reader .SetFileName (row ["gt" ])
51
+ reader .Update ()
52
+
53
+ clean = vtk .vtkCleanPolyData ()
54
+ clean .SetInputData (reader .GetOutput ())
55
+ clean .SetTolerance (0.0001 )
56
+ clean .Update ()
57
+ surf1 = clean .GetOutput ()
58
+
59
+ surf1_label = surf1 .GetPointData ().GetArray ('RegionId' )
60
+ for pid in range (surf1_label .GetNumberOfTuples ()):
61
+ y_true_arr .append (surf1_label .GetTuple (pid )[0 ])
62
+
63
+ reader = vtk .vtkPolyDataReader ()
64
+ reader .SetFileName (row ["prediction" ])
65
+ reader .Update ()
66
+
67
+ clean = vtk .vtkCleanPolyData ()
68
+ clean .SetInputData (reader .GetOutput ())
69
+ clean .SetTolerance (0.0001 )
70
+ clean .Update ()
71
+ surf2 = clean .GetOutput ()
72
+
73
+ surf2_label = surf2 .GetPointData ().GetArray ('RegionId' )
74
+ for pid in range (surf2_label .GetNumberOfTuples ()):
75
+ if surf2_label .GetTuple (pid )[0 ] == - 1 :
76
+ y_pred_arr .append (0 )
77
+ else :
78
+ y_pred_arr .append (surf2_label .GetTuple (pid )[0 ] - 1 )
79
+
80
+ elif (eval_type == "segmentation" ):
81
+ fpr , tpr , _ = roc_curve (np .array (image_batch [1 ]).reshape (- 1 ), np .array (y_pred ).reshape (- 1 ), pos_label = 1 )
82
+ roc_auc = auc (fpr ,tpr )
83
+
84
+ fpr_arr .append (fpr )
85
+ tpr_arr .append (tpr )
86
+ roc_auc_arr .append (roc_auc )
87
+
88
+ y_pred_flat = np .array (y_pred ).reshape ((len (y_pred ), - 1 ))
89
+ labels_flat = np .array (image_batch [1 ]).reshape ((len (y_pred ), - 1 ))
90
+
91
+ for i in range (len (y_pred )):
92
+ intersection = 2.0 * np .sum (y_pred_flat [i ] * labels_flat [i ]) + 1e-7
93
+ union = np .sum (y_pred_flat [i ]) + np .sum (labels_flat [i ]) + 1e-7
94
+ iou_arr .append (intersection / union )
95
+
96
+ elif (eval_type == "image" or eval_type == "numeric" ):
97
+ abs_diff_arr .extend (np .average (np .absolute (y_pred - image_batch [1 ]).reshape ([1 , - 1 ]), axis = - 1 ))
98
+ mse_arr .extend (np .average (np .square (y_pred - image_batch [1 ]).reshape ([1 , - 1 ]), axis = - 1 ))
99
+
100
+
101
+ def plot_confusion_matrix (cm , classes ,
102
+ normalize = False ,
103
+ title = 'Confusion matrix' ,
104
+ cmap = plt .cm .Blues ):
105
+ """
106
+ This function prints and plots the confusion matrix.
107
+ Normalization can be applied by setting `normalize=True`.
108
+ """
109
+ if normalize :
110
+ cm = cm .astype ('float' ) / cm .sum (axis = 1 )[:, np .newaxis ]
111
+ print ("Normalized confusion matrix" )
112
+ else :
113
+ print ('Confusion matrix, without normalization' )
114
+
115
+ print (cm )
116
+
117
+ plt .imshow (cm , interpolation = 'nearest' , cmap = cmap )
118
+ plt .title (title )
119
+ plt .colorbar ()
120
+ tick_marks = np .arange (len (classes ))
121
+ plt .xticks (tick_marks , classes , rotation = 45 )
122
+ plt .yticks (tick_marks , classes )
123
+
124
+ fmt = '.3f' if normalize else 'd'
125
+ thresh = cm .max () / 2.
126
+ for i , j in itertools .product (range (cm .shape [0 ]), range (cm .shape [1 ])):
127
+ plt .text (j , i , format (cm [i , j ], fmt ),
128
+ horizontalalignment = "center" ,
129
+ color = "white" if cm [i , j ] > thresh else "black" )
130
+
131
+ plt .ylabel ('True label' )
132
+ plt .xlabel ('Predicted label' )
133
+ plt .tight_layout ()
134
+
135
+ if (eval_type == "class" ):
136
+ # Compute confusion matrix
137
+
138
+ cnf_matrix = confusion_matrix (y_true_arr , y_pred_arr )
139
+ np .set_printoptions (precision = 3 )
140
+
141
+ # Plot non-normalized confusion matrix
142
+ fig = plt .figure ()
143
+ plot_confusion_matrix (cnf_matrix , classes = class_names , title = 'Confusion matrix, without normalization' )
144
+ confusion_filename = os .path .splitext (args .out )[0 ] + "_confusion.png"
145
+ fig .savefig (confusion_filename )
146
+ # Plot normalized confusion matrix
147
+ fig2 = plt .figure ()
148
+ plot_confusion_matrix (cnf_matrix , classes = class_names , normalize = True , title = 'Normalized confusion matrix' )
149
+
150
+ norm_confusion_filename = os .path .splitext (args .out )[0 ] + "_norm_confusion.png"
151
+ fig2 .savefig (norm_confusion_filename )
152
+
153
+ elif (eval_type == "segmentation" ):
154
+
155
+ # First aggregate all false positive rates
156
+ all_fpr = np .unique (np .concatenate ([fpr for fpr in fpr_arr ]))
157
+
158
+ # Then interpolate all ROC curves at this points
159
+ mean_tpr = np .zeros_like (all_fpr )
160
+ for i in range (len (fpr_arr )):
161
+ mean_tpr += interp (all_fpr , fpr_arr [i ], tpr_arr [i ])
162
+
163
+ mean_tpr /= len (fpr_arr )
164
+
165
+ roc_auc = auc (all_fpr , mean_tpr )
166
+
167
+ roc_fig = plt .figure ()
168
+ lw = 1
169
+ plt .plot (all_fpr , mean_tpr , color = 'darkorange' , lw = lw , label = 'ROC curve (area = %0.2f)' % roc_auc )
170
+ plt .plot ([0 , 1 ], [0 , 1 ], color = 'navy' , lw = lw , linestyle = '--' )
171
+ plt .xlim ([0.0 , 1.0 ])
172
+ plt .ylim ([0.0 , 1.05 ])
173
+ plt .xlabel ('False Positive Rate' )
174
+ plt .ylabel ('True Positive Rate' )
175
+ plt .title ('Receiver operating characteristic' )
176
+ plt .legend (loc = "lower right" )
177
+
178
+ roc_filename = os .path .splitext (json_tf_records )[0 ] + "_roc.png"
179
+ roc_fig .savefig (roc_filename )
180
+
181
+ iou_obj = {}
182
+ iou_obj ["iou" ] = iou_arr
183
+
184
+ iou_json = os .path .splitext (json_tf_records )[0 ] + "_iou_arr.json"
185
+
186
+ with open (iou_json , "w" ) as f :
187
+ f .write (json .dumps (iou_obj ))
188
+
189
+ iou_fig_polar = plt .figure ()
190
+ ax = iou_fig_polar .add_subplot (111 , projection = 'polar' )
191
+ theta = 2 * np .pi * np .arange (len (iou_arr ))/ len (iou_arr )
192
+ colors = iou_arr
193
+ ax .scatter (theta , iou_arr , c = colors , cmap = 'autumn' , alpha = 0.75 )
194
+ ax .set_rlim (0 ,1 )
195
+ plt .title ('Intersection over union' )
196
+ locs , labels = plt .xticks ()
197
+ plt .xticks (locs , np .arange (0 , len (iou_arr ), round (len (iou_arr )/ len (locs ))))
198
+
199
+ iou_polar_filename = os .path .splitext (json_tf_records )[0 ] + "_iou_polar.png"
200
+ iou_fig_polar .savefig (iou_polar_filename )
201
+
202
+ iou_fig = plt .figure ()
203
+ x_samples = np .arange (len (iou_arr ))
204
+ plt .scatter (x_samples , iou_arr , c = colors , cmap = 'autumn' , alpha = 0.75 )
205
+ plt .title ('Intersection over union' )
206
+ iou_mean = np .mean (iou_arr )
207
+ plt .plot (x_samples ,[iou_mean ]* len (iou_arr ), label = 'Mean' , linestyle = '--' )
208
+ plt .text (len (iou_arr ) + 2 ,iou_mean , '%.3f' % iou_mean )
209
+ iou_stdev = np .std (iou_arr )
210
+ stdev_line = plt .plot (x_samples ,iou_mean + [iou_stdev ]* len (iou_arr ), label = 'Stdev' , linestyle = ':' , alpha = 0.75 )
211
+ stdev_line = plt .plot (x_samples ,iou_mean - [iou_stdev ]* len (iou_arr ), label = 'Stdev' , linestyle = ':' , alpha = 0.75 )
212
+ plt .text (len (iou_arr ) + 2 ,iou_mean + iou_stdev , '%.3f' % iou_stdev , alpha = 0.75 , fontsize = 'x-small' )
213
+ iou_filename = os .path .splitext (json_tf_records )[0 ] + "_iou.png"
214
+ iou_fig .savefig (iou_filename )
215
+
216
+ elif (eval_type == "image" or eval_type == "numeric" ):
217
+ abs_diff_arr = np .array (abs_diff_arr )
218
+ abs_diff_fig = plt .figure ()
219
+ x_samples = np .arange (len (abs_diff_arr ))
220
+
221
+ plt .scatter (x_samples , abs_diff_arr , c = abs_diff_arr , cmap = 'cool' , alpha = 0.75 , label = 'Mean absolute error' )
222
+ plt .xlabel ('Samples' )
223
+ plt .ylabel ('Absolute error' )
224
+ plt .title ('Mean absolute error' )
225
+
226
+ abs_diff_mean = np .array ([np .mean (abs_diff_arr )]* len (abs_diff_arr ))
227
+ mean_line = plt .plot (x_samples ,abs_diff_mean , label = 'Mean' , linestyle = '--' )
228
+ abs_diff_stdev = np .array ([np .std (abs_diff_mean )]* len (abs_diff_mean ))
229
+ stdev_line = plt .plot (x_samples , abs_diff_mean + abs_diff_stdev , label = 'Mean' , linestyle = ':' , alpha = 0.75 )
230
+ plt .text (len (abs_diff_mean ), np .mean (abs_diff_mean ), "{0:.3f}" .format (np .mean (abs_diff_mean )))
231
+
232
+ abs_filename = os .path .splitext (json_tf_records )[0 ] + "_abs_diff.png"
233
+ abs_diff_fig .savefig (abs_filename )
234
+
235
+ mse_arr = np .array (mse_arr )
236
+ mse_fig = plt .figure ()
237
+ plt .scatter (x_samples , mse_arr , c = mse_arr , cmap = 'cool' , alpha = 0.75 , label = 'MSE' )
238
+ plt .xlabel ('Samples' )
239
+ plt .ylabel ('MSE' )
240
+ plt .title ('Mean squared error' )
241
+
242
+ mse_mean = np .array ([np .mean (mse_arr )]* len (mse_arr ))
243
+ mse_line = plt .plot (x_samples ,mse_mean , label = 'Mean' , linestyle = '--' )
244
+ mse_stdev = np .array ([np .std (mse_arr )]* len (mse_arr ))
245
+ stdev_line = plt .plot (x_samples , mse_mean + mse_stdev , label = 'Mean' , linestyle = ':' , alpha = 0.75 )
246
+ plt .text (len (mse_mean ), np .mean (mse_mean ), "{0:.3f}" .format (np .mean (mse_mean )))
247
+
248
+ mse_filename = os .path .splitext (json_tf_records )[0 ] + "_mse.png"
249
+ mse_fig .savefig (mse_filename )
250
+
251
+ print ("mae:" , abs_diff_mean [0 ], "mse:" , mse_mean [0 ])
0 commit comments