6
6
#include < torch/csrc/utils/python_arg_parser.h>
7
7
#include < torch/csrc/utils/python_numbers.h>
8
8
#include < torch/csrc/utils/python_strings.h>
9
+ #include < torch/csrc/utils/tensor_dtypes.h>
9
10
10
11
#include < c10/util/Exception.h>
11
12
@@ -20,7 +21,7 @@ PyObject* THPFInfo_New(const at::ScalarType& type) {
20
21
if (!self)
21
22
throw python_error ();
22
23
auto self_ = reinterpret_cast <THPDTypeInfo*>(self.get ());
23
- self_->type = type;
24
+ self_->type = c10::toValueType ( type) ;
24
25
return self.release ();
25
26
}
26
27
@@ -34,18 +35,6 @@ PyObject* THPIInfo_New(const at::ScalarType& type) {
34
35
return self.release ();
35
36
}
36
37
37
- PyObject* THPFInfo_str (THPFInfo* self) {
38
- std::ostringstream oss;
39
- oss << " finfo(type=" << self->type << " )" ;
40
- return THPUtils_packString (oss.str ().c_str ());
41
- }
42
-
43
- PyObject* THPIInfo_str (THPIInfo* self) {
44
- std::ostringstream oss;
45
- oss << " iinfo(type=" << self->type << " )" ;
46
- return THPUtils_packString (oss.str ().c_str ());
47
- }
48
-
49
38
PyObject* THPFInfo_pynew (PyTypeObject* type, PyObject* args, PyObject* kwargs) {
50
39
HANDLE_TH_ERRORS
51
40
static torch::PythonArgParser parser ({
@@ -63,7 +52,7 @@ PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
63
52
AT_ASSERT (at::isFloatingType (scalar_type));
64
53
} else {
65
54
scalar_type = r.scalartype (0 );
66
- if (!at::isFloatingType (scalar_type)) {
55
+ if (!at::isFloatingType (scalar_type) && ! at::isComplexType (scalar_type) ) {
67
56
return PyErr_Format (
68
57
PyExc_TypeError,
69
58
" torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'" ,
@@ -123,7 +112,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
123
112
}
124
113
125
114
static PyObject* THPFInfo_eps (THPFInfo* self, void *) {
126
- return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1 (at::kHalf ,
115
+ return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 (at::kHalf , at::ScalarType::BFloat16 ,
127
116
self->type , " epsilon" , [] {
128
117
return PyFloat_FromDouble (
129
118
std::numeric_limits<
@@ -132,20 +121,20 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
132
121
}
133
122
134
123
static PyObject* THPFInfo_max (THPFInfo* self, void *) {
135
- return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1 (at::kHalf , self->type , " max" , [] {
124
+ return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 (at::kHalf , at::ScalarType::BFloat16 , self->type , " max" , [] {
136
125
return PyFloat_FromDouble (
137
126
std::numeric_limits<at::scalar_value_type<scalar_t >::type>::max ());
138
127
});
139
128
}
140
129
141
130
static PyObject* THPFInfo_min (THPFInfo* self, void *) {
142
- return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1 (at::kHalf , self->type , " min " , [] {
131
+ return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 (at::kHalf , at::ScalarType::BFloat16, self->type , " lowest " , [] {
143
132
return PyFloat_FromDouble (
144
133
std::numeric_limits<at::scalar_value_type<scalar_t >::type>::lowest ());
145
134
});
146
135
}
147
136
148
- static PyObject* THPIInfo_max (THPFInfo * self, void *) {
137
+ static PyObject* THPIInfo_max (THPIInfo * self, void *) {
149
138
if (at::isIntegralType (self->type , /* includeBool=*/ false )) {
150
139
return AT_DISPATCH_INTEGRAL_TYPES (self->type , " max" , [] {
151
140
return THPUtils_packInt64 (std::numeric_limits<scalar_t >::max ());
@@ -157,7 +146,7 @@ static PyObject* THPIInfo_max(THPFInfo* self, void*) {
157
146
});
158
147
}
159
148
160
- static PyObject* THPIInfo_min (THPFInfo * self, void *) {
149
+ static PyObject* THPIInfo_min (THPIInfo * self, void *) {
161
150
if (at::isIntegralType (self->type , /* includeBool=*/ false )) {
162
151
return AT_DISPATCH_INTEGRAL_TYPES (self->type , " min" , [] {
163
152
return THPUtils_packInt64 (std::numeric_limits<scalar_t >::lowest ());
@@ -169,19 +158,69 @@ static PyObject* THPIInfo_min(THPFInfo* self, void*) {
169
158
});
170
159
}
171
160
161
+ static PyObject* THPIInfo_dtype (THPIInfo* self, void *) {
162
+ std::string primary_name, legacy_name;
163
+ std::tie (primary_name, legacy_name) = torch::utils::getDtypeNames (self->type );
164
+ return AT_DISPATCH_INTEGRAL_TYPES (self->type , " dtype" , [primary_name] {
165
+ return PyUnicode_FromString ((char *)primary_name.data ());
166
+ });
167
+ }
168
+
172
169
static PyObject* THPFInfo_tiny (THPFInfo* self, void *) {
173
- return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1 (at::kHalf , self->type , " min" , [] {
170
+ return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 (at::kHalf , at::ScalarType::BFloat16 , self->type , " min" , [] {
174
171
return PyFloat_FromDouble (
175
172
std::numeric_limits<at::scalar_value_type<scalar_t >::type>::min ());
176
173
});
177
174
}
178
175
176
+ static PyObject* THPFInfo_resolution (THPFInfo* self, void *) {
177
+ return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 (at::kHalf , at::ScalarType::BFloat16, self->type , " digits10" , [] {
178
+ return PyFloat_FromDouble (
179
+ std::pow (10 , -std::numeric_limits<at::scalar_value_type<scalar_t >::type>::digits10));
180
+ });
181
+ }
182
+
183
+ static PyObject* THPFInfo_dtype (THPFInfo* self, void *) {
184
+ std::string primary_name, legacy_name;
185
+ std::tie (primary_name, legacy_name) = torch::utils::getDtypeNames (self->type );
186
+ return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2 (at::kHalf , at::ScalarType::BFloat16, self->type , " dtype" , [primary_name] {
187
+ return PyUnicode_FromString ((char *)primary_name.data ());
188
+ });
189
+ }
190
+
191
+ PyObject* THPFInfo_str (THPFInfo* self) {
192
+ std::ostringstream oss;
193
+ oss << " finfo(resolution=" << PyFloat_AsDouble (THPFInfo_resolution (self, nullptr ));
194
+ oss << " , min=" << PyFloat_AsDouble (THPFInfo_min (self, nullptr ));
195
+ oss << " , max=" << PyFloat_AsDouble (THPFInfo_max (self, nullptr ));
196
+ oss << " , eps=" << PyFloat_AsDouble (THPFInfo_eps (self, nullptr ));
197
+ oss << " , tiny=" << PyFloat_AsDouble (THPFInfo_tiny (self, nullptr ));
198
+ oss << " , dtype=" << PyUnicode_AsUTF8 (THPFInfo_dtype (self, nullptr )) << " )" ;
199
+
200
+ return THPUtils_packString (oss.str ().c_str ());
201
+ }
202
+
203
+ PyObject* THPIInfo_str (THPIInfo* self) {
204
+ auto type = self->type ;
205
+ std::string primary_name, legacy_name;
206
+ std::tie (primary_name, legacy_name) = torch::utils::getDtypeNames (type);
207
+ std::ostringstream oss;
208
+
209
+ oss << " iinfo(min=" << PyFloat_AsDouble (THPIInfo_min (self, nullptr ));
210
+ oss << " , max=" << PyFloat_AsDouble (THPIInfo_max (self, nullptr ));
211
+ oss << " , dtype=" << PyUnicode_AsUTF8 (THPIInfo_dtype (self, nullptr )) << " )" ;
212
+
213
+ return THPUtils_packString (oss.str ().c_str ());
214
+ }
215
+
179
216
static struct PyGetSetDef THPFInfo_properties[] = {
180
217
{" bits" , (getter)THPDTypeInfo_bits, nullptr , nullptr , nullptr },
181
218
{" eps" , (getter)THPFInfo_eps, nullptr , nullptr , nullptr },
182
219
{" max" , (getter)THPFInfo_max, nullptr , nullptr , nullptr },
183
220
{" min" , (getter)THPFInfo_min, nullptr , nullptr , nullptr },
184
221
{" tiny" , (getter)THPFInfo_tiny, nullptr , nullptr , nullptr },
222
+ {" resolution" , (getter)THPFInfo_resolution, nullptr , nullptr , nullptr },
223
+ {" dtype" , (getter)THPFInfo_dtype, nullptr , nullptr , nullptr },
185
224
{nullptr }};
186
225
187
226
static PyMethodDef THPFInfo_methods[] = {
@@ -232,6 +271,7 @@ static struct PyGetSetDef THPIInfo_properties[] = {
232
271
{" bits" , (getter)THPDTypeInfo_bits, nullptr , nullptr , nullptr },
233
272
{" max" , (getter)THPIInfo_max, nullptr , nullptr , nullptr },
234
273
{" min" , (getter)THPIInfo_min, nullptr , nullptr , nullptr },
274
+ {" dtype" , (getter)THPIInfo_dtype, nullptr , nullptr , nullptr },
235
275
{nullptr }};
236
276
237
277
static PyMethodDef THPIInfo_methods[] = {
0 commit comments