Skip to content

Commit

Permalink
Introduce new heuristics for block types
Browse files Browse the repository at this point in the history
  • Loading branch information
soutaro committed Dec 19, 2022
1 parent e91be72 commit cb3aaa1
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 14 deletions.
1 change: 1 addition & 0 deletions lib/rbs.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
require "rbs/prototype/rbi"
require "rbs/prototype/rb"
require "rbs/prototype/runtime"
require "rbs/prototype/node_usage"
require "rbs/type_name_resolver"
require "rbs/environment_walker"
require "rbs/vendorer"
Expand Down
42 changes: 29 additions & 13 deletions lib/rbs/prototype/helpers.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,43 @@ module Helpers

def block_from_body(node)
_, args_node, body_node = node.children
_pre_num, _pre_init, _opt, _first_post, _post_num, _post_init, _rest, _kw, _kwrest, block_var = args_from_node(args_node)

_pre_num, _pre_init, _opt, _first_post, _post_num, _post_init, _rest, _kw, _kwrest, block = args_from_node(args_node)
# @type var body_node: node?
if body_node
yields = any_node?(body_node) {|n| n.type == :YIELD }
end

if yields || block_var
required = true

if body_node
if any_node?(body_node) {|n| n.type == :FCALL && n.children[0] == :block_given? && !n.children[1] }
required = false
end
end

if _rest == :* && block_var == :&
# ... is given
required = false
end

method_block = nil
if block_var
if body_node
usage = NodeUsage.new(body_node)
if usage.each_conditional_node.any? {|n| n.type == :LVAR && n.children[0] == block_var }
required = false
end
end
end

if block
method_block = Types::Block.new(
required: false,
required: required,
type: Types::Function.empty(untyped),
self_type: nil
)
end

if body_node
if (yields = any_node?(body_node) {|n| n.type == :YIELD })
method_block = Types::Block.new(
required: true,
type: Types::Function.empty(untyped),
self_type: nil
)

if yields
yields.each do |yield_node|
array_content = yield_node.children[0]&.children&.compact || []

Expand Down
97 changes: 97 additions & 0 deletions lib/rbs/prototype/node_usage.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
module RBS
module Prototype
class NodeUsage
include Helpers

attr_reader :conditional_nodes

def initialize(node)
@node = node
@conditional_nodes = Set[].compare_by_identity

calculate(node, conditional: false)
end

def each_conditional_node(&block)
if block
conditional_nodes.each(&block)
else
conditional_nodes.each
end
end

def calculate(node, conditional:)
if conditional
conditional_nodes << node
end

case node.type
when :IF, :UNLESS
cond_node, true_node, false_node = node.children
calculate(cond_node, conditional: true)
calculate(true_node, conditional: conditional) if true_node
calculate(false_node, conditional: conditional) if false_node
when :AND, :OR
left, right = node.children
calculate(left, conditional: true)
calculate(right, conditional: conditional)
when :QCALL
receiver, _, args = node.children
calculate(receiver, conditional: true)
calculate(args, conditional: false) if args
when :WHILE
cond, body = node.children
calculate(cond, conditional: true)
calculate(body, conditional: false) if body
when :OP_ASGN_OR, :OP_ASGN_AND
var, _, asgn = node.children
calculate(var, conditional: true)
calculate(asgn, conditional: conditional)
when :LASGN, :IASGN, :GASGN
_, lhs = node.children
calculate(lhs, conditional: conditional) if lhs
when :MASGN
lhs, _ = node.children
calculate(lhs, conditional: conditional)
when :CDECL
if node.children.size == 2
name, lhs = node.children
calculate(lhs, conditional: conditional)
else
const, rhs, lhs = node.children
calculate(const, conditional: false)
calculate(lhs, conditional: conditional)
end
when :SCOPE
_, _, body = node.children
calculate(body, conditional: conditional)
when :CASE2
_, *branches = node.children
branches.each do |branch|
if branch.type == :WHEN
list, body = branch.children
list.children.each do |child|
if child
calculate(child, conditional: true)
end
end
calculate(body, conditional: conditional)
else
calculate(branch, conditional: conditional)
end
end
when :BLOCK
*nodes, last = node.children
nodes.each do |no|
calculate(no, conditional: false)
end
calculate(last, conditional: conditional) if last
else
each_child(node) do |child|
calculate(child, conditional: false)
end
end
end
end
end
end
20 changes: 20 additions & 0 deletions sig/prototype/node_usage.rbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module RBS
module Prototype
class NodeUsage
include Helpers

type node = RubyVM::AbstractSyntaxTree::Node

attr_reader node: node

attr_reader conditional_nodes: Set[node]

def initialize: (node) -> void

def calculate: (node, conditional: bool) -> void

def each_conditional_node: () { (node) -> void } -> void
| () -> Enumerator[node, void]
end
end
end
47 changes: 47 additions & 0 deletions test/rbs/node_usage_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
require "test_helper"

class RBS::NodeUsageTest < Test::Unit::TestCase
include RBS::Prototype

def parse(string)
RubyVM::AbstractSyntaxTree.parse(string)
end

def test_conditional
usage = NodeUsage.new(parse(<<~RB))
if block
yield
end
foo && bar || baz
1&.+(2)
begin
bar
end while baz
a ||= b
a += 123
x = 1
@y = 2
Z = 3
Z::Z1 = 4
x, y, z = foo
puts unless foo
case
when foo
else
hello
end
[
(foo(); bar; baz)
]
RB
end
end
34 changes: 33 additions & 1 deletion test/rbs/rb_prototype_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def kw_req(a:) end

assert_write parser.decls, <<-EOF
class Hello
def hello: (untyped a, ?::Integer b, *untyped c, untyped d, e: untyped, ?f: ::Integer, **untyped g) ?{ () -> untyped } -> nil
def hello: (untyped a, ?::Integer b, *untyped c, untyped d, e: untyped, ?f: ::Integer, **untyped g) { () -> untyped } -> nil
def self.world: () { (untyped, untyped, untyped, x: untyped, y: untyped) -> untyped } -> untyped
Expand Down Expand Up @@ -233,6 +233,38 @@ def when_last_is_nil: () -> nil
EOF
end

def test_defs_return_type_with_block_optional
parser = RB.new

rb = <<~'EOR'
class Hello
def with_optional_block1
# `block_given?` call makes the block optional
if block_given?
yield 1
end
end
def with_optional_block2(&block)
# testing block var makes the block optional
if block
yield 1
end
end
end
EOR

parser.parse(rb)

assert_write parser.decls, <<~EOF
class Hello
def with_optional_block1: () ?{ (untyped) -> untyped } -> (untyped | nil)
def with_optional_block2: () ?{ (untyped) -> untyped } -> (untyped | nil)
end
EOF
end

def test_defs_return_type_with_if
parser = RB.new

Expand Down

0 comments on commit cb3aaa1

Please sign in to comment.