Skip to content

Commit 0651887

Browse files
Kiyosorafacebook-github-bot
authored andcommittedJul 10, 2020
Improve repr for torch.iinfo & torch.finfo (pytorch#40488)
Summary: - fix pytorch#39991 - Include directly `min`/`max`/`eps`/`tiny` values in repr of `torch.iinfo` & `torch.finfo` for inspection - Use `torch.float16` / `torch.int16` instead of uncorrespond names `Half` / `Short` - The improved repr is shown just like: ``` >>> torch.iinfo(torch.int8) iinfo(type=torch.int8, max=127, min=-128) >>> torch.iinfo(torch.int16) iinfo(type=torch.int16, max=32767, min=-32768) >>> torch.iinfo(torch.int32) iinfo(type=torch.int32, max=2.14748e+09, min=-2.14748e+09) >>> torch.iinfo(torch.int64) iinfo(type=torch.int64, max=9.22337e+18, min=-9.22337e+18) >>> torch.finfo(torch.float16) finfo(type=torch.float16, eps=0.000976563, max=65504, min=-65504, tiny=6.10352e-05) >>> torch.finfo(torch.float32) finfo(type=torch.float32, eps=1.19209e-07, max=3.40282e+38, min=-3.40282e+38, tiny=1.17549e-38) >>> torch.finfo(torch.float64) finfo(type=torch.float64, eps=2.22045e-16, max=1.79769e+308, min=-1.79769e+308, tiny=2.22507e-308) ``` Pull Request resolved: pytorch#40488 Differential Revision: D22445301 Pulled By: mruberry fbshipit-source-id: 552af9904c423006084b45d6c4adfb4b5689db54
1 parent cb6c352 commit 0651887

File tree

6 files changed

+91
-27
lines changed

6 files changed

+91
-27
lines changed
 

‎docs/source/type_info.rst

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ eps float The smallest representable number such that ``1.0 + eps != 1
2727
max float The largest representable number.
2828
min float The smallest representable number (typically ``-max``).
2929
tiny float The smallest positive representable number.
30+
resolution float The approximate decimal resolution of this type, i.e., ``10**-precision``.
3031
========= ===== ========================================
3132

3233
.. note::

‎test/test_type_info.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,30 @@
1414
class TestDTypeInfo(TestCase):
1515

1616
def test_invalid_input(self):
17-
for dtype in [torch.float32, torch.float64]:
17+
for dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
1818
with self.assertRaises(TypeError):
1919
_ = torch.iinfo(dtype)
2020

21-
for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
21+
for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool]:
2222
with self.assertRaises(TypeError):
2323
_ = torch.finfo(dtype)
2424

2525
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2626
def test_iinfo(self):
27-
for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
27+
for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8]:
2828
x = torch.zeros((2, 2), dtype=dtype)
2929
xinfo = torch.iinfo(x.dtype)
3030
xn = x.cpu().numpy()
3131
xninfo = np.iinfo(xn.dtype)
3232
self.assertEqual(xinfo.bits, xninfo.bits)
3333
self.assertEqual(xinfo.max, xninfo.max)
3434
self.assertEqual(xinfo.min, xninfo.min)
35+
self.assertEqual(xinfo.dtype, xninfo.dtype)
3536

