Skip to content

Commit

Permalink
Allow indexing with numeric or logical matrixes. (#1183)
Browse files Browse the repository at this point in the history
* Allow indexing with numeric or logical matrixes.

* Make sure there's protection

* Numeric indexes should be moved to the correct device too.

* fixes to device placement

* + fixes
  • Loading branch information
dfalbel authored Jul 30, 2024
1 parent fe5020f commit fbb2d80
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 30 deletions.
16 changes: 8 additions & 8 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ mnist-r.*
# ^vignettes/using-autograd\.Rmd

# uncomment below for CRAN submission
^inst/bin/.*
^inst/include/(?!torch.h|lantern|torch_RcppExports.h|utils.h|torch_impl.h|torch_types.h|torch_api.h|torch_deleters.h|torch_imports.h).*
^inst/lib/.*
^inst/share/.*
^inst/build-hash
^inst/build-versions
^src/lantern/.*
^tests/testthat/assets/model-v.*
# ^inst/bin/.*
# ^inst/include/(?!torch.h|lantern|torch_RcppExports.h|utils.h|torch_impl.h|torch_types.h|torch_api.h|torch_deleters.h|torch_imports.h).*
# ^inst/lib/.*
# ^inst/share/.*
# ^inst/build-hash
# ^inst/build-versions
# ^src/lantern/.*
# ^tests/testthat/assets/model-v.*

^doc$
^Meta$
Expand Down
2 changes: 1 addition & 1 deletion R/indexing.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ print.slice <- function(x, ...) {
N = .Machine$integer.max,
newaxis = NULL,
`..` = structure(list(), class = "fill")
)
)

