@@ -572,13 +572,17 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr,
572
572
573
573
namespace {
574
574
575
+ // /
576
+ // / An util class of multi-dimensional index
577
+ // /
575
578
class id {
576
579
public:
577
580
id (size_t dims) : _values(dims) {}
578
581
id (size_t dims, int64_t *value) : _values(value, value + dims) {}
579
582
id (const std::vector<int64_t > &values) : _values(values) {}
580
583
id (const std::vector<int64_t > &&values) : _values(std::move(values)) {}
581
584
585
+ // / Permute this id by axes and return a new id
582
586
id permute (std::vector<int64_t > axes) const {
583
587
std::vector<int64_t > new_values (_values.size ());
584
588
for (size_t i = 0 ; i < _values.size (); i++) {
@@ -590,6 +594,7 @@ class id {
590
594
int64_t operator [](size_t i) const { return _values[i]; }
591
595
int64_t &operator [](size_t i) { return _values[i]; }
592
596
597
+ // / Subtract another id from this id and return a new id
593
598
id operator -(const id &rhs) const {
594
599
std::vector<int64_t > new_values (_values.size ());
595
600
for (size_t i = 0 ; i < _values.size (); i++) {
@@ -598,6 +603,7 @@ class id {
598
603
return id (std::move (new_values));
599
604
}
600
605
606
+ // / Subtract another id from this id and return a new id
601
607
id operator -(const int64_t *rhs) const {
602
608
std::vector<int64_t > new_values (_values.size ());
603
609
for (size_t i = 0 ; i < _values.size (); i++) {
@@ -606,6 +612,10 @@ class id {
606
612
return id (std::move (new_values));
607
613
}
608
614
615
+ // / Increase the last dimension value of this id which bounds by shape
616
+ // /
617
+ // / Example:
618
+ // / In shape (2,2) : (0,0)->(0,1)->(1,0)->(1,1)->(0,0)
609
619
void next (const int64_t *shape) {
610
620
size_t i = _values.size ();
611
621
while (i--) {
@@ -623,15 +633,20 @@ class id {
623
633
std::vector<int64_t > _values;
624
634
};
625
635
636
+ // /
637
+ // / An wrapper template class for distribute multi-dimensional array
638
+ // /
626
639
template <typename T> class ndarray {
627
640
public:
628
641
ndarray (int64_t nDims, int64_t *gShape , int64_t *gOffsets , void *lData,
629
642
int64_t *lShape, int64_t *lStrides)
630
643
: _nDims(nDims), _gShape(gShape ), _gOffsets(gOffsets ), _lData((T *)lData),
631
644
_lShape (lShape), _lStrides(lStrides) {}
632
645
646
+ // / Return the first global index of local data
633
647
id firstLocalIndex () const { return id (_nDims, _gOffsets); }
634
648
649
+ // / Interate all global indices in local data
635
650
void localIndices (const std::function<void (const id &)> &callback) const {
636
651
size_t size = lSize ();
637
652
id idx = firstLocalIndex ();
@@ -641,6 +656,7 @@ template <typename T> class ndarray {
641
656
}
642
657
}
643
658
659
+ // / Interate all global indices of the array
644
660
void globalIndices (const std::function<void (const id &)> &callback) const {
645
661
size_t size = gSize ();
646
662
id idx (_nDims);
@@ -660,6 +676,7 @@ template <typename T> class ndarray {
660
676
return offset;
661
677
}
662
678
679
+ // / Using global index to access its data
663
680
T &operator [](const id &idx) { return _lData[getLocalDataOffset (idx)]; }
664
681
T operator [](const id &idx) const { return _lData[getLocalDataOffset (idx)]; }
665
682
0 commit comments