3637
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3738
def test_finfo(self):
3839
initial_default_type = torch.get_default_dtype()
39-
for dtype in [torch.float32, torch.float64]:
40+
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
4041
x = torch.zeros((2, 2), dtype=dtype)
4142
xinfo = torch.finfo(x.dtype)
4243
xn = x.cpu().numpy()
@@ -46,8 +47,25 @@ def test_finfo(self):
4647
self.assertEqual(xinfo.min, xninfo.min)
4748
self.assertEqual(xinfo.eps, xninfo.eps)
4849
self.assertEqual(xinfo.tiny, xninfo.tiny)
49-
torch.set_default_dtype(dtype)
50-
self.assertEqual(torch.finfo(dtype), torch.finfo())
50+
self.assertEqual(xinfo.resolution, xninfo.resolution)
51+
self.assertEqual(xinfo.dtype, xninfo.dtype)
52+
if not dtype.is_complex:
53+
torch.set_default_dtype(dtype)
54+
self.assertEqual(torch.finfo(dtype), torch.finfo())
55+
56+
# Special test case for BFloat16 type
57+
x = torch.zeros((2, 2), dtype=torch.bfloat16)
58+
xinfo = torch.finfo(x.dtype)
59+
self.assertEqual(xinfo.bits, 16)
60+
self.assertEqual(xinfo.max, 3.38953e+38)
61+
self.assertEqual(xinfo.min, -3.38953e+38)
62+
self.assertEqual(xinfo.eps, 0.0078125)
63+
self.assertEqual(xinfo.tiny, 1.17549e-38)
64+
self.assertEqual(xinfo.resolution, 0.01)
65+
self.assertEqual(xinfo.dtype, "bfloat16")
66+
torch.set_default_dtype(x.dtype)
67+
self.assertEqual(torch.finfo(x.dtype), torch.finfo())
68+
5169
# Restore the default type to ensure that the test has no side effect
5270
torch.set_default_dtype(initial_default_type)
5371

‎torch/_C/__init__.pyi.in

+3
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class iinfo:
5959
bits: _int
6060
min: _int
6161
max: _int
62+
dtype: str
6263

6364
def __init__(self, dtype: _dtype) -> None: ...
6465

@@ -68,6 +69,8 @@ class finfo:
6869
max: _float
6970
eps: _float
7071
tiny: _float
72+
resolution: _float
73+
dtype: str
7174

7275
@overload
7376
def __init__(self, dtype: _dtype) -> None: ...

‎torch/csrc/TypeInfo.cpp

+60-20
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <torch/csrc/utils/python_arg_parser.h>
77
#include <torch/csrc/utils/python_numbers.h>
88
#include <torch/csrc/utils/python_strings.h>
9+
#include <torch/csrc/utils/tensor_dtypes.h>
910

1011
#include <c10/util/Exception.h>
1112

@@ -20,7 +21,7 @@ PyObject* THPFInfo_New(const at::ScalarType& type) {
2021
if (!self)
2122
throw python_error();
2223
auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
23-
self_->type = type;
24+
self_->type = c10::toValueType(type);
2425
return self.release();
2526
}
2627

@@ -34,18 +35,6 @@ PyObject* THPIInfo_New(const at::ScalarType& type) {
3435
return self.release();
3536
}
3637