tensor_slice <- function(tensor, ..., drop = TRUE) {
Tensor_slice(tensor, environment(), drop = drop, mask = .d)
Expand Down
4 changes: 2 additions & 2 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44146,11 +44146,11 @@ BEGIN_RCPP
END_RCPP
}
// Tensor_slice_put
void Tensor_slice_put(Rcpp::XPtr<XPtrTorchTensor> self, Rcpp::Environment e, SEXP rhs, Rcpp::List mask);
void Tensor_slice_put(XPtrTorchTensor self, Rcpp::Environment e, SEXP rhs, Rcpp::List mask);
RcppExport SEXP _torch_Tensor_slice_put(SEXP selfSEXP, SEXP eSEXP, SEXP rhsSEXP, SEXP maskSEXP) {
BEGIN_RCPP
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::XPtr<XPtrTorchTensor> >::type self(selfSEXP);
Rcpp::traits::input_parameter< XPtrTorchTensor >::type self(selfSEXP);
Rcpp::traits::input_parameter< Rcpp::Environment >::type e(eSEXP);
Rcpp::traits::input_parameter< SEXP >::type rhs(rhsSEXP);
Rcpp::traits::input_parameter< Rcpp::List >::type mask(maskSEXP);
Expand Down
39 changes: 27 additions & 12 deletions src/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ struct index_info {
// returns true if appended a vector like object. We use the boolean vector
// to decide if we should start a new index object.
index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice,
bool drop) {
bool drop, torch::Device device) {
// a single NA means empty argument which and in turn we must select
// all elements in that dimension.
if (TYPEOF(slice) == LGLSXP && LENGTH(slice) == 1 &&
Expand Down Expand Up @@ -249,13 +249,26 @@ index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice,
// if it's a numeric vector
if ((TYPEOF(slice) == REALSXP || TYPEOF(slice) == INTSXP) &&
LENGTH(slice) > 1) {
index_append_integer_vector(index, slice);
return {1, true, false};
// if it's a numeric vector but has a dim attribute, we convert the value to a Tensor
// before adding it to the index.
const auto dims = Rcpp::RObject(Rf_getAttrib(slice, R_DimSymbol));
if (Rf_isNull(dims)) {
index_append_integer_vector(index, slice);
return {1, true, false};
}
// If the slice has a dim attribute, we convert it to a tensor and let the code
// continue to add it to the index.
slice = torch_tensor_cpp(slice, torch::Dtype(lantern_Dtype_int64()), device);
}

if (TYPEOF(slice) == LGLSXP) {
index_append_bool_vector(index, slice);
return {1, true, false};
const auto dims = Rcpp::RObject(Rf_getAttrib(slice, R_DimSymbol));
if (Rf_isNull(dims)) {
index_append_bool_vector(index, slice);
return {1, true, false};
}
/// convert to tensor a let it go
slice = torch_tensor_cpp(slice, torch::Dtype(lantern_Dtype_bool()));
}

if (Rf_inherits(slice, "torch_tensor")) {
Expand All @@ -271,15 +284,15 @@ index_info index_append_sexp(XPtrTorchTensorIndex& index, SEXP slice,
}

std::vector<XPtrTorchTensorIndex> slices_to_index(
std::vector<Rcpp::RObject> slices, bool drop) {
std::vector<Rcpp::RObject> slices, bool drop, torch::Device device) {
std::vector<XPtrTorchTensorIndex> output;
XPtrTorchTensorIndex index = lantern_TensorIndex_new();
SEXP slice;
int num_dim = 0;
bool has_ellipsis = false;
for (auto i = 0; i < slices.size(); i++) {
slice = slices[i];
auto info = index_append_sexp(index, slice, drop);
auto info = index_append_sexp(index, slice, drop, device);

if (!has_ellipsis && info.ellipsis) {
has_ellipsis = true;
Expand Down Expand Up @@ -328,7 +341,8 @@ std::vector<XPtrTorchTensorIndex> slices_to_index(
XPtrTorchTensor Tensor_slice(XPtrTorchTensor self, Rcpp::Environment e,
bool drop, Rcpp::List mask) {
auto dots = evaluate_slices(enquos0(e), mask);
auto index = slices_to_index(dots, drop);
auto device = torch::Device(lantern_Tensor_device(self.get()));
auto index = slices_to_index(dots, drop, device);
XPtrTorchTensor out = self;
for (auto& ind : index) {
out = lantern_Tensor_index(out.get(), ind.get());
Expand All @@ -339,10 +353,11 @@ XPtrTorchTensor Tensor_slice(XPtrTorchTensor self, Rcpp::Environment e,
XPtrTorchScalar cpp_torch_scalar(SEXP x);

// [[Rcpp::export]]
void Tensor_slice_put(Rcpp::XPtr<XPtrTorchTensor> self, Rcpp::Environment e,
void Tensor_slice_put(XPtrTorchTensor self, Rcpp::Environment e,
SEXP rhs, Rcpp::List mask) {
auto dots = evaluate_slices(enquos0(e), mask);
auto indexes = slices_to_index(dots, true);
auto device = torch::Device(lantern_Tensor_device(self.get()));
auto indexes = slices_to_index(dots, true, device);

if (indexes.size() > 1) {
Rcpp::stop(
Expand All @@ -356,13 +371,13 @@ void Tensor_slice_put(Rcpp::XPtr<XPtrTorchTensor> self, Rcpp::Environment e,
TYPEOF(rhs) == LGLSXP || TYPEOF(rhs) == STRSXP) &&
LENGTH(rhs) == 1) {
auto s = cpp_torch_scalar(rhs);
lantern_Tensor_index_put_scalar_(self->get(), index.get(), s.get());
lantern_Tensor_index_put_scalar_(self.get(), index.get(), s.get());
return;
}

if (Rf_inherits(rhs, "torch_tensor")) {
Rcpp::XPtr<XPtrTorchTensor> t = Rcpp::as<Rcpp::XPtr<XPtrTorchTensor>>(rhs);
lantern_Tensor_index_put_tensor_(self->get(), index.get(), t->get());
lantern_Tensor_index_put_tensor_(self.get(), index.get(), t->get());
return;
}

Expand Down
7 changes: 0 additions & 7 deletions src/lantern/src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,39 +156,34 @@ bool *_lantern_Tensor_data_ptr_bool(void *self) {
int64_t _lantern_Tensor_numel(void *self) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
return x.numel();
LANTERN_FUNCTION_END_RET(0)
}

int64_t _lantern_Tensor_element_size(void *self) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
return x.element_size();
LANTERN_FUNCTION_END_RET(0)
}

int64_t _lantern_Tensor_ndimension(void *self) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
return x.ndimension();
LANTERN_FUNCTION_END_RET(0)
}

int64_t _lantern_Tensor_size(void *self, int64_t i) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
return x.size(i);
LANTERN_FUNCTION_END_RET(0)
}

void *_lantern_Tensor_dtype(void *self) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
torch::Dtype dtype = c10::typeMetaToScalarType(x.dtype());
return make_raw::Dtype(dtype);
LANTERN_FUNCTION_END
Expand All @@ -197,7 +192,6 @@ void *_lantern_Tensor_dtype(void *self) {
void *_lantern_Tensor_device(void *self) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
torch::Device device = x.device();
return make_raw::Device(device);
LANTERN_FUNCTION_END
Expand Down Expand Up @@ -238,7 +232,6 @@ void *_lantern_Tensor_names(void *self) {
bool _lantern_Tensor_has_any_zeros(void *self) {
LANTERN_FUNCTION_START
torch::Tensor x = from_raw::Tensor(self);
;
return (x == 0).any().item().toBool();
LANTERN_FUNCTION_END
}
Expand Down
38 changes: 38 additions & 0 deletions tests/testthat/test-indexing.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,41 @@ test_that("NULL tensor", {
expect_error(torch_tensor(as.integer(NULL))[1], regexp = "out of bounds")

})

test_that("works with numeric /logic matrix", {
# Regression test for: https://github.com/mlverse/torch/issues/1181
x <- torch_randn(4, 4)
y <- rbind(c(1, 1), c(1,2))

expect_true(
torch_allclose(
x[y],
x[torch_tensor(y, dtype = "long")]
)
)

expect_true(
torch_allclose(
x[x > 0],
x[as.array(x>0)]
)
)

# also test if it works when the tensor is in a different device
skip_if_not_m1_mac()
x <- x$to(device="mps")

expect_true(
torch_allclose(
x[y],
x[torch_tensor(y, dtype = "long")]
)
)

expect_true(
torch_allclose(
x[x > 0],
x[as.array(x>0)]
)
)
})

0 comments on commit fbb2d80

Please sign in to comment.