diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index d039655a9..d484d799c 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -472,6 +472,7 @@ class LowererImpl : public util::Uncopyable { private: bool assemble; bool compute; + std::string funcname; bool loopOrderAllowsShortCircuit = false; int markAssignsAtomicDepth = 0; diff --git a/include/taco/util/scopedset.h b/include/taco/util/scopedset.h index c06d3a9cc..a45f83ae7 100644 --- a/include/taco/util/scopedset.h +++ b/include/taco/util/scopedset.h @@ -6,6 +6,7 @@ #include #include "taco/error.h" +#include "taco/util/strings.h" namespace taco { namespace util { @@ -56,7 +57,10 @@ class ScopedSet { } friend std::ostream& operator<<(std::ostream& os, ScopedSet sset) { - os << "ScopedSet:" << std::endl; + os << "ScopedSet: " << std::endl; + for (auto& s : sset.scopes) { + os << "scope: " << util::join(s) << std::endl; + } return os; } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index d481c69d4..7f7f020e5 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -115,6 +115,7 @@ LowererImpl::lower(IndexStmt stmt, string name, { this->assemble = assemble; this->compute = compute; + this->funcname = name; definedIndexVarsOrdered = {}; definedIndexVars = {}; loopOrderAllowsShortCircuit = allForFreeLoopsBeforeAllReductionLoops(stmt); @@ -568,6 +569,9 @@ Stmt LowererImpl::lowerForall(Forall forall) } MergeLattice caseLattice = MergeLattice::make(forall, iterators, provGraph, definedIndexVars, whereTempsToResult); + // std::cout << "case lattice: " << forall.getIndexVar() << " " << caseLattice << std::endl; + // std::cout << "merge lattice: " << forall.getIndexVar() << " " << caseLattice.getLoopLattice() << std::endl; + vector resultAccesses; set reducedAccesses; std::tie(resultAccesses, reducedAccesses) = getResultAccesses(forall); @@ -586,7 +590,7 @@ Stmt LowererImpl::lowerForall(Forall forall) Stmt loops; // Emit a loop that iterates over over a single iterator (optimization) - if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique()) { + if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique() && false) { MergeLattice loopLattice = caseLattice.getLoopLattice(); MergePoint point = loopLattice.points()[0]; @@ -664,10 +668,11 @@ Stmt LowererImpl::lowerForall(Forall forall) loops = lowerMergeLattice(caseLattice, underivedAncestors[0], forall.getStmt(), reducedAccesses); } -// taco_iassert(loops.defined()); + taco_iassert(loops.defined()); +// std::cout << "LOOPS " << loops << std::endl; if (!generateComputeCode() && !hasStores(loops)) { - // If assembly loop does not modify output arrays, then it can be safely + // If assembly loop does not modify output arrays, then it can be safely // omitted. loops = Stmt(); } @@ -1386,11 +1391,14 @@ Stmt LowererImpl::lowerMergeLattice(MergeLattice caseLattice, IndexVar coordinat bool resolvedCoordDeclared = !modeIteratorsNonMergers.empty(); vector mergeLoopsVec; + std::cout << "Lattice: " << caseLattice << std::endl; for (MergePoint point : loopLattice.points()) { // Each iteration of this loop generates a while loop for one of the merge // points in the merge lattice. IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point, caseLattice)); + std::cout << "Var: " << coordinateVar << " Merge Point: " << point << " Statement: " << statement << " Zeroed: " << zeroedStmt << std::endl; MergeLattice sublattice = caseLattice.subLattice(point); +// std::cout << "sublattice: " << sublattice << std::endl; Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared); mergeLoopsVec.push_back(mergeLoop); } @@ -1565,21 +1573,29 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I { vector result; - if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) { - // Can check value array of some tensor - Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses); - result.push_back(body); - return Block::make(result); - } +// if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) { +// // Can check value array of some tensor +// Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses); +// result.push_back(body); +// return Block::make(result); +// } // Emitting structural cases so unconditionally apply lattice optimizations. - MergeLattice loopLattice = caseLattice.getLoopLattice(); +// MergeLattice loopLattice = caseLattice.getLoopLattice(); + MergeLattice loopLattice = caseLattice; + +// std::cout << "LoopLattice: " << loopLattice << " CaseLattice: " << caseLattice << std::endl; + std::cout << " CaseLattice: " << caseLattice << std::endl; vector appenders; vector inserters; tie(appenders, inserters) = splitAppenderAndInserters(loopLattice.results()); - if (loopLattice.iterators().size() == 1) { + auto skip = true; +// auto skip = !(this->funcname == "compute"); +// std::cout << "skip value " << stmt << " " << skip << std::endl; + + if (loopLattice.iterators().size() == 1 && !loopLattice.points()[0].isOmitter() && skip) { // Just one iterator so no conditional taco_iassert(!loopLattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, @@ -1588,21 +1604,83 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I } else if (!loopLattice.points().empty()) { vector> cases; +// std::cout << "Accessible iterators: " << this->accessibleIterators << std::endl; for (MergePoint point : loopLattice.points()) { +// std::cout << "In the loop lattice lowering phase " << point << std::endl; + + struct ReadyTensors : IndexNotationVisitor { + std::set readyTensors; + std::set definedIndexVars; + + void visit(const AccessNode* node) { + bool ready = true; + for (auto& ivar : node->indexVars) { + if (!util::contains(definedIndexVars, ivar)) { + ready = false; + break; + } + } + if (ready) { + readyTensors.insert(node->tensorVar); + } + } + }; - // if(point.isOmitter()) { - // continue; - // } + auto rt = ReadyTensors(); + rt.definedIndexVars = this->definedIndexVars; + stmt.accept(&rt); + + std::cout << "readyTensors: " << util::join(rt.readyTensors) << std::endl; + + auto skipPoint = false; + for (auto& rl : loopLattice.points().front().locators()) { + + auto rlit = rl.getTensor(); + TensorVar rltv; + for (auto& kv : this->tensorVars) { + if (kv.second == rlit) { + rltv = kv.first; + break; + } + } + + if (util::contains(rt.readyTensors, rltv)) { + std::cout << "Not considering tensorvar: " << rltv << std::endl; + continue; + } + + if (!util::contains(point.locators(), rl)) { + std::cout << "skipping point: " << point << std::endl; + skipPoint = true; + } + } + + if (skipPoint) continue; + + if(point.isOmitter() && hasNoForAlls(stmt)) { +// std::cout << "omitting point: " << point << std::endl; + continue; + } // Construct case expression vector coordComparisons = compareToResolvedCoordinate(point.rangers(), coordinate, coordinateVar); vector omittedRegionIterators = loopLattice.retrieveRegionIteratorsToOmit(point); if (!point.isOmitter()) { +// omittedRegionIterators = filter(omittedRegionIterators, [](const Iterator& it) { +// auto iterTensorVar = it.getTensor(); +// +//// auto tensor +//// this->tensorVars +//// this->tens +// return false; +// }); std::vector neqComparisons = compareToResolvedCoordinate(omittedRegionIterators, coordinate, coordinateVar); append(coordComparisons, neqComparisons); } +// std::cout << util::join(coordComparisons) << std::endl; + coordComparisons = filter(coordComparisons, [](const Expr& e) { return e.defined(); }); // Construct case body @@ -1683,13 +1761,16 @@ std::vector LowererImpl::constructInnerLoopCasePreamble(ir::Expr coord continue; } Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); + taco_iassert(nonZeroCase.defined()); Stmt declaration = VarDecl::make(caseName, nonZeroCase); result.push_back(declaration); iteratorToConditionMap[tensorIterators[i]] = caseName; } for(size_t i = modeItersWithIndexCases.size(); i < valueComparisons.size(); ++i) { + if (!valueComparisons[i].defined()) continue; Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); + taco_iassert(valueComparisons[i].defined()); Stmt declaration = VarDecl::make(caseName, valueComparisons[i]); result.push_back(declaration); iteratorToConditionMap[tensorIterators[i]] = caseName; diff --git a/test/op_factory.h b/test/op_factory.h index 1871528fa..a5962e198 100644 --- a/test/op_factory.h +++ b/test/op_factory.h @@ -71,6 +71,7 @@ struct IntersectGenDeMorgan { struct xorGen { IterationAlgebra operator()(const std::vector& regions) { +// return Intersect(regions[0], regions[1]); IterationAlgebra noIntersect = Complement(Intersect(regions[0], regions[1])); return Intersect(noIntersect, Union(regions[0], regions[1])); } @@ -124,6 +125,14 @@ struct identityFunc { struct GeneralAdd { ir::Expr operator()(const std::vector &v) { +// return ir::Literal::make(int(v.size())); + +// if (!v.size()) { +// return 0; +// } + +// return 1; + taco_iassert(v.size() >= 2) << "Add operator needs at least two operands"; ir::Expr add = ir::Add::make(v[0], v[1]); diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 1912a691d..52b1c9eee 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -1853,4 +1853,42 @@ TEST_STMT(XorTestOrder2, } ) +TEST(weija, weija) { + auto dim = 2; +// Tensor A("A", {dim, dim}, {Dense, Sparse}); + Tensor A("A", {dim, dim}, CSC); + +// Tensor A("A", {dim, dim}, {Dense, Sparse}); + Tensor Z("Z", {dim}, {Sparse}); + Tensor B("B", {dim, dim}, {Dense, Dense}); +// Tensor B("B", {dim, dim}, {Dense, Sparse}); + Tensor C("C", {dim, dim}, {Dense, Dense}); + Tensor D("D", {dim, dim}, {Dense, Dense}); + Tensor E("E", {dim, dim}, {Dense, Dense}); + + A.insert({0, 0}, 1); A.insert({1, 1}, 1); A.pack(); + B.insert({0, 0}, 1); B.insert({0, 1}, 1); B.pack(); + Z.insert({0}, 1); Z.insert({1}, 1); Z.pack(); + + IndexVar i("i"), j("j"), k("k"), l("l"), m("m"); + C(i, j) = xorOp(A(i, j), Z(j)); +// C(i, j) = A(i, j) * Z(i); + auto stmt =C.getAssignment().concretize().reorder({j, i}); + std::cout << stmt << std::endl; + C.compile(stmt); +// C(i, j) = xorOp(A(i, k), B(k, j)); +// E(i, l) = xorOp(xorOp(A(i, k), B(k, j)), D(j, l)); +// C(i, j) = xorOp(A(i, j), B(i, j)); +// C(i, j) = A(i, j) * B(j, i); +// C(i, j) = xorOp(A(i, j), B(j, i)); +// C.compile(C.getAssignment().concretize().reorder({i, j, k})); +// std::cout << "starting" << std::endl; +// C.compile(); + std::cout << C.getSource() << std::endl; + C.evaluate(); +// E.compile(); +// std::cout << E.getSource() << std::endl; + std::cout << C << std::endl; +} + }}