Skip to content

Commit 59422b2

Browse files
committed
fix comments
1 parent 8622fa0 commit 59422b2

File tree

3 files changed

+54
-61
lines changed

3 files changed

+54
-61
lines changed

examples/transposed3d.py

-60
This file was deleted.

src/idtr.cpp

+17
Original file line numberDiff line numberDiff line change
@@ -572,13 +572,17 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
572572

573573
namespace {
574574

575+
///
576+
/// An util class of multi-dimensional index
577+
///
575578
class id {
576579
public:
577580
id(size_t dims) : _values(dims) {}
578581
id(size_t dims, int64_t *value) : _values(value, value + dims) {}
579582
id(const std::vector<int64_t> &values) : _values(values) {}
580583
id(const std::vector<int64_t> &&values) : _values(std::move(values)) {}
581584

585+
/// Permute this id by axes and return a new id
582586
id permute(std::vector<int64_t> axes) const {
583587
std::vector<int64_t> new_values(_values.size());
584588
for (size_t i = 0; i < _values.size(); i++) {
@@ -590,6 +594,7 @@ class id {
590594
int64_t operator[](size_t i) const { return _values[i]; }
591595
int64_t &operator[](size_t i) { return _values[i]; }
592596

597+
/// Subtract another id from this id and return a new id
593598
id operator-(const id &rhs) const {
594599
std::vector<int64_t> new_values(_values.size());
595600
for (size_t i = 0; i < _values.size(); i++) {
@@ -598,6 +603,7 @@ class id {
598603
return id(std::move(new_values));
599604
}
600605

606+
/// Subtract another id from this id and return a new id
601607
id operator-(const int64_t *rhs) const {
602608
std::vector<int64_t> new_values(_values.size());
603609
for (size_t i = 0; i < _values.size(); i++) {
@@ -606,6 +612,10 @@ class id {
606612
return id(std::move(new_values));
607613
}
608614

615+
/// Increase the last dimension value of this id which bounds by shape
616+
///
617+
/// Example:
618+
/// In shape (2,2) : (0,0)->(0,1)->(1,0)->(1,1)->(0,0)
609619
void next(const int64_t *shape) {
610620
size_t i = _values.size();
611621
while (i--) {
@@ -623,15 +633,20 @@ class id {
623633
std::vector<int64_t> _values;
624634
};
625635

636+
///
637+
/// An wrapper template class for distribute multi-dimensional array
638+
///
626639
template <typename T> class ndarray {
627640
public:
628641
ndarray(int64_t nDims, int64_t *gShape, int64_t *gOffsets, void *lData,
629642
int64_t *lShape, int64_t *lStrides)
630643
: _nDims(nDims), _gShape(gShape), _gOffsets(gOffsets), _lData((T *)lData),
631644
_lShape(lShape), _lStrides(lStrides) {}
632645

646+
/// Return the first global index of local data
633647
id firstLocalIndex() const { return id(_nDims, _gOffsets); }
634648

649+
/// Interate all global indices in local data
635650
void localIndices(const std::function<void(const id &)> &callback) const {
636651
size_t size = lSize();
637652
id idx = firstLocalIndex();
@@ -641,6 +656,7 @@ template <typename T> class ndarray {
641656
}
642657
}
643658

659+
/// Interate all global indices of the array
644660
void globalIndices(const std::function<void(const id &)> &callback) const {
645661
size_t size = gSize();
646662
id idx(_nDims);
@@ -660,6 +676,7 @@ template <typename T> class ndarray {
660676
return offset;
661677
}
662678

679+
/// Using global index to access its data
663680
T &operator[](const id &idx) { return _lData[getLocalDataOffset(idx)]; }
664681
T operator[](const id &idx) const { return _lData[getLocalDataOffset(idx)]; }
665682

test/test_manip.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,45 @@ def test_todevice_host2gpu(self):
9494
b = a.to_device(device="GPU")
9595
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])
9696

97-
def test_permute_dims(self):
97+
def test_permute_dims1(self):
9898
a = sp.arange(0, 10, 1, sp.int64)
9999
b = sp.reshape(a, (2, 5))
100100
c1 = sp.to_numpy(sp.permute_dims(b, [1, 0]))
101101
c2 = sp.to_numpy(b).transpose(1, 0)
102102
assert numpy.allclose(c1, c2)
103+
104+
def test_permute_dims2(self):
105+
# === sharpy
106+
sp_a = sp.arange(0, 2 * 3 * 4, 1)
107+
sp_a = sp.reshape(sp_a, [2, 3, 4])
108+
109+
# b = a.swapaxes(1,0).swapaxes(1,2)
110+
sp_b = sp.permute_dims(sp_a, (1, 0, 2)) # 2x4x4 -> 4x2x4 || 4x4x4
111+
sp_b = sp.permute_dims(sp_b, (0, 2, 1)) # 4x2x4 -> 4x4x2 || 4x4x4
112+
113+
# c = b.swapaxes(1,2).swapaxes(1,0)
114+
sp_c = sp.permute_dims(sp_b, (0, 2, 1))
115+
sp_c = sp.permute_dims(sp_c, (1, 0, 2))
116+
117+
assert numpy.allclose(sp.to_numpy(sp_a), sp.to_numpy(sp_c))
118+
119+
# d = a.swapaxes(2,1).swapaxes(2,0)
120+
sp_d = sp.permute_dims(sp_a, (0, 2, 1))
121+
sp_d = sp.permute_dims(sp_d, (2, 1, 0))
122+
123+
# c = d.swapaxes(2,1).swapaxes(0,1)
124+
sp_e = sp.permute_dims(sp_d, (0, 2, 1))
125+
sp_e = sp.permute_dims(sp_e, (1, 0, 2))
126+
127+
# === numpy
128+
np_a = numpy.arange(0, 2 * 3 * 4, 1)
129+
np_a = numpy.reshape(np_a, [2, 3, 4])
130+
131+
np_b = np_a.swapaxes(1, 0).swapaxes(1, 2)
132+
assert numpy.allclose(sp.to_numpy(sp_b), np_b)
133+
134+
np_d = np_a.swapaxes(2, 1).swapaxes(2, 0)
135+
assert numpy.allclose(sp.to_numpy(sp_d), np_d)
136+
137+
np_e = np_d.swapaxes(2, 1).swapaxes(0, 1)
138+
assert numpy.allclose(sp.to_numpy(sp_e), np_e)

0 commit comments

Comments
 (0)