Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor topk param #9703

Merged
merged 4 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions oneflow/api/python/framework/tensor_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ DIRECT_PASS_FUNC(PyTensorObject_bincount, functional::bincount)
DIRECT_PASS_FUNC(PyTensorObject_isclose, functional::isclose)
DIRECT_PASS_FUNC(PyTensorObject_broadcast_to, functional::broadcast_to)
DIRECT_PASS_FUNC(PyTensorObject_unique, functional::unique)
DIRECT_PASS_FUNC(PyTensorObject_topk, functional::topk)

// functions that parsing at Python C api layer
static PyObject* PyTensorObject_byte(PyObject* self, PyObject* unused) {
Expand Down Expand Up @@ -1018,6 +1019,7 @@ PyMethodDef PyTensorObject_extra_methods[] = {
{"isclose", (PyCFunction)PyTensorObject_isclose, METH_VARARGS | METH_KEYWORDS, NULL},
{"broadcast_to", (PyCFunction)PyTensorObject_broadcast_to, METH_VARARGS | METH_KEYWORDS, NULL},
{"unique", (PyCFunction)PyTensorObject_unique, METH_VARARGS | METH_KEYWORDS, NULL},
{"topk", (PyCFunction)PyTensorObject_topk, METH_VARARGS | METH_KEYWORDS, NULL},

// macro UNARY_METHOD
{"abs", PyTensorObject_abs, METH_NOARGS, NULL},
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2647,7 +2647,7 @@
bind_python: True

- name: "topk"
signature: "TensorTuple[values, indices] (Tensor input, Int32 k, Int32 dim=-1, Bool largest=True, Bool sorted=True) => TopK"
signature: "TensorTuple[values, indices] (Tensor input, Int32 k, Int32 dim=None, Bool largest=True, Bool sorted=True) => TopK"
bind_python: True

- name: "in_top_k"
Expand Down
10 changes: 6 additions & 4 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3118,16 +3118,18 @@ class ToDeviceFunctor {
class TopKFunctor {
public:
TopKFunctor() { op_ = CHECK_JUST(one::OpBuilder("top_k").Input("in").Output("out").Build()); }
Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& input, const int32_t& k,
const int32_t& dim, const bool largest, const bool sorted) const {
Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& input, const int32_t k,
const Optional<int32_t>& dim, const bool largest,
const bool sorted) const {
auto outputs = std::make_shared<TensorTuple>(2);
std::shared_ptr<Tensor> values;
std::shared_ptr<Tensor> indices;

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("k", "sorted");
attrs.SetAllAttrs(k, sorted);

int32_t axis = dim;
int32_t dim_value = dim.value_or(-1);
int32_t axis = dim_value;
axis = JUST(maybe_wrap_dim(axis, input->ndim()));
if (axis == input->ndim() - 1) {
if (largest) {
Expand All @@ -3139,7 +3141,7 @@ class TopKFunctor {
values = JUST(DimGather(input, axis, indices, false));

} else {
auto perm = JUST(GetPermWhenTransposeAxisToLastDim(input->ndim(), dim));
auto perm = JUST(GetPermWhenTransposeAxisToLastDim(input->ndim(), dim_value));
auto x = JUST(Transpose(input, *perm));
if (largest) {
indices = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {x}, attrs));
Expand Down
5 changes: 0 additions & 5 deletions python/oneflow/framework/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,6 @@ def _T(self):
return flow._C.T(self)


def _topk(self, k, dim: int = None, largest: bool = True, sorted: bool = True):
return flow.topk(self, k, dim, largest, sorted)


def _nms(boxes, scores, iou_threshold: float):
return flow.nms(boxes, scores, iou_threshold)

Expand Down Expand Up @@ -627,7 +623,6 @@ def RegisterMethods():
Tensor.eq = _eq
Tensor.sort = _sort
Tensor.tolist = _tolist
Tensor.topk = _topk
Tensor.nms = _nms
Tensor.nonzero = _nonzero
Tensor.prod = _prod
Expand Down
6 changes: 3 additions & 3 deletions python/oneflow/test/tensor/test_tensor_part_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,9 +1035,9 @@ def test_tensor_topk_with_random_data(test_case):
x = random_tensor(ndim=4, dim1=8, dim2=9, dim3=10).to(device)
y = x.topk(
random(low=1, high=8).to(int),
dim=random(low=1, high=4).to(int),
largest=random_bool(),
sorted=constant(True),
dim=random(low=1, high=4).to(int) | nothing(),
largest=random_bool() | nothing(),
sorted=constant(True) | nothing(),
)
return y[0], y[1]

Expand Down