-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathpost_process.cpp
116 lines (94 loc) · 2.68 KB
/
post_process.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/**********************************************************
* Author : lingteng qiu
* Email :
* Create time : 2018-12-25 14:36
* Last modified : 2018-12-25 14:36
* Filename : post_process.cpp
* Description : post_process for our detection system
* *******************************************************/
#include<torch/script.h>
#include<iostream>
#include<vector>
#include<queue>
#include"post_process.hpp"
torch::Tensor jaccard(torch::Tensor source,torch::Tensor point)
{
//step1 intersect
torch::Tensor temp = torch::zeros({source.size(0),source.size(1)});
temp.copy_(source);
auto c1 = temp.select(1,0);
auto c2 = temp.select(1,1);
auto c3 = temp.select(1,2);
auto c4 = temp.select(1,3);
//area
auto area_a = (c3-c1)*(c4-c2);
auto area_b = (point[2]-point[0])*(point[3]-point[1]);
//inter
c1.masked_fill_(c1<point[0],point[0]);
c2.masked_fill_(c2<point[1],point[1]);
c3.masked_fill_(c3>point[2],point[2]);
c4.masked_fill_(c4>point[3],point[3]);
auto w = c3-c1;
auto h = c4-c2;
w.masked_fill_(w<0,0);
h.masked_fill_(h<0,0);
auto inter = w.mul(h);
auto unions = area_a+ area_b-inter;
return inter.div(unions);
}
torch::Tensor post_process(torch::Tensor boxes,std::tuple<int,int> shape)
{
int h = std::get<0>(shape);
int w = std::get<1>(shape);
auto temp = boxes.index_select(1,torch::tensor({0,1,2,3}).to(torch::kLong));
temp.select(1,0).masked_fill_(temp.select(1,0)<0,0);
temp.select(1,1).masked_fill_(temp.select(1,1)<0,0);
temp.select(1,2).masked_fill_(temp.select(1,2)>=w,w-1);
temp.select(1,3).masked_fill_(temp.select(1,3)>=h,h-1);
std::vector<torch::Tensor> sets;
while(temp.size(0)>0)
{
std::queue<torch::Tensor> q;
q.push(temp.select(0,0));
auto sub_set = q.front();
temp = temp.slice(0,1,temp.size(0));
while(!q.empty())
{
auto point = q.front();
q.pop();
auto overlap = jaccard(temp,point);
std::vector<long> delete_list;
for(int i = 0;i<overlap.size(0);i++)
{
if(overlap[i].item<float>() >0)
{
auto ele = temp.select(0,i);
sub_set = torch::cat({sub_set,ele},0);
q.push(ele);
}
else
{
delete_list.push_back(i);
}
}
auto keep = torch::from_blob(delete_list.data(),{delete_list.size()},torch::kLong);
temp = temp.index_select(0,keep);
}
sets.push_back(sub_set);
}
//find best over rect
auto retv = torch::zeros({sets.size(),4});
int cnt = 0;
for(auto set : sets)
{
set = set.view({-1,4});
auto min_tensor = std::get<0>(set.min(0));
auto max_tensor = std::get<0>(set.max(0));
retv[cnt][0] = min_tensor[0];
retv[cnt][1] = min_tensor[1];
retv[cnt][2] = max_tensor[2];
retv[cnt][3] = max_tensor[3];
cnt++;
}
return retv;
}