diff --git a/lib/rbs.rb b/lib/rbs.rb index 0559fa422..a9bbdc839 100644 --- a/lib/rbs.rb +++ b/lib/rbs.rb @@ -39,6 +39,7 @@ require "rbs/prototype/rbi" require "rbs/prototype/rb" require "rbs/prototype/runtime" +require "rbs/prototype/node_usage" require "rbs/type_name_resolver" require "rbs/environment_walker" require "rbs/vendorer" diff --git a/lib/rbs/prototype/helpers.rb b/lib/rbs/prototype/helpers.rb index 10baa0679..1c36ff18e 100644 --- a/lib/rbs/prototype/helpers.rb +++ b/lib/rbs/prototype/helpers.rb @@ -7,27 +7,43 @@ module Helpers def block_from_body(node) _, args_node, body_node = node.children + _pre_num, _pre_init, _opt, _first_post, _post_num, _post_init, _rest, _kw, _kwrest, block_var = args_from_node(args_node) - _pre_num, _pre_init, _opt, _first_post, _post_num, _post_init, _rest, _kw, _kwrest, block = args_from_node(args_node) + # @type var body_node: node? + if body_node + yields = any_node?(body_node) {|n| n.type == :YIELD } + end + + if yields || block_var + required = true + + if body_node + if any_node?(body_node) {|n| n.type == :FCALL && n.children[0] == :block_given? && !n.children[1] } + required = false + end + end + + if _rest == :* && block_var == :& + # ... is given + required = false + end - method_block = nil + if block_var + if body_node + usage = NodeUsage.new(body_node) + if usage.each_conditional_node.any? {|n| n.type == :LVAR && n.children[0] == block_var } + required = false + end + end + end - if block method_block = Types::Block.new( - required: false, + required: required, type: Types::Function.empty(untyped), self_type: nil ) - end - - if body_node - if (yields = any_node?(body_node) {|n| n.type == :YIELD }) - method_block = Types::Block.new( - required: true, - type: Types::Function.empty(untyped), - self_type: nil - ) + if yields yields.each do |yield_node| array_content = yield_node.children[0]&.children&.compact || [] diff --git a/lib/rbs/prototype/node_usage.rb b/lib/rbs/prototype/node_usage.rb new file mode 100644 index 000000000..2c5f07e0b --- /dev/null +++ b/lib/rbs/prototype/node_usage.rb @@ -0,0 +1,99 @@ +# frozen_string_literal: true + +module RBS + module Prototype + class NodeUsage + include Helpers + + attr_reader :conditional_nodes + + def initialize(node) + @node = node + @conditional_nodes = Set[].compare_by_identity + + calculate(node, conditional: false) + end + + def each_conditional_node(&block) + if block + conditional_nodes.each(&block) + else + conditional_nodes.each + end + end + + def calculate(node, conditional:) + if conditional + conditional_nodes << node + end + + case node.type + when :IF, :UNLESS + cond_node, true_node, false_node = node.children + calculate(cond_node, conditional: true) + calculate(true_node, conditional: conditional) if true_node + calculate(false_node, conditional: conditional) if false_node + when :AND, :OR + left, right = node.children + calculate(left, conditional: true) + calculate(right, conditional: conditional) + when :QCALL + receiver, _, args = node.children + calculate(receiver, conditional: true) + calculate(args, conditional: false) if args + when :WHILE + cond, body = node.children + calculate(cond, conditional: true) + calculate(body, conditional: false) if body + when :OP_ASGN_OR, :OP_ASGN_AND + var, _, asgn = node.children + calculate(var, conditional: true) + calculate(asgn, conditional: conditional) + when :LASGN, :IASGN, :GASGN + _, lhs = node.children + calculate(lhs, conditional: conditional) if lhs + when :MASGN + lhs, _ = node.children + calculate(lhs, conditional: conditional) + when :CDECL + if node.children.size == 2 + _, lhs = node.children + calculate(lhs, conditional: conditional) + else + const, _, lhs = node.children + calculate(const, conditional: false) + calculate(lhs, conditional: conditional) + end + when :SCOPE + _, _, body = node.children + calculate(body, conditional: conditional) + when :CASE2 + _, *branches = node.children + branches.each do |branch| + if branch.type == :WHEN + list, body = branch.children + list.children.each do |child| + if child + calculate(child, conditional: true) + end + end + calculate(body, conditional: conditional) + else + calculate(branch, conditional: conditional) + end + end + when :BLOCK + *nodes, last = node.children + nodes.each do |no| + calculate(no, conditional: false) + end + calculate(last, conditional: conditional) if last + else + each_child(node) do |child| + calculate(child, conditional: false) + end + end + end + end + end +end diff --git a/sig/prototype/node_usage.rbs b/sig/prototype/node_usage.rbs new file mode 100644 index 000000000..6a72f8346 --- /dev/null +++ b/sig/prototype/node_usage.rbs @@ -0,0 +1,20 @@ +module RBS + module Prototype + class NodeUsage + include Helpers + + type node = RubyVM::AbstractSyntaxTree::Node + + attr_reader node: node + + attr_reader conditional_nodes: Set[node] + + def initialize: (node) -> void + + def calculate: (node, conditional: bool) -> void + + def each_conditional_node: () { (node) -> void } -> void + | () -> Enumerator[node, void] + end + end +end diff --git a/test/rbs/node_usage_test.rb b/test/rbs/node_usage_test.rb new file mode 100644 index 000000000..5ded3b572 --- /dev/null +++ b/test/rbs/node_usage_test.rb @@ -0,0 +1,47 @@ +require "test_helper" + +class RBS::NodeUsageTest < Test::Unit::TestCase + include RBS::Prototype + + def parse(string) + RubyVM::AbstractSyntaxTree.parse(string) + end + + def test_conditional + usage = NodeUsage.new(parse(<<~RB)) + if block + yield + end + + foo && bar || baz + + 1&.+(2) + + begin + bar + end while baz + + a ||= b + a += 123 + + x = 1 + @y = 2 + Z = 3 + Z::Z1 = 4 + + x, y, z = foo + + puts unless foo + + case + when foo + else + hello + end + + [ + (foo(); bar; baz) + ] + RB + end +end diff --git a/test/rbs/rb_prototype_test.rb b/test/rbs/rb_prototype_test.rb index abbb575dc..87d119a3f 100644 --- a/test/rbs/rb_prototype_test.rb +++ b/test/rbs/rb_prototype_test.rb @@ -62,7 +62,7 @@ def kw_req(a:) end assert_write parser.decls, <<-EOF class Hello - def hello: (untyped a, ?::Integer b, *untyped c, untyped d, e: untyped, ?f: ::Integer, **untyped g) ?{ () -> untyped } -> nil + def hello: (untyped a, ?::Integer b, *untyped c, untyped d, e: untyped, ?f: ::Integer, **untyped g) { () -> untyped } -> nil def self.world: () { (untyped, untyped, untyped, x: untyped, y: untyped) -> untyped } -> untyped @@ -233,6 +233,38 @@ def when_last_is_nil: () -> nil EOF end + def test_defs_return_type_with_block_optional + parser = RB.new + + rb = <<~'EOR' + class Hello + def with_optional_block1 + # `block_given?` call makes the block optional + if block_given? + yield 1 + end + end + + def with_optional_block2(&block) + # testing block var makes the block optional + if block + yield 1 + end + end + end + EOR + + parser.parse(rb) + + assert_write parser.decls, <<~EOF + class Hello + def with_optional_block1: () ?{ (untyped) -> untyped } -> (untyped | nil) + + def with_optional_block2: () ?{ (untyped) -> untyped } -> (untyped | nil) + end + EOF + end + def test_defs_return_type_with_if parser = RB.new