Skip to content

Commit 8bde01a

Browse files
authored
Merge pull request #687 from contramundum53/fast-pareto-front
Fast pareto-front calculation for 2D
2 parents bb2a7ed + 1203501 commit 8bde01a

File tree

1 file changed

+69
-14
lines changed

1 file changed

+69
-14
lines changed

optuna_dashboard/ts/components/GraphParetoFront.tsx

+69-14
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,72 @@ const makeMarker = (
165165
}
166166
}
167167

168+
const getIsDominatedND = (normalizedValues: number[][]) => {
169+
// Fallback for straight-forward pareto front algorithm (O(N^2) complexity).
170+
const isDominated: boolean[] = []
171+
normalizedValues.forEach((values0: number[]) => {
172+
const dominated = normalizedValues.some((values1: number[]) => {
173+
if (values0.every((value0: number, k: number) => values1[k] === value0)) {
174+
return false
175+
}
176+
return values0.every((value0: number, k: number) => values1[k] <= value0)
177+
})
178+
isDominated.push(dominated)
179+
})
180+
return isDominated
181+
}
182+
183+
const getIsDominated2D = (normalizedValues: number[][]) => {
184+
// Fast pareto front algorithm (O(N log N) complexity).
185+
const sorted = normalizedValues
186+
.map((values, i) => [values[0], values[1], i])
187+
.sort((a, b) =>
188+
a[0] > b[0]
189+
? 1
190+
: a[0] < b[0]
191+
? -1
192+
: a[1] > b[1]
193+
? 1
194+
: a[1] < b[1]
195+
? -1
196+
: 0
197+
)
198+
let maxValueSeen0 = sorted[0][0]
199+
let minValueSeen1 = sorted[0][1]
200+
201+
const isDominated: boolean[] = new Array(normalizedValues.length).fill(false)
202+
sorted.forEach((values) => {
203+
if (
204+
values[1] > minValueSeen1 ||
205+
(values[1] === minValueSeen1 && values[0] > maxValueSeen0)
206+
) {
207+
isDominated[values[2]] = true
208+
} else {
209+
minValueSeen1 = values[1]
210+
}
211+
maxValueSeen0 = values[0]
212+
})
213+
return isDominated
214+
}
215+
216+
const getIsDominated1D = (normalizedValues: number[][]) => {
217+
const best_value = Math.min(...normalizedValues.map((values) => values[0]))
218+
return normalizedValues.map((values) => values[0] !== best_value)
219+
}
220+
221+
const getIsDominated = (normalizedValues: number[][]) => {
222+
if (normalizedValues.length === 0) {
223+
return []
224+
}
225+
if (normalizedValues[0].length === 1) {
226+
return getIsDominated1D(normalizedValues)
227+
} else if (normalizedValues[0].length === 2) {
228+
return getIsDominated2D(normalizedValues)
229+
} else {
230+
return getIsDominatedND(normalizedValues)
231+
}
232+
}
233+
168234
const plotParetoFront = (
169235
study: StudyDetail,
170236
objectiveXId: number,
@@ -218,22 +284,11 @@ const plotParetoFront = (
218284
}
219285
})
220286

221-
const dominatedTrials: boolean[] = []
222-
normalizedValues.forEach((values0: number[], i: number) => {
223-
const dominated = normalizedValues.some((values1: number[], j: number) => {
224-
if (i === j) {
225-
return false
226-
}
227-
return values0.every((value0: number, k: number) => {
228-
return values1[k] <= value0
229-
})
230-
})
231-
dominatedTrials.push(dominated)
232-
})
287+
const isDominated: boolean[] = getIsDominated(normalizedValues)
233288

234289
const plotData: Partial<plotly.PlotData>[] = [
235290
makeScatterObject(
236-
feasibleTrials.filter((t, i) => dominatedTrials[i]),
291+
feasibleTrials.filter((t, i) => isDominated[i]),
237292
objectiveXId,
238293
objectiveYId,
239294
infeasibleTrials.length === 0
@@ -244,7 +299,7 @@ const plotParetoFront = (
244299
mode
245300
),
246301
makeScatterObject(
247-
feasibleTrials.filter((t, i) => !dominatedTrials[i]),
302+
feasibleTrials.filter((t, i) => !isDominated[i]),
248303
objectiveXId,
249304
objectiveYId,
250305
"%{text}<extra>Best Trial</extra>",

0 commit comments

Comments
 (0)