diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 1d30ffceae77d..12b3ca11faf9d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -1151,7 +1151,7 @@ def build_If(ctx, node): build_stmts(ctx, node.orelse) return node - with ctx.non_static_control_flow_guard(): + with ctx.non_static_if_guard(node): impl.begin_frontend_if(ctx.ast_builder, node.test.ptr) ctx.ast_builder.begin_frontend_if_true() build_stmts(ctx, node.body) @@ -1280,6 +1280,13 @@ def build_Assert(ctx, node): @staticmethod def build_Break(ctx, node): if ctx.is_in_static_for(): + nearest_non_static_if: ast.If = ctx.current_loop_scope( + ).nearest_non_static_if + if nearest_non_static_if: + msg = ctx.get_pos_info(nearest_non_static_if.test) + msg += "You are trying to `break` a static `for` loop, " \ + "but the `break` statement is inside a non-static `if`. " + raise TaichiSyntaxError(msg) ctx.set_loop_status(LoopStatus.Break) else: ctx.ast_builder.insert_break_stmt() @@ -1288,6 +1295,13 @@ def build_Break(ctx, node): @staticmethod def build_Continue(ctx, node): if ctx.is_in_static_for(): + nearest_non_static_if: ast.If = ctx.current_loop_scope( + ).nearest_non_static_if + if nearest_non_static_if: + msg = ctx.get_pos_info(nearest_non_static_if.test) + msg += "You are trying to `continue` a static `for` loop, " \ + "but the `continue` statement is inside a non-static `if`. " + raise TaichiSyntaxError(msg) ctx.set_loop_status(LoopStatus.Continue) else: ctx.ast_builder.insert_continue_stmt() diff --git a/python/taichi/lang/ast/ast_transformer_utils.py b/python/taichi/lang/ast/ast_transformer_utils.py index 7ec6b44228ba9..70474ca394e2a 100644 --- a/python/taichi/lang/ast/ast_transformer_utils.py +++ b/python/taichi/lang/ast/ast_transformer_utils.py @@ -4,6 +4,7 @@ from enum import Enum from sys import version_info from textwrap import TextWrapper +from typing import List from taichi.lang import impl from taichi.lang.exception import (TaichiCompilationError, TaichiNameError, @@ -68,7 +69,7 @@ def __init__(self): class NonStaticControlFlowGuard: - def __init__(self, status): + def __init__(self, status: NonStaticControlFlowStatus): self.status = status def __enter__(self): @@ -89,6 +90,7 @@ class LoopScopeAttribute: def __init__(self, is_static): self.is_static = is_static self.status = LoopStatus.Normal + self.nearest_non_static_if = None class LoopScopeGuard: @@ -107,6 +109,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.non_static_guard.__exit__(exc_type, exc_val, exc_tb) +class NonStaticIfGuard: + def __init__(self, if_node: ast.If, loop_attribute: LoopScopeAttribute, + non_static_status: NonStaticControlFlowStatus): + self.loop_attribute = loop_attribute + self.if_node = if_node + self.non_static_guard = NonStaticControlFlowGuard(non_static_status) + + def __enter__(self): + if self.loop_attribute: + self.old_non_static_if = self.loop_attribute.nearest_non_static_if + self.loop_attribute.nearest_non_static_if = self.if_node + self.non_static_guard.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.loop_attribute: + self.loop_attribute.nearest_non_static_if = self.old_non_static_if + self.non_static_guard.__exit__(exc_type, exc_val, exc_tb) + + class ReturnStatus(Enum): NoReturn = 0 ReturnedVoid = 1 @@ -128,7 +149,7 @@ def __init__(self, is_real_function=False): self.func = func self.local_scopes = [] - self.loop_scopes = [] + self.loop_scopes: List[LoopScopeAttribute] = [] self.excluded_parameters = excluded_parameters self.is_kernel = is_kernel self.arg_features = arg_features @@ -164,6 +185,12 @@ def loop_scope_guard(self, is_static=False): return LoopScopeGuard(self.loop_scopes, self.non_static_control_flow_guard()) + def non_static_if_guard(self, if_node: ast.If): + return NonStaticIfGuard( + if_node, + self.current_loop_scope() if self.loop_scopes else None, + self.non_static_control_flow_status) + def non_static_control_flow_guard(self): return NonStaticControlFlowGuard(self.non_static_control_flow_status) diff --git a/tests/python/test_syntax_errors.py b/tests/python/test_syntax_errors.py index 6ecc89809c49a..e443ffaaafbe7 100644 --- a/tests/python/test_syntax_errors.py +++ b/tests/python/test_syntax_errors.py @@ -328,3 +328,32 @@ def baz(): ti.TaichiSyntaxError, match="Function definition is not allowed in 'ti.func'"): baz() + + +@test_utils.test() +def test_continue_in_static_for_in_non_static_if(): + @ti.kernel + def test_static_loop(): + for i in ti.static(range(5)): + x = 0.1 + if x == 0.0: + continue + + with pytest.raises( + ti.TaichiSyntaxError, + match="You are trying to `continue` a static `for` loop"): + test_static_loop() + + +@test_utils.test() +def test_break_in_static_for_in_non_static_if(): + @ti.kernel + def test_static_loop(): + for i in ti.static(range(5)): + x = 0.1 + if x == 0.0: + break + + with pytest.raises(ti.TaichiSyntaxError, + match="You are trying to `break` a static `for` loop"): + test_static_loop()