Skip to content

Commit

Permalink
cmd/compile: devirtualize interface calls with type assertions
Browse files Browse the repository at this point in the history
  • Loading branch information
mateusz834 committed Feb 13, 2025
1 parent ff27d27 commit d3d5a3b
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 9 deletions.
10 changes: 2 additions & 8 deletions src/cmd/compile/internal/devirtualize/devirtualize.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,8 @@ func StaticCall(call *ir.CallExpr) {
}

sel := call.Fun.(*ir.SelectorExpr)
r := ir.StaticValue(sel.X)
if r.Op() != ir.OCONVIFACE {
return
}
recv := r.(*ir.ConvExpr)

typ := recv.X.Type()
if typ.IsInterface() {
typ := ir.StaticType(sel.X)
if typ == nil {
return
}

Expand Down
49 changes: 48 additions & 1 deletion src/cmd/compile/internal/ir/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,37 @@ func IsAddressable(n Node) bool {
return false
}

var Implements = func(t, iface *types.Type) bool {
panic("unreachable")
}

// StaticType is like StaticValue but for types.
func StaticType(n Node) *types.Type {
out, typs := staticValue(n, true)

if out.Op() != OCONVIFACE {
return nil
}

recv := out.(*ConvExpr)

typ := recv.X.Type()
if typ.IsInterface() {
return nil
}

// Make sure that every type assertion that involves interfaes is satisfied.
for _, t := range typs {
if t.IsInterface() {
if !Implements(typ, t) {
return nil
}
}
}

return typ
}

// StaticValue analyzes n to find the earliest expression that always
// evaluates to the same value as n, which might be from an enclosing
// function.
Expand All @@ -855,6 +886,16 @@ func IsAddressable(n Node) bool {
// calling StaticValue on the "int(y)" expression returns the outer
// "g()" expression.
func StaticValue(n Node) Node {
v, t := staticValue(n, false)
if len(t) != 0 {
base.Fatalf("len(t) != 0; len(t) = %v", len(t))
}
return v

}

func staticValue(n Node, forDevirt bool) (Node, []*types.Type) {
typeAssertTypes := []*types.Type{}
for {
switch n1 := n.(type) {
case *ConvExpr:
Expand All @@ -870,11 +911,17 @@ func StaticValue(n Node) Node {
case *ParenExpr:
n = n1.X
continue
case *TypeAssertExpr:
if forDevirt && n1.Op() == ODOTTYPE {
typeAssertTypes = append(typeAssertTypes, n1.Type())
n = n1.X
continue
}
}

n1 := staticValue1(n)
if n1 == nil {
return n
return n, typeAssertTypes
}
n = n1
}
Expand Down
4 changes: 4 additions & 0 deletions src/cmd/compile/internal/typecheck/subr.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,10 @@ func Implements(t, iface *types.Type) bool {
return implements(t, iface, &missing, &have, &ptr)
}

func init() {
ir.Implements = Implements
}

// ImplementsExplain reports whether t implements the interface iface. t can be
// an interface, a type parameter, or a concrete type. If t does not implement
// iface, a non-empty string is returned explaining why.
Expand Down
151 changes: 151 additions & 0 deletions test/escape_iface_with_devirt_type_assertions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// errorcheck -0 -m

// Copyright 2025 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package escape

import (
"crypto/sha256"
"encoding"
"hash"
"io"
)

type M interface{ M() }

type A interface{ A() }

type C interface{ C() }

type Impl struct{}

func (*Impl) M() {} // ERROR "can inline"

func (*Impl) A() {} // ERROR "can inline"

type CImpl struct{}

func (CImpl) C() {} // ERROR "can inline"

func t() {
var a M = &Impl{} // ERROR "&Impl{} does not escape"

a.(M).M() // ERROR "devirtualizing a.\(M\).M" "inlining call"
a.(A).A() // ERROR "devirtualizing a.\(A\).A" "inlining call"
a.(*Impl).M() // ERROR "inlining call"
a.(*Impl).A() // ERROR "inlining call"

v := a.(M)
v.M() // ERROR "devirtualizing v.M" "inlining call"
v.(A).A() // ERROR "devirtualizing v.\(A\).A" "inlining call"
v.(*Impl).A() // ERROR "inlining call"
v.(*Impl).M() // ERROR "inlining call"

v2 := a.(A)
v2.A() // ERROR "devirtualizing v2.A" "inlining call"
v2.(M).M() // ERROR "devirtualizing v2.\(M\).M" "inlining call"
v2.(*Impl).A() // ERROR "inlining call"
v2.(*Impl).M() // ERROR "inlining call"

a.(M).(A).A() // ERROR "devirtualizing a.\(M\).\(A\).A" "inlining call"
a.(A).(M).M() // ERROR "devirtualizing a.\(A\).\(M\).M" "inlining call"

a.(M).(A).(*Impl).A() // ERROR "inlining call"
a.(A).(M).(*Impl).M() // ERROR "inlining call"

{
var a C = &CImpl{} // ERROR "does not escape"
a.(any).(C).C() // ERROR "devirtualizing" "inlining"
a.(any).(*CImpl).C() // ERROR "inlining"
}
}

// TODO: these type assertions could also be devirtualized.
func t2() {
{
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
if v, ok := a.(M); ok {
v.M()
}
}
{
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
if v, ok := a.(A); ok {
v.A()
}
}
{
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
v, ok := a.(M)
if ok {
v.M()
}
}
{
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
v, ok := a.(A)
if ok {
v.A()
}
}
{
var a M = &Impl{} // ERROR "does not escape"
v, ok := a.(*Impl)
if ok {
v.A() // ERROR "inlining"
v.M() // ERROR "inlining"
}
}
{
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
v, _ := a.(M)
v.M()
}
{
var a M = &Impl{} // ERROR "&Impl{} escapes to heap"
v, _ := a.(A)
v.A()
}
{
var a M = &Impl{} // ERROR "does not escape"
v, _ := a.(*Impl)
v.A() // ERROR "inlining"
v.M() // ERROR "inlining"
}
}

//go:noinline
func testInvalidAsserts() {
{
var a M = &Impl{} // ERROR "escapes"
a.(C).C() // this will panic
a.(any).(C).C() // this will panic
}
{
var a C = &CImpl{} // ERROR "escapes"
a.(M).M() // this will panic
a.(any).(M).M() // this will panic
}
{
var a C = &CImpl{} // ERROR "does not escape"

// this will panic
a.(M).(*Impl).M() // ERROR "inlining"

// this will panic
a.(any).(M).(*Impl).M() // ERROR "inlining"
}
}

func testSha256() {
h := sha256.New() // ERROR "inlining call" "does not escape"
h.Write(nil) // ERROR "devirtualizing"
h.(io.Writer).Write(nil) // ERROR "devirtualizing"
h.(hash.Hash).Write(nil) // ERROR "devirtualizing"
h.(encoding.BinaryUnmarshaler).UnmarshalBinary(nil) // ERROR "devirtualizing"

h2 := sha256.New() // ERROR "escapes" "inlining call"
h2.(M).M() // this will panic
}

0 comments on commit d3d5a3b

Please sign in to comment.