Skip to content

Commit

Permalink
Revert "Fix performance regression of simple indexing cases (pytorch#…
Browse files Browse the repository at this point in the history
…6793)" (pytorch#6886)

This reverts commit 8a01669.
  • Loading branch information
colesbury authored and soumith committed Apr 24, 2018
1 parent b6ed729 commit 9765bb5
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 161 deletions.
11 changes: 0 additions & 11 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,22 +1370,11 @@ def test(x, ia, ib):
# test setitem
x_clone1 = x.clone()
x_clone2 = x.clone()
x_clone3 = x.clone()
first_shape = x[:, ia, None, ib, 0].shape
second_shape = x[ia].shape
x_clone1[:, ia, None, ib, 0] = torch.randn(first_shape).to(x_clone1)
x_clone2[ia] = torch.randn(second_shape).to(x_clone2)

# fill equivalents
x_clone1[:, ia, None, ib, 0] = 5
x_clone2[ia] = 7

# mask equivalents
mask = (torch.randn(x_clone3.size()) < 0).to(ia.device)
x_clone3[mask]
self.assertEqual(x_clone3[mask].cpu(), x_clone3.cpu()[mask.cpu()])
x_clone3[mask] = 6

cpu = torch.device('cpu')
for device in ['cuda:0', 'cuda:1'] if torch.cuda.device_count() > 1 else ['cuda']:
# Index cpu tensor with cuda tensor
Expand Down
26 changes: 0 additions & 26 deletions test/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,32 +254,6 @@ def test_zero_dim_index(self):
self.assertEqual(x, x[0])
self.assertEqual(len(w), 1)

def test_legacy_dispatch(self):
# compare with indexing using index_select / index_fill etc
x = torch.arange(0, 9).view(3, 3)
idx = torch.tensor([0, 2])
self.assertEqual(x[idx], x.index_select(0, idx))
self.assertEqual(x[:, idx], x.index_select(1, idx))

mask = x > 4
self.assertEqual(x[mask], x.masked_select(mask))

y = x.clone()
yr = x.clone()
y[idx] = 0
yr.index_fill_(0, idx, 0)
self.assertEqual(y, yr)
y[:, idx] = 2
yr.index_fill_(1, idx, 2)
self.assertEqual(y, yr)

mask = x > 4
y = x.clone()
yr = x.clone()
y[mask] = 10
yr.masked_fill_(mask, 10)
self.assertEqual(y, yr)


# The tests below are from NumPy test_indexing.py with some modifications to
# make them compatible with PyTorch. It's licensed under the BDS license below:
Expand Down
11 changes: 5 additions & 6 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,16 @@ def fn(x, y):
# index-2 is not implemented in interpreter
@unittest.expectedFailure
def test_index(self):
x = Variable(torch.rand(2, 2, 2), requires_grad=True)
x = Variable(torch.Tensor([0.4]), requires_grad=True)
y = Variable(torch.LongTensor([0]), requires_grad=True)
y2 = Variable(torch.LongTensor([1]), requires_grad=True)

@torch.jit.compile(nderivs=0)
def fn(x, y, y2):
return x[y, y2]
def fn(x, y):
return x[y]

z = fn(x, y, y2)
z = fn(x, y)
with self.assertCompiled(fn):
z2 = fn(x, y, y2)
z2 = fn(x, y)
self.assertEqual(z, z2)

# Backwards tracing was broken for indexing by a constant,
Expand Down
119 changes: 1 addition & 118 deletions torch/csrc/autograd/python_variable_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/utils/python_compat.h"
#include "torch/csrc/utils/python_numbers.h"
#include "torch/csrc/utils/tensor_conversion_dispatch.h"
#include "torch/csrc/utils/tensor_new.h"
#include "torch/csrc/utils/tensor_conversion_dispatch.h"

#include <ATen/ExpandUtils.h>
#include <vector>
Expand Down Expand Up @@ -169,16 +169,6 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis
return result;
}

static Tensor typeConvertIndex(const Variable& self, const Variable& ind) {
int64_t device = self.is_cuda() ? self.get_device() : -1;
if (ind.defined()) {
auto& new_type = ind.type().toBackend(self.type().backend());
return torch::utils::dispatch_type_conversion(ind, new_type, device, false);
} else {
return ind;
}
}

