Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

eval: optimize iteration #7327

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 135 additions & 2 deletions v1/rego/rego_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"strconv"
"testing"

"github.com/open-policy-agent/opa/internal/runtime"
Expand Down Expand Up @@ -175,16 +176,148 @@ func BenchmarkAciTestOnlyEval(b *testing.B) {
b.ReportAllocs()

for i := 0; i < b.N; i++ {

res, err := pq.Eval(ctx, EvalParsedInput(input.Value))
if err != nil {
b.Fatal(err)
}

_ = res
}
}

// BenchmarkArrayIteration-10
// 15574 77121 ns/op 67249 B/op 1115 allocs/op // handleErr wrapping, not inlined
// 33862 35864 ns/op 5768 B/op 93 allocs/op // handleErr only on error, inlined
func BenchmarkArrayIteration(b *testing.B) {
ctx := context.Background()

at := make([]*ast.Term, 512)
for i := 0; i < 511; i++ {
at[i] = ast.StringTerm("a")
}
at[511] = ast.StringTerm("v")

input := ast.NewObject(ast.Item(ast.StringTerm("foo"), ast.ArrayTerm(at...)))
module := ast.MustParseModule(`package test

default r := false

r if input.foo[_] == "v"`)

r := New(Query("data.test.r = x"), ParsedModule(module))

pq, err := r.PrepareForEval(ctx)
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
res, err := pq.Eval(ctx, EvalParsedInput(input))
if err != nil {
b.Fatal(err)
}

if res == nil {
b.Fatal("expected result")
}

if res[0].Bindings["x"].(bool) != true {
b.Fatalf("expected true, got %v", res[0].Bindings["x"])
}
}
}

// BenchmarkSetIteration-10
// 4800 272403 ns/op 80875 B/op 1193 allocs/op // handleErr wrapping, not inlined
// 4933 223234 ns/op 76772 B/op 681 allocs/op // handleErr only on error, not inlined
func BenchmarkSetIteration(b *testing.B) {
ctx := context.Background()

at := make([]*ast.Term, 512)
for i := 0; i < 512; i++ {
at[i] = ast.StringTerm(strconv.Itoa(i))
}

input := ast.NewObject(ast.Item(ast.StringTerm("foo"), ast.ArrayTerm(at...)))
module := ast.MustParseModule(`package test

s := {x | x := input.foo[_]}

default r := false

r if s[_] == "not found"`)

r := New(Query("data.test.r = x"), ParsedModule(module))

pq, err := r.PrepareForEval(ctx)
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
res, err := pq.Eval(ctx, EvalParsedInput(input))
if err != nil {
b.Fatal(err)
}
if res == nil {
b.Fatal("expected result")
}
if res[0].Bindings["x"].(bool) != false {
b.Fatalf("expected false, got %v", res[0].Bindings["x"])
}
}
}

// BenchmarkObjectIteration-10
// 12067 99582 ns/op 72830 B/op 1126 allocs/op // handleErr wrapping, not inlined
// 15358 85080 ns/op 27752 B/op 615 allocs/op // handleErr only on error, not inlined
func BenchmarkObjectIteration(b *testing.B) {
ctx := context.Background()

at := make([][2]*ast.Term, 512)
for i := 0; i < 512; i++ {
at[i] = ast.Item(ast.StringTerm(strconv.Itoa(i)), ast.StringTerm(strconv.Itoa(i)))
}

input := ast.NewObject(ast.Item(ast.StringTerm("foo"), ast.ObjectTerm(at...)))
module := ast.MustParseModule(`package test

default r := false

r if {
input.foo[_] == "512"
}
`)

r := New(Query("data.test.r = x"), ParsedModule(module))

pq, err := r.PrepareForEval(ctx)
if err != nil {
b.Fatal(err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
res, err := pq.Eval(ctx, EvalParsedInput(input))
if err != nil {
b.Fatal(err)
}
if res == nil {
b.Fatal("expected result")
}
if res[0].Bindings["x"].(bool) != false {
b.Fatalf("expected false, got %v", res[0].Bindings["x"])
}
}
}

func mustReadFileAsString(b *testing.B, path string) string {
b.Helper()

Expand Down
49 changes: 38 additions & 11 deletions v1/topdown/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -3650,28 +3650,55 @@ func (e evalTerm) enumerate(iter unifyIterator) error {

switch v := e.term.Value.(type) {
case *ast.Array:
// Note(anders):
// For this case (e.g. input.foo[_]), we can avoid the (quite expensive) overhead of a callback
// function literal escaping to the heap in each iteration by inlining the biunification logic,
// meaning a 10x reduction in both the number of allocations made as well as the memory consumed.
// It is possible that such inlining could be done for the set/object cases as well, and that's
// worth looking into later, as I imagine set iteration in particular would be an even greater
// win across most policies. Those cases are however much more complex, as we need to deal with
// any type on either side, not just int/var as is the case here.
for i := 0; i < v.Len(); i++ {
k := ast.InternedIntNumberTerm(i)
if err := handleErr(e.e.biunify(k, e.ref[e.pos], e.bindings, e.bindings, func() error {
return e.next(iter, k)
})); err != nil {
return err
a := ast.InternedIntNumberTerm(i)
b := e.ref[e.pos]

if _, ok := b.Value.(ast.Var); ok {
if e.e.traceEnabled {
e.e.traceUnify(a, b)
}
var undo undo
b, e.bindings = e.bindings.apply(b)
e.bindings.bind(b, a, e.bindings, &undo)

err := e.next(iter, a)
undo.Undo()
Comment on lines +3669 to +3674
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this supposed to do? Sorry not familiar with how this is expected to work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! The code inlined is from here: https://github.com/open-policy-agent/opa/blob/main/v1/topdown/eval.go#L1155-L1160 I can't claim to fully understand how bindings work, but if I am to guess, it resets the binding in scope for the next iteration?

if err != nil {
if err := handleErr(err); err != nil {
return err
}
}
}
}
case ast.Object:
for _, k := range v.Keys() {
if err := handleErr(e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error {
err := e.e.biunify(k, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.termbindings.Plug(k))
})); err != nil {
return err
})
if err != nil {
if err := handleErr(err); err != nil {
return err
}
}
}
case ast.Set:
for _, elem := range v.Slice() {
if err := handleErr(e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error {
err := e.e.biunify(elem, e.ref[e.pos], e.termbindings, e.bindings, func() error {
return e.next(iter, e.termbindings.Plug(elem))
})); err != nil {
return err
})
if err != nil {
if err := handleErr(err); err != nil {
return err
}
}
}
}
Expand Down
Loading