Skip to content

Commit

Permalink
Closes Bears-R-Us#3870: bug in reshape for bigint type
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Nov 20, 2024
1 parent b456d3e commit ae91c63
Showing 1 changed file with 103 additions and 1 deletion.
104 changes: 103 additions & 1 deletion src/AryUtil.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module AryUtil
use List;
use CommAggregation;
use CommPrimitives;
use BigInteger;


param bitsPerDigit = RSLSD_bitsPerDigit;
Expand Down Expand Up @@ -905,7 +906,8 @@ module AryUtil
/*
unflatten a 1D array into a multi-dimensional array of the given shape
*/
proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws {
proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws
where t!=bigint {
var unflat = makeDistArray((...shape), t);

if N == 1 {
Expand Down Expand Up @@ -952,6 +954,103 @@ module AryUtil
// flat region is spread across multiple locales, do a get for each source locale
for locInID in locInStart..locInStop {
const flatSubSlice = flatSlice[flatLocRanges[locInID]];
get(
c_ptrTo(unflat[dufc.orderToIndex(flatSubSlice.low)]),
getAddr(a[flatSubSlice.low]),
locInID,
c_sizeof(t) * flatSubSlice.size
);
}
}
}
}

return unflat;
}

proc unflatten(const ref a: [?d] ?t, shape: ?N*int): [] t throws
where t==bigint {
var unflat = makeDistArray((...shape), t);

if N == 1 {
unflat = a;
return unflat;
}

// ranges of flat indices owned by each locale
const flatLocRanges = [loc in Locales] d.localSubdomain(loc).dim(0);
writeln("flatLocRanges");
writeln(flatLocRanges);

coforall loc in Locales with (ref unflat) do on loc {
const lduf = unflat.domain.localSubdomain(),
lastRank = lduf.dim(N-1);
writeln("lduf");
writeln(lduf);
writeln("lastRank");
writeln(lastRank);

// iterate over each slice of contiguous memory in the local subdomain
forall idx in domOffAxis(lduf, N-1) with (
const ord = new orderer(shape),
const dufc = unflat.domain,
in flatLocRanges
) {
var idxTup: (N-1)*int;
for i in 0..<(N-1) do idxTup[i] = idx[i];

writeln("idxTup");
writeln(idxTup);

const low = ((...idxTup), lastRank.low),
high = ((...idxTup), lastRank.high),
flatSlice = ord.indexToOrder(low)..ord.indexToOrder(high);

writeln("low");
writeln(low);
writeln("high");
writeln(high);
writeln("flatSlice");
writeln(flatSlice);
// compute which locales in the input array this slice corresponds to
var locInStart, locInStop = 0;
for (flr, locID) in zip(flatLocRanges, 0..<numLocales) {
if flr.contains(flatSlice.low) then locInStart = locID;
if flr.contains(flatSlice.high) then locInStop = locID;
}

if locInStart == locInStop {
// flat region sits within a single locale, do a single get
writeln("case1");
writeln("c_ptrTo(unflat[low])");
writeln(c_ptrTo(unflat[low]));
writeln("getAddr(a[flatSlice.low])");
writeln(getAddr(a[flatSlice.low]));
writeln("locInStart");
writeln(locInStart);
writeln("c_sizeof(t) * flatSlice.size");
writeln(c_sizeof(t) * flatSlice.size);

get(
c_ptrTo(unflat[low]),
getAddr(a[flatSlice.low]),
locInStart,
c_sizeof(t) * flatSlice.size
);
} else {
// flat region is spread across multiple locales, do a get for each source locale
for locInID in locInStart..locInStop {
const flatSubSlice = flatSlice[flatLocRanges[locInID]];
writeln("case2");

writeln("c_ptrTo(unflat[dufc.orderToIndex(flatSubSlice.low)])");
writeln(c_ptrTo(unflat[dufc.orderToIndex(flatSubSlice.low)]));
writeln("getAddr(a[flatSubSlice.low])");
writeln(getAddr(a[flatSubSlice.low]));
writeln("locInID");
writeln(locInID);
writeln("c_sizeof(t) * flatSubSlice.size");
writeln(c_sizeof(t) * flatSubSlice.size);

get(
c_ptrTo(unflat[dufc.orderToIndex(flatSubSlice.low)]),
Expand All @@ -964,6 +1063,9 @@ module AryUtil
}
}

writeln("unflat");
writeln(unflat);

return unflat;
}

Expand Down

0 comments on commit ae91c63

Please sign in to comment.