static std::vector<Tensor> typeConvertIndices(const Variable& self, const variable_list& indices) {
std::vector<Tensor> converted_inds(indices.size());
int64_t device = self.is_cuda() ? self.get_device() : -1;
Expand Down Expand Up @@ -271,97 +261,6 @@ static PyObject* applyBoolGetitem(const Variable& self, bool index) {
}
}

enum class LegacyIndexingType {
None,
Mask,
Index,
};

static std::pair<LegacyIndexingType, int64_t>
getLegacyIndexingType(const Variable& self, const variable_list& vars) {
// TODO: this could be that the broadcasted size is the same.
if (vars.size() == 1 && vars[0].type().scalarType() == ScalarType::Byte && vars[0].is_same_size(self)) {
return std::make_pair(LegacyIndexingType::Mask, -1);
}

// single tensor indexing
int num_defined_variables = 0;
int64_t index_dim = -1;
for (size_t i = 0; i < vars.size(); i++) {
auto& variable = vars[i];
auto is_defined = variable.defined();
num_defined_variables += is_defined;
if (is_defined) {
index_dim = (int64_t)i;
if (num_defined_variables > 1) {
break;
}
if (variable.dim() != 1 || variable.type().scalarType() != ScalarType::Long || variable.numel() == 0) {
num_defined_variables = -1;
break;
}
}
}

if (num_defined_variables == 1) {
return std::make_pair(LegacyIndexingType::Index, index_dim);
}
// advanced indexing
return std::make_pair(LegacyIndexingType::None, -1);
}

static Variable dispatch_legacy_index(const Variable& self, const variable_list& vars,
std::pair<LegacyIndexingType, int64_t> legacyIndex) {
LegacyIndexingType indexingType = std::get<0>(legacyIndex);
switch(indexingType) {
case LegacyIndexingType::Mask: {
auto mask = vars[0];
auto mask_convert = typeConvertIndex(self, mask);
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.masked_select(mask_convert);
}
case LegacyIndexingType::Index: {
int64_t index_dim = std::get<1>(legacyIndex);
auto index = vars[index_dim];
auto index_convert = typeConvertIndex(self, index);
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.index_select(index_dim, index_convert);
}
case LegacyIndexingType::None:
default: {
throw std::runtime_error("Unexpected indexing type");
}
}
}

static Variable dispatch_legacy_index_put_(Variable& self, const variable_list& vars, const Variable& value,
std::pair<LegacyIndexingType, int64_t> legacyIndex) {
LegacyIndexingType indexingType = std::get<0>(legacyIndex);
switch(indexingType) {
case LegacyIndexingType::Mask: {
auto mask = vars[0];
auto mask_convert = typeConvertIndex(self, mask);
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.masked_fill_(mask_convert, value);
}
case LegacyIndexingType::Index: {
int64_t index_dim = std::get<1>(legacyIndex);
auto index = vars[index_dim];
auto index_convert = typeConvertIndex(self, index);
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.index_fill_(index_dim, index_convert, value);
}
case LegacyIndexingType::None:
default: {
throw std::runtime_error("Unexpected indexing type");
}
}
}

PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
HANDLE_TH_ERRORS
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
Expand Down Expand Up @@ -396,12 +295,6 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
return applyBoolGetitem(self_, variableIndices[0].toCByte());
}

// TODO move this to ATen
auto legacy_index = getLegacyIndexingType(sliced, variableIndices);
if (std::get<0>(legacy_index) != LegacyIndexingType::None) {
return wrap(dispatch_legacy_index(sliced, variableIndices, legacy_index));
}

// indexing by tensors ("advanced" indexing)
return wrap(dispatch_index(sliced, variableIndices));
Py_RETURN_NONE;
Expand Down Expand Up @@ -468,16 +361,6 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
return 0;
}

// TODO move this to ATen
// we are being overly cautious here and only considering the *_fill_ variants
// (value is a scalar), as there could be broadcasting in the value that could
// happen and is not handled by masked_scatter_ and index_copy_
auto legacy_index = getLegacyIndexingType(sliced, variableIndices);
if (std::get<0>(legacy_index) != LegacyIndexingType::None && value.dim() == 0) {
dispatch_legacy_index_put_(sliced, variableIndices, value, legacy_index);
return 0;
}

// indexing by tensors ("advanced" indexing)
dispatch_index_put_(sliced, variableIndices, value);
return 0;
Expand Down

0 comments on commit 9765bb5

Please sign in to comment.