diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 65518d288..490c423bd 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -2515,12 +2515,18 @@ struct HashOutputIteratorDeref { // this is what you get when you dereference template struct HashOutputIterator { // outputs just the index of the pair. - explicit HashOutputIterator(T *t) : t_(t) {} - __device__ __forceinline__ HashOutputIteratorDeref operator[]( + explicit __host__ __device__ __forceinline__ HashOutputIterator(T *t) + : t_(t) {} + __host__ __device__ __forceinline__ HashOutputIteratorDeref operator[]( int32_t idx) const { return HashOutputIteratorDeref(t_ + idx); } - __device__ __forceinline__ HashOutputIterator operator+(size_t offset) { + __host__ __device__ __forceinline__ HashOutputIteratorDeref operator*() + const { + return HashOutputIteratorDeref(t_); + } + __host__ __device__ __forceinline__ HashOutputIterator + operator+(size_t offset) { return HashOutputIterator{t_ + offset}; } T *t_; diff --git a/k2/csrc/ragged_ops_inl.h b/k2/csrc/ragged_ops_inl.h index 322871a37..92536f1cb 100644 --- a/k2/csrc/ragged_ops_inl.h +++ b/k2/csrc/ragged_ops_inl.h @@ -578,12 +578,18 @@ struct PairOutputIteratorDeref { // this is what you get when you dereference template struct PairOutputIterator { // outputs just the index of the pair. - explicit PairOutputIterator(int32_t *i) : i_(i) {} - __device__ __forceinline__ PairOutputIteratorDeref operator[]( + explicit __host__ __device__ __forceinline__ PairOutputIterator(int32_t *i) + : i_(i) {} + __host__ __device__ __forceinline__ PairOutputIteratorDeref operator[]( int32_t idx) const { return PairOutputIteratorDeref(i_ + idx); } - __device__ __forceinline__ PairOutputIterator operator+(int32_t offset) { + __host__ __device__ __forceinline__ PairOutputIteratorDeref operator*() + const { + return PairOutputIteratorDeref(i_); + } + __host__ __device__ __forceinline__ PairOutputIterator + operator+(int32_t offset) { return PairOutputIterator{i_ + offset}; } int32_t *i_; diff --git a/k2/python/csrc/torch.h b/k2/python/csrc/torch.h index 454469e69..396a82bd6 100644 --- a/k2/python/csrc/torch.h +++ b/k2/python/csrc/torch.h @@ -30,6 +30,14 @@ #include "k2/python/csrc/torch.h" #include "torch/extension.h" +#if K2_TORCH_VERSION_MAJOR > 2 || \ + (K2_TORCH_VERSION_MAJOR == 2 && K2_TORCH_VERSION_MINOR >= 4) +// For torch >= 2.4.x +// do nothing to fix the following error +// error: class "pybind11::detail::type_caster" has +// already been defined +#else +// For torch < 2.4 namespace pybind11 { namespace detail { @@ -71,6 +79,7 @@ struct type_caster { } // namespace detail } // namespace pybind11 +#endif namespace k2 { /* Transfer an object to a specific device.