Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast pareto-front calculation for 2D #687

Merged
merged 14 commits into from
Dec 21, 2023
83 changes: 69 additions & 14 deletions optuna_dashboard/ts/components/GraphParetoFront.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,72 @@ const makeMarker = (
}
}

const getIsDominatedND = (normalizedValues: number[][]) => {
// Fallback for straight-forward pareto front algorithm (O(N^2) complexity).
const isDominated: boolean[] = []
normalizedValues.forEach((values0: number[]) => {
const dominated = normalizedValues.some((values1: number[]) => {
if (values0.every((value0: number, k: number) => values1[k] === value0)) {
return false
}
return values0.every((value0: number, k: number) => values1[k] <= value0)
})
isDominated.push(dominated)
})
return isDominated
}

const getIsDominated2D = (normalizedValues: number[][]) => {
// Fast pareto front algorithm (O(N log N) complexity).
const sorted = normalizedValues
.map((values, i) => [values[0], values[1], i])
.sort((a, b) =>
a[0] > b[0]
? 1
: a[0] < b[0]
? -1
: a[1] > b[1]
? 1
: a[1] < b[1]
? -1
: 0
)
let maxValueSeen0 = sorted[0][0]
let minValueSeen1 = sorted[0][1]

const isDominated: boolean[] = new Array(normalizedValues.length).fill(false)
sorted.forEach((values) => {
if (
values[1] > minValueSeen1 ||
(values[1] === minValueSeen1 && values[0] > maxValueSeen0)
) {
isDominated[values[2]] = true
} else {
minValueSeen1 = values[1]
}
maxValueSeen0 = values[0]
})
return isDominated
}

const getIsDominated1D = (normalizedValues: number[][]) => {
const best_value = Math.min(...normalizedValues.map((values) => values[0]))
return normalizedValues.map((values) => values[0] !== best_value)
}

const getIsDominated = (normalizedValues: number[][]) => {
if (normalizedValues.length === 0) {
return []
}
if (normalizedValues[0].length === 1) {
return getIsDominated1D(normalizedValues)
} else if (normalizedValues[0].length === 2) {
return getIsDominated2D(normalizedValues)
} else {
return getIsDominatedND(normalizedValues)
}
}

const plotParetoFront = (
study: StudyDetail,
objectiveXId: number,
Expand Down Expand Up @@ -218,22 +284,11 @@ const plotParetoFront = (
}
})

const dominatedTrials: boolean[] = []
normalizedValues.forEach((values0: number[], i: number) => {
const dominated = normalizedValues.some((values1: number[], j: number) => {
if (i === j) {
return false
}
return values0.every((value0: number, k: number) => {
return values1[k] <= value0
})
})
dominatedTrials.push(dominated)
})
const isDominated: boolean[] = getIsDominated(normalizedValues)

const plotData: Partial<plotly.PlotData>[] = [
makeScatterObject(
feasibleTrials.filter((t, i) => dominatedTrials[i]),
feasibleTrials.filter((t, i) => isDominated[i]),
objectiveXId,
objectiveYId,
infeasibleTrials.length === 0
Expand All @@ -244,7 +299,7 @@ const plotParetoFront = (
mode
),
makeScatterObject(
feasibleTrials.filter((t, i) => !dominatedTrials[i]),
feasibleTrials.filter((t, i) => !isDominated[i]),
objectiveXId,
objectiveYId,
"%{text}<extra>Best Trial</extra>",
Expand Down