Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve query input support #156

Merged
9 changes: 9 additions & 0 deletions ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,15 @@ func (body Body) Vars(skipClosures bool) VarSet {
return vis.vars
}

// NewExpr returns a new Expr object.
func NewExpr(terms interface{}) *Expr {
return &Expr{
Negated: false,
Terms: terms,
Index: 0,
}
}

// Complement returns a copy of this expression with the negation flag flipped.
func (expr *Expr) Complement() *Expr {
cpy := *expr
Expand Down
7 changes: 7 additions & 0 deletions ast/term.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ type Term struct {
Location *Location `json:"-"` // the location of the Term in the source
}

// NewTerm returns a new Term object.
func NewTerm(v Value) *Term {
return &Term{
Value: v,
}
}

// Equal returns true if this term equals the other term. Equality is
// defined for each kind of term.
func (term *Term) Equal(other *Term) bool {
Expand Down
8 changes: 8 additions & 0 deletions ast/varset.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ func (s VarSet) Diff(vs VarSet) VarSet {
return r
}

// Equal returns true if s contains exactly the same elements as vs.
func (s VarSet) Equal(vs VarSet) bool {
if len(s.Diff(vs)) > 0 {
return false
}
return len(vs.Diff(s)) == 0
}

// Intersect returns a VarSet containing variables in s that are in vs.
func (s VarSet) Intersect(vs VarSet) VarSet {
r := VarSet{}
Expand Down
12 changes: 12 additions & 0 deletions ast/visit.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,18 @@ func WalkRefs(x interface{}, f func(Ref) bool) {
Walk(vis, x)
}

// WalkVars calls the function f on all vars under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkVars(x interface{}, f func(Var) bool) {
vis := &GenericVisitor{func(x interface{}) bool {
if v, ok := x.(Var); ok {
return f(v)
}
return false
}}
Walk(vis, x)
}

// WalkBodies calls the function f on all bodies under x. If the function f
// returns true, AST nodes under the last node will not be visited.
func WalkBodies(x interface{}, f func(Body) bool) {
Expand Down
13 changes: 13 additions & 0 deletions ast/visit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,16 @@ func TestVisitor(t *testing.T) {
}

}

func TestWalkVars(t *testing.T) {
x := MustParseBody("x = 1, data.abc[2] = y, y[z] = [q | q = 1]")
found := NewVarSet()
WalkVars(x, func(v Var) bool {
found.Add(v)
return false
})
expected := NewVarSet(Var("x"), Var("data"), Var("y"), Var("z"), Var("q"), Var("eq"))
if !expected.Equal(found) {
t.Fatalf("Expected %v but got: %v", expected, found)
}
}
60 changes: 54 additions & 6 deletions repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,46 @@ func (r *REPL) loadCompiler() (*ast.Compiler, error) {
return compiler, nil
}

// loadGlobals returns the globals mapping currently defined in the REPL. The
// REPL loads globals from the data.repl.globals document.
func (r *REPL) loadGlobals(compiler *ast.Compiler) (*ast.ValueMap, error) {

params := topdown.NewQueryParams(compiler, r.store, r.txn, nil, ast.MustParseRef("data.repl.globals"))

result, err := topdown.Query(params)
if err != nil {
return nil, err
}

if result.Undefined() {
return nil, nil
}

pairs := [][2]*ast.Term{}

obj, ok := result[0].Result.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("globals is %T but expected object", result)
}

for k, v := range obj {

gk, err := ast.ParseTerm(k)
if err != nil {
return nil, err
}

gv, err := ast.InterfaceToValue(v)
if err != nil {
return nil, err
}

pairs = append(pairs, [...]*ast.Term{gk, &ast.Term{Value: gv}})
}

return topdown.MakeGlobals(pairs)
}

func (r *REPL) evalStatement(stmt interface{}) bool {
switch s := stmt.(type) {
case ast.Body:
Expand All @@ -470,7 +510,12 @@ func (r *REPL) evalStatement(stmt interface{}) bool {
fmt.Fprintln(r.output, "error:", err)
return false
}
return r.evalBody(compiler, body)
globals, err := r.loadGlobals(compiler)
if err != nil {
fmt.Fprintln(r.output, "error:", err)
return false
}
return r.evalBody(compiler, globals, body)
case *ast.Rule:
if err := r.compileRule(s); err != nil {
fmt.Fprintln(r.output, "error:", err)
Expand All @@ -483,22 +528,23 @@ func (r *REPL) evalStatement(stmt interface{}) bool {
return false
}

func (r *REPL) evalBody(compiler *ast.Compiler, body ast.Body) bool {
func (r *REPL) evalBody(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) bool {

// Special case for positive, single term inputs.
if len(body) == 1 {
expr := body[0]
if !expr.Negated {
if _, ok := expr.Terms.(*ast.Term); ok {
if singleValue(body) {
return r.evalTermSingleValue(compiler, body)
return r.evalTermSingleValue(compiler, globals, body)
}
return r.evalTermMultiValue(compiler, body)
return r.evalTermMultiValue(compiler, globals, body)
}
}
}

ctx := topdown.NewContext(body, compiler, r.store, r.txn)
ctx.Globals = globals

var buf *topdown.BufferTracer

Expand Down Expand Up @@ -610,13 +656,14 @@ func (r *REPL) evalPackage(p *ast.Package) bool {
// and comprehensions always evaluate to a single value. To handle references, this function
// still executes the query, except it does so by rewriting the body to assign the term
// to a variable. This allows the REPL to obtain the result even if the term is false.
func (r *REPL) evalTermSingleValue(compiler *ast.Compiler, body ast.Body) bool {
func (r *REPL) evalTermSingleValue(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) bool {

term := body[0].Terms.(*ast.Term)
outputVar := ast.Wildcard
body = ast.NewBody(ast.Equality.Expr(term, outputVar))

ctx := topdown.NewContext(body, compiler, r.store, r.txn)
ctx.Globals = globals

var buf *topdown.BufferTracer

Expand Down Expand Up @@ -656,7 +703,7 @@ func (r *REPL) evalTermSingleValue(compiler *ast.Compiler, body ast.Body) bool {

// evalTermMultiValue evaluates and prints terms in cases where the term may evaluate to multiple
// ground values, e.g., a[i], [servers[x]], etc.
func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, body ast.Body) bool {
func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, globals *ast.ValueMap, body ast.Body) bool {

// Mangle the expression in the same way we do for evalTermSingleValue. When handling the
// evaluation result below, we will ignore this variable.
Expand All @@ -665,6 +712,7 @@ func (r *REPL) evalTermMultiValue(compiler *ast.Compiler, body ast.Body) bool {
body = ast.NewBody(ast.Equality.Expr(term, outputVar))

ctx := topdown.NewContext(body, compiler, r.store, r.txn)
ctx.Globals = globals

var buf *topdown.BufferTracer

Expand Down
21 changes: 21 additions & 0 deletions repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,27 @@ func TestEvalBodyContainingWildCards(t *testing.T) {

}

func TestEvalBodyGlobals(t *testing.T) {
store := newTestStore()
var buffer bytes.Buffer
repl := newRepl(store, &buffer)

repl.OneShot("package repl")
repl.OneShot(`globals["foo.bar"] = "hello" :- true`)
repl.OneShot(`globals["baz"] = data.a[0].b.c[2] :- true`)
repl.OneShot("package test")
repl.OneShot("import foo.bar")
repl.OneShot("import baz")
repl.OneShot(`p :- bar = "hello", baz = false`)

repl.OneShot("p")

result := buffer.String()
if result != "true\n" {
t.Fatalf("expected true but got: %v", result)
}
}

func TestEvalImport(t *testing.T) {
store := newTestStore()
var buffer bytes.Buffer
Expand Down
Loading