Commit 8dbe32a 1 parent e2fbcc8 commit 8dbe32a Copy full SHA for 8dbe32a
File tree 1 file changed +16
-17
lines changed
1 file changed +16
-17
lines changed Original file line number Diff line number Diff line change @@ -166,16 +166,16 @@ def to_pytorch_type(dt):
166
166
return torch .uint8
167
167
if dt == f16 :
168
168
return torch .float16
169
-
169
+
170
170
if dt in (u16 , u32 , u64 ):
171
- if hasattr (torch , "uint16" ):
172
- if dt == u16 :
173
- return torch .uint16
174
- if dt == u32 :
175
- return torch .uint32
176
- if dt == u64 :
177
- return torch .uint64
178
- raise RuntimeError (f"PyTorch doesn't support { dt .to_string ()} data type before version 2.3.0." )
171
+ if hasattr (torch , "uint16" ):
172
+ if dt == u16 :
173
+ return torch .uint16
174
+ if dt == u32 :
175
+ return torch .uint32
176
+ if dt == u64 :
177
+ return torch .uint64
178
+ raise RuntimeError (f"PyTorch doesn't support { dt .to_string ()} data type before version 2.3.0." )
179
179
180
180
raise RuntimeError (f"PyTorch doesn't support { dt .to_string ()} data type." )
181
181
assert False
@@ -276,18 +276,17 @@ def to_taichi_type(dt):
276
276
return u8
277
277
if dt == torch .float16 :
278
278
return f16
279
-
279
+
280
280
if hasattr (torch , "uint16" ):
281
- if dt == torch .uint16 :
282
- return u16
283
- if dt == torch .uint32 :
284
- return u32
285
- if dt == torch .uint64 :
286
- return u64
281
+ if dt == torch .uint16 :
282
+ return u16
283
+ if dt == torch .uint32 :
284
+ return u32
285
+ if dt == torch .uint64 :
286
+ return u64
287
287
288
288
raise RuntimeError (f"PyTorch doesn't support { dt .to_string ()} data type before version 2.3.0." )
289
289
290
-
291
290
if has_paddle ():
292
291
import paddle # pylint: disable=C0415
293
292
You can’t perform that action at this time.
0 commit comments