Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Error] [lang] Add error when breaking/continuing a static for inside non-static if #5755

Merged
merged 7 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ def build_If(ctx, node):
build_stmts(ctx, node.orelse)
return node

with ctx.non_static_control_flow_guard():
with ctx.non_static_if_guard(node):
impl.begin_frontend_if(ctx.ast_builder, node.test.ptr)
ctx.ast_builder.begin_frontend_if_true()
build_stmts(ctx, node.body)
Expand Down Expand Up @@ -1280,6 +1280,13 @@ def build_Assert(ctx, node):
@staticmethod
def build_Break(ctx, node):
if ctx.is_in_static_for():
nearest_non_static_if: ast.If = ctx.current_loop_scope(
).nearest_non_static_if
if nearest_non_static_if:
msg = ctx.get_pos_info(nearest_non_static_if.test)
msg += "You are trying to `break` a static `for` loop, " \
"but the `break` statement is inside a non-static `if`. "
raise TaichiSyntaxError(msg)
ctx.set_loop_status(LoopStatus.Break)
else:
ctx.ast_builder.insert_break_stmt()
Expand All @@ -1288,6 +1295,13 @@ def build_Break(ctx, node):
@staticmethod
def build_Continue(ctx, node):
if ctx.is_in_static_for():
nearest_non_static_if: ast.If = ctx.current_loop_scope(
).nearest_non_static_if
if nearest_non_static_if:
msg = ctx.get_pos_info(nearest_non_static_if.test)
msg += "You are trying to `continue` a static `for` loop, " \
"but the `continue` statement is inside a non-static `if`. "
raise TaichiSyntaxError(msg)
ctx.set_loop_status(LoopStatus.Continue)
else:
ctx.ast_builder.insert_continue_stmt()
Expand Down
31 changes: 29 additions & 2 deletions python/taichi/lang/ast/ast_transformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from sys import version_info
from textwrap import TextWrapper
from typing import List

from taichi.lang import impl
from taichi.lang.exception import (TaichiCompilationError, TaichiNameError,
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self):


class NonStaticControlFlowGuard:
def __init__(self, status):
def __init__(self, status: NonStaticControlFlowStatus):
self.status = status

def __enter__(self):
Expand All @@ -89,6 +90,7 @@ class LoopScopeAttribute:
def __init__(self, is_static):
self.is_static = is_static
self.status = LoopStatus.Normal
self.nearest_non_static_if = None


class LoopScopeGuard:
Expand All @@ -107,6 +109,25 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.non_static_guard.__exit__(exc_type, exc_val, exc_tb)


class NonStaticIfGuard:
def __init__(self, if_node: ast.If, loop_attribute: LoopScopeAttribute,
non_static_status: NonStaticControlFlowStatus):
self.loop_attribute = loop_attribute
self.if_node = if_node
self.non_static_guard = NonStaticControlFlowGuard(non_static_status)

def __enter__(self):
if self.loop_attribute:
self.old_non_static_if = self.loop_attribute.nearest_non_static_if
self.loop_attribute.nearest_non_static_if = self.if_node
self.non_static_guard.__enter__()

def __exit__(self, exc_type, exc_val, exc_tb):
if self.loop_attribute:
self.loop_attribute.nearest_non_static_if = self.old_non_static_if
self.non_static_guard.__exit__(exc_type, exc_val, exc_tb)


class ReturnStatus(Enum):
NoReturn = 0
ReturnedVoid = 1
Expand All @@ -128,7 +149,7 @@ def __init__(self,
is_real_function=False):
self.func = func
self.local_scopes = []
self.loop_scopes = []
self.loop_scopes: List[LoopScopeAttribute] = []
self.excluded_parameters = excluded_parameters
self.is_kernel = is_kernel
self.arg_features = arg_features
Expand Down Expand Up @@ -164,6 +185,12 @@ def loop_scope_guard(self, is_static=False):
return LoopScopeGuard(self.loop_scopes,
self.non_static_control_flow_guard())

def non_static_if_guard(self, if_node: ast.If):
return NonStaticIfGuard(
if_node,
self.current_loop_scope() if self.loop_scopes else None,
self.non_static_control_flow_status)

def non_static_control_flow_guard(self):
return NonStaticControlFlowGuard(self.non_static_control_flow_status)

Expand Down
29 changes: 29 additions & 0 deletions tests/python/test_syntax_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,32 @@ def baz():
ti.TaichiSyntaxError,
match="Function definition is not allowed in 'ti.func'"):
baz()


@test_utils.test()
def test_continue_in_static_for_in_non_static_if():
@ti.kernel
def test_static_loop():
for i in ti.static(range(5)):
x = 0.1
if x == 0.0:
continue

with pytest.raises(
ti.TaichiSyntaxError,
match="You are trying to `continue` a static `for` loop"):
test_static_loop()


@test_utils.test()
def test_break_in_static_for_in_non_static_if():
@ti.kernel
def test_static_loop():
for i in ti.static(range(5)):
x = 0.1
if x == 0.0:
break

with pytest.raises(ti.TaichiSyntaxError,
match="You are trying to `break` a static `for` loop"):
test_static_loop()