Skip to content

Commit 8dbe32a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent e2fbcc8 commit 8dbe32a

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

python/taichi/lang/util.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -166,16 +166,16 @@ def to_pytorch_type(dt):
166166
return torch.uint8
167167
if dt == f16:
168168
return torch.float16
169-
169+
170170
if dt in (u16, u32, u64):
171-
if hasattr(torch, "uint16"):
172-
if dt == u16:
173-
return torch.uint16
174-
if dt == u32:
175-
return torch.uint32
176-
if dt == u64:
177-
return torch.uint64
178-
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
171+
if hasattr(torch, "uint16"):
172+
if dt == u16:
173+
return torch.uint16
174+
if dt == u32:
175+
return torch.uint32
176+
if dt == u64:
177+
return torch.uint64
178+
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
179179

180180
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
181181
assert False
@@ -276,18 +276,17 @@ def to_taichi_type(dt):
276276
return u8
277277
if dt == torch.float16:
278278
return f16
279-
279+
280280
if hasattr(torch, "uint16"):
281-
if dt == torch.uint16:
282-
return u16
283-
if dt == torch.uint32:
284-
return u32
285-
if dt == torch.uint64:
286-
return u64
281+
if dt == torch.uint16:
282+
return u16
283+
if dt == torch.uint32:
284+
return u32
285+
if dt == torch.uint64:
286+
return u64
287287

288288
raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
289289

290-
291290
if has_paddle():
292291
import paddle # pylint: disable=C0415
293292

0 commit comments

Comments
 (0)