Skip to content

Commit

Permalink
Merge pull request #9 from xeTaiz/master
Browse files Browse the repository at this point in the history
Use torch.searchsorted
  • Loading branch information
aliutkus authored Aug 27, 2020
2 parents f108f32 + 5e5a436 commit 8fc1563
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 16 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# torchinterp1d
## CUDA 1-D interpolation for Pytorch

Requires PyTorch >= 1.6 (due to [torch.searchsorted](https://pytorch.org/docs/master/generated/torch.searchsorted.html)).

## Presentation

This repository implements an `Interp1d` class that overrides torch.autograd.Function, enabling
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@
packages=['torchinterp1d'],
keywords='interp1d torch',
install_requires=[
'torchsearchsorted @ git+https://github.com/aliutkus/torchsearchsorted',
'torch>=1.6',
],
)
16 changes: 1 addition & 15 deletions torchinterp1d/interp1d.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import torch
import contextlib
SEARCHSORTED_AVAILABLE = True
try:
from torchsearchsorted import searchsorted
except ImportError:
SEARCHSORTED_AVAILABLE = False


class Interp1d(torch.autograd.Function):
def __call__(self, x, y, xnew, out=None):
Expand Down Expand Up @@ -37,14 +31,6 @@ def forward(ctx, x, y, xnew, out=None):
Tensor for the output. If None: allocated automatically.
"""
# checking availability of the searchsorted pytorch module
if not SEARCHSORTED_AVAILABLE:
raise Exception(
'The interp1d function depends on the '
'torchsearchsorted module, which is not available.\n'
'You must get it at ',
'https://github.com/aliutkus/torchsearchsorted\n')

# making the vectors at least 2D
is_flat = {}
require_grad = {}
Expand Down Expand Up @@ -107,7 +93,7 @@ def forward(ctx, x, y, xnew, out=None):

# calling searchsorted on the x values.
ind = ynew.long()
searchsorted(v['x'].contiguous(), v['xnew'].contiguous(), ind)
torch.searchsorted(v['x'].contiguous(), v['xnew'].contiguous(), out=ind)

# the `-1` is because searchsorted looks for the index where the values
# must be inserted to preserve order. And we want the index of the
Expand Down

0 comments on commit 8fc1563

Please sign in to comment.