37-
PyObject* THPFInfo_str(THPFInfo* self) {
38-
std::ostringstream oss;
39-
oss << "finfo(type=" << self->type << ")";
40-
return THPUtils_packString(oss.str().c_str());
41-
}
42-
43-
PyObject* THPIInfo_str(THPIInfo* self) {
44-
std::ostringstream oss;
45-
oss << "iinfo(type=" << self->type << ")";
46-
return THPUtils_packString(oss.str().c_str());
47-
}
48-
4938
PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
5039
HANDLE_TH_ERRORS
5140
static torch::PythonArgParser parser({
@@ -63,7 +52,7 @@ PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
6352
AT_ASSERT(at::isFloatingType(scalar_type));
6453
} else {
6554
scalar_type = r.scalartype(0);
66-
if (!at::isFloatingType(scalar_type)) {
55+
if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) {
6756
return PyErr_Format(
6857
PyExc_TypeError,
6958
"torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
@@ -123,7 +112,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
123112
}
124113

125114
static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
126-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf,
115+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16,
127116
self->type, "epsilon", [] {
128117
return PyFloat_FromDouble(
129118
std::numeric_limits<
@@ -132,20 +121,20 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
132121
}
133122

134123
static PyObject* THPFInfo_max(THPFInfo* self, void*) {
135-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "max", [] {
124+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
136125
return PyFloat_FromDouble(
137126
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
138127
});
139128
}
140129

141130
static PyObject* THPFInfo_min(THPFInfo* self, void*) {
142-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] {
131+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
143132
return PyFloat_FromDouble(
144133
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
145134
});
146135
}
147136

148-
static PyObject* THPIInfo_max(THPFInfo* self, void*) {
137+
static PyObject* THPIInfo_max(THPIInfo* self, void*) {
149138
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
150139
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
151140
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
@@ -157,7 +146,7 @@ static PyObject* THPIInfo_max(THPFInfo* self, void*) {
157146
});
158147
}
159148

160-
static PyObject* THPIInfo_min(THPFInfo* self, void*) {
149+
static PyObject* THPIInfo_min(THPIInfo* self, void*) {
161150
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
162151
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
163152
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
@@ -169,19 +158,69 @@ static PyObject* THPIInfo_min(THPFInfo* self, void*) {
169158
});
170159
}
171160

161+
static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
162+
std::string primary_name, legacy_name;
163+
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type);
164+
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype", [primary_name] {
165+
return PyUnicode_FromString((char*)primary_name.data());
166+
});
167+
}
168+
172169
static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
173-
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] {
170+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
174171
return PyFloat_FromDouble(
175172
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
176173
});
177174
}
178175

176+
static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
177+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
178+
return PyFloat_FromDouble(
179+
std::pow(10, -std::numeric_limits<at::scalar_value_type<scalar_t>::type>::digits10));
180+
});
181+
}
182+
183+
static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
184+
std::string primary_name, legacy_name;
185+
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type);
186+
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "dtype", [primary_name] {
187+
return PyUnicode_FromString((char*)primary_name.data());
188+
});
189+
}
190+
191+
PyObject* THPFInfo_str(THPFInfo* self) {
192+
std::ostringstream oss;
193+
oss << "finfo(resolution=" << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
194+
oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
195+
oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr));
196+
oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr));
197+
oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
198+
oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")";
199+
200+
return THPUtils_packString(oss.str().c_str());
201+
}
202+
203+
PyObject* THPIInfo_str(THPIInfo* self) {
204+
auto type = self->type;
205+
std::string primary_name, legacy_name;
206+
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(type);
207+
std::ostringstream oss;
208+
209+
oss << "iinfo(min=" << PyFloat_AsDouble(THPIInfo_min(self, nullptr));
210+
oss << ", max=" << PyFloat_AsDouble(THPIInfo_max(self, nullptr));
211+
oss << ", dtype=" << PyUnicode_AsUTF8(THPIInfo_dtype(self, nullptr)) << ")";
212+
213+
return THPUtils_packString(oss.str().c_str());
214+
}
215+
179216
static struct PyGetSetDef THPFInfo_properties[] = {
180217
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
181218
{"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
182219
{"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
183220
{"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
184221
{"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
222+
{"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
223+
{"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
185224
{nullptr}};
186225

187226
static PyMethodDef THPFInfo_methods[] = {
@@ -232,6 +271,7 @@ static struct PyGetSetDef THPIInfo_properties[] = {
232271
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
233272
{"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
234273
{"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
274+
{"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
235275
{nullptr}};
236276

237277
static PyMethodDef THPIInfo_methods[] = {

‎torch/csrc/utils/tensor_dtypes.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace torch {
1111
namespace utils {
1212

13-
static std::pair<std::string, std::string> getDtypeNames(
13+
std::pair<std::string, std::string> getDtypeNames(
1414
at::ScalarType scalarType) {
1515
switch (scalarType) {
1616
case at::ScalarType::Byte:

‎torch/csrc/utils/tensor_dtypes.h

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
namespace torch { namespace utils {
88

9+
std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType);
10+
911
void initializeDtypes();
1012

1113
}} // namespace torch::utils

0 commit comments

Comments
 (0)
Please sign in to comment.