-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathperc_cw.py
189 lines (160 loc) · 8.29 KB
/
perc_cw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
from differential_color_functions import rgb2lab_diff, ciede2000_diff
def quantization(x):
"""quantize the continus image tensors into 255 levels (8 bit encoding)"""
x_quan=torch.round(x*255)/255
return x_quan
class PerC_CW:
"""
PerC_CW: C&W with a new substitute penaly term concerning perceptual color differences.
Modified from https://github.com/jeromerony/fast_adversarial/blob/master/fast_adv/adversarys/carlini.py
Parameters
----------
image_constraints : tuple
Bounds of the images.
num_classes : int
Number of classes of the model to adversary.
confidence : float, optional
Confidence of the adversary for Carlini's loss, in term of distance between logits.
learning_rate : float
Learning rate for the optimization.
search_steps : int
Number of search steps to find the best scale constant for Carlini's loss.
max_iterations : int
Maximum number of iterations during a single search step.
initial_const : float
Initial constant of the adversary.
device : torch.device, optional
Device to use for the adversary.
"""
def __init__(self,
image_constraints: Tuple[float, float] = [0,1],
num_classes: int = 1000,
confidence: float = 0,
learning_rate: float = 0.01,
search_steps: int = 9,
max_iterations: int = 1000,
abort_early: bool = True,
initial_const: float = 10,
device: torch.device = torch.device('cpu')
) -> None:
self.confidence = confidence
self.learning_rate = learning_rate
self.binary_search_steps = search_steps
self.max_iterations = max_iterations
self.abort_early = abort_early
self.initial_const = initial_const
self.num_classes = num_classes
self.repeat = self.binary_search_steps >= 10
self.boxmin = image_constraints[0]
self.boxmax = image_constraints[1]
self.boxmul = (self.boxmax - self.boxmin) / 2
self.boxplus = (self.boxmin + self.boxmax) / 2
self.device = device
@staticmethod
def _arctanh(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
x *= (1. - eps)
return (torch.log((1 + x) / (1 - x))) * 0.5
def _step(self, model: nn.Module, optimizer: optim.Optimizer, inputs: torch.Tensor, tinputs: torch.Tensor,
modifier: torch.Tensor, labels: torch.Tensor, labels_infhot: torch.Tensor, targeted: bool,
const: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = inputs.shape[0]
adv_input = torch.tanh(tinputs + modifier) * self.boxmul + self.boxplus
# calculate the global perceptual color differences (L2 norm of the color distance map)
l2 = torch.norm(ciede2000_diff(rgb2lab_diff(inputs,self.device),rgb2lab_diff(adv_input,self.device),self.device).view(batch_size, -1),dim=1)
logits = model((adv_input-0.5)/0.5)
real = logits.gather(1, labels.unsqueeze(1)).squeeze(1)
other = (logits - labels_infhot).max(1)[0]
if targeted:
# if targeted, optimize for making the other class most likely
logit_dists = torch.clamp(other - real + self.confidence, min=0)
else:
# if non-targeted, optimize for making this class least likely.
logit_dists = torch.clamp(real - other + self.confidence, min=0)
logit_loss = (const * logit_dists).sum()
l2_loss = l2.sum()
loss = logit_loss + l2_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
return quantization(adv_input).detach(), logits.detach(), l2.detach(), logit_dists.detach(), loss.detach()
def adversary(self, model: nn.Module, inputs: torch.Tensor, labels: torch.Tensor,
targeted: bool = False) -> torch.Tensor:
"""
Performs the adversary of the model given the inputs and labels.
Parameters
----------
model : nn.Module
Model to fool.
inputs : torch.Tensor
Batch of image examples.
labels : torch.Tensor
Original labels if untargeted, else labels of targets.
targeted : bool, optional
Whether to perform a targeted adversary or not.
Returns
-------
torch.Tensor
Batch of image samples modified to be adversarial
"""
batch_size = inputs.shape[0]
tinputs = self._arctanh((inputs - self.boxplus) / self.boxmul)
# set the lower and upper bounds accordingly
lower_bound = torch.zeros(batch_size, device=self.device)
CONST = torch.full((batch_size,), self.initial_const, device=self.device)
upper_bound = torch.full((batch_size,), 1e10, device=self.device)
o_best_l2 = torch.full((batch_size,), 1e10, device=self.device)
o_best_score = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
o_best_adversary = inputs.clone()
# setup the target variable, we need it to be in one-hot form for the loss function
labels_onehot = torch.zeros(labels.size(0), self.num_classes, device=self.device)
labels_onehot.scatter_(1, labels.unsqueeze(1), 1)
labels_infhot = torch.zeros_like(labels_onehot).scatter_(1, labels.unsqueeze(1), float('inf'))
for outer_step in range(self.binary_search_steps):
# setup the modifier variable, this is the variable we are optimizing over
modifier = torch.zeros_like(inputs, requires_grad=True)
# setup the optimizer
optimizer = optim.Adam([modifier], lr=self.learning_rate, betas=(0.9, 0.999), eps=1e-8)
best_l2 = torch.full((batch_size,), 1e10, device=self.device)
best_score = torch.full((batch_size,), -1, dtype=torch.long, device=self.device)
# The last iteration (if we run many steps) repeat the search once.
if self.repeat and outer_step == (self.binary_search_steps - 1):
CONST = upper_bound
prev = float('inf')
for iteration in range(self.max_iterations):
# perform the adversary
adv, logits, l2, logit_dists, loss = self._step(model, optimizer, inputs, tinputs, modifier,
labels, labels_infhot, targeted, CONST)
# check if we should abort search if we're getting nowhere.
if self.abort_early and iteration % (self.max_iterations // 10) == 0:
if loss > prev * 0.9999:
break
prev = loss
# adjust the best result found so far
predicted_classes = (model((adv-0.5)/0.5) - labels_onehot * self.confidence).argmax(1) if targeted else \
(model((adv-0.5)/0.5) + labels_onehot * self.confidence).argmax(1)
is_adv = (predicted_classes == labels) if targeted else (predicted_classes != labels)
is_smaller = l2 < best_l2
o_is_smaller = l2 < o_best_l2
is_both = is_adv * is_smaller
o_is_both = is_adv * o_is_smaller
best_l2[is_both] = l2[is_both]
best_score[is_both] = predicted_classes[is_both]
o_best_l2[o_is_both] = l2[o_is_both]
o_best_score[o_is_both] = predicted_classes[o_is_both]
o_best_adversary[o_is_both] = adv[o_is_both]
# adjust the constant as needed
adv_found = (best_score == labels) if targeted else ((best_score != labels) * (best_score != -1))
upper_bound[adv_found] = torch.min(upper_bound[adv_found], CONST[adv_found])
adv_not_found = ~adv_found
lower_bound[adv_not_found] = torch.max(lower_bound[adv_not_found], CONST[adv_not_found])
is_smaller = upper_bound < 1e9
CONST[is_smaller] = (lower_bound[is_smaller] + upper_bound[is_smaller]) / 2
CONST[(~is_smaller) * adv_not_found] *= 10
# return the best solution found
return o_best_adversary