Skip to content

Commit

Permalink
Refactor raw arg update
Browse files Browse the repository at this point in the history
  • Loading branch information
EwanC committed Oct 31, 2024
1 parent 89193f0 commit 331d19b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 30 deletions.
24 changes: 3 additions & 21 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
}
}

return std::move(UniqueDeps);
return UniqueDeps;
}

void graph_impl::markCGMemObjs(
Expand Down Expand Up @@ -1563,7 +1563,7 @@ void exec_graph_impl::updateImpl(std::shared_ptr<node_impl> Node) {
}
}

UpdateDesc.hNewKernel = nullptr;
UpdateDesc.hNewKernel = UrKernel;
UpdateDesc.numNewMemObjArgs = MemobjDescs.size();
UpdateDesc.pNewMemObjArgList = MemobjDescs.data();
UpdateDesc.numNewPointerArgs = PtrDescs.size();
Expand Down Expand Up @@ -1852,24 +1852,7 @@ void dynamic_parameter_impl::updateValue(const raw_kernel_arg *NewRawValue,
size_t RawArgSize = NewRawValue->MArgSize;
const void *RawArgData = NewRawValue->MArgData;

for (auto &[NodeWeak, ArgIndex] : MNodes) {
auto NodeShared = NodeWeak.lock();
if (NodeShared) {
dynamic_parameter_impl::updateCGArgValue(
NodeShared->MCommandGroup, ArgIndex, RawArgData, RawArgSize);
}
}

for (auto &DynCGInfo : MDynCGs) {
auto DynCG = DynCGInfo.DynCG.lock();
if (DynCG) {
auto &CG = DynCG->MKernels[DynCGInfo.CGIndex];
dynamic_parameter_impl::updateCGArgValue(CG, DynCGInfo.ArgIndex,
RawArgData, RawArgSize);
}
}

std::memcpy(MValueStorage.data(), RawArgData, RawArgSize);
updateValue(RawArgData, RawArgSize);
}

void dynamic_parameter_impl::updateValue(const void *NewValue, size_t Size) {
Expand Down Expand Up @@ -1987,7 +1970,6 @@ dynamic_command_group_impl::dynamic_command_group_impl(

void dynamic_command_group_impl::finalizeCGFList(
const std::vector<std::function<void(handler &)>> &CGFList) {
// True if kernels use sycl::nd_range, and false if using sycl::range
for (size_t CGFIndex = 0; CGFIndex < CGFList.size(); CGFIndex++) {
const auto &CGF = CGFList[CGFIndex];
// Handler defined inside the loop so it doesn't appear to the runtime
Expand Down
9 changes: 2 additions & 7 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
throw sycl::exception(sycl::errc::invalid,
"Cannot update execution range of a node with an "
"execution range of different dimensions than what "
"the node was original created with.");
"the node was originally created with.");
}

NDRDesc = sycl::detail::NDRDescT{ExecutionRange};
Expand All @@ -438,7 +438,7 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
throw sycl::exception(sycl::errc::invalid,
"Cannot update execution range of a node with an "
"execution range of different dimensions than what "
"the node was original created with.");
"the node was originally created with.");
}

NDRDesc = sycl::detail::NDRDescT{ExecutionRange};
Expand Down Expand Up @@ -1173,11 +1173,6 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
/// @param Deps List of dependent nodes
void addDepsToNode(std::shared_ptr<node_impl> Node,
std::vector<std::shared_ptr<node_impl>> &Deps) {
// Remove empty shared pointers from the list
auto EmptyElementIter =
std::remove(Deps.begin(), Deps.end(), std::shared_ptr<node_impl>());
Deps.erase(EmptyElementIter, Deps.end());

if (!Deps.empty()) {
for (auto &N : Deps) {
N->registerSuccessor(Node);
Expand Down
11 changes: 9 additions & 2 deletions sycl/source/handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,10 @@ event handler::finalize() {
// node can set it as a predecessor.
auto DependentNode = GraphImpl->getLastInorderNode(MQueue);
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Deps = {DependentNode};
Deps;
if (DependentNode) {
Deps.push_back(DependentNode);
}
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);

// If we are recording an in-order queue remember the new node, so it
Expand All @@ -566,7 +569,11 @@ event handler::finalize() {
} else {
auto LastBarrierRecordedFromQueue = GraphImpl->getBarrierDep(MQueue);
std::vector<std::shared_ptr<ext::oneapi::experimental::detail::node_impl>>
Deps = {LastBarrierRecordedFromQueue};
Deps;

if (LastBarrierRecordedFromQueue) {
Deps.push_back(LastBarrierRecordedFromQueue);
}
NodeImpl = GraphImpl->add(NodeType, std::move(CommandGroup), Deps);

if (NodeImpl->MCGType == sycl::detail::CGType::Barrier) {
Expand Down

0 comments on commit 331d19b

Please sign in to comment.