Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Add thread tile size inference for scatter #19694

Merged
merged 4 commits into from
Jan 14, 2025

Conversation

qedawkins
Copy link
Contributor

@qedawkins qedawkins commented Jan 13, 2025

This adds basic logic for picking thread distribution tile sizes for
scatter ops. This allows fallback distribution to kick in for scatter.

Additionally splits derived_thread_config tests into their own file.
The existing test for GPUApplyTilingLevel has grown rather large. The
tests for derived_thread_config within that file aren't really testing
the pass but the logic in how we pick tile sizes. As a result it makes
more sense to split out those tests into a separate file. In the future
the logic for derived_thread_configs will probably need to move
elsewhere but this is good cleanup to start.

The existing test for GPUApplyTilingLevel has grown rather large. The
tests for derived_thread_config within that file aren't really testing
the pass but the logic in how we pick tile sizes. As a result it makes
more sense to split out those tests into a separate file. In the future
the logic for derived_thread_configs will probably need to move
elsewhere but this is good cleanup to start.
This adds basic logic for picking thread distribution tile sizes for
scatter ops. This allows fallback distribution to kick in for scatter.
@qedawkins qedawkins enabled auto-merge (squash) January 14, 2025 01:44
@qedawkins qedawkins disabled auto-merge January 14, 2025 01:44
@qedawkins qedawkins enabled auto-merge (squash) January 14, 2025 01:45
.Default([](Operation *op) -> SmallVector<int64_t> { return {}; });
.Case([&](IREE::LinalgExt::ScatterOp scatterOp) -> SmallVector<int64_t> {
int64_t loopDepth = scatterOp.getLoopIteratorTypes().size();
SmallVector<int64_t> loopBounds =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weird question: can't this be a mix of static and dynamic bounds, in which case you'd want to preserve the static bound information you have?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so static bounds get their value and dynamic ones get ShapedType::kDynamic. It would be useful to leverage divisibility information but for now just using static info is what's best supported. The tiling interface already supports OpFoldResult for this reason so it's something we can transition to in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just confused by why any dynamic dimensions made the whole list dynamic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, the code is maybe a bit misleading. It's only all dynamic as a fallback in case getStaticLoopBounds fails which it never does. It returns the mixed list like you were thinking.

@qedawkins qedawkins merged commit 8d1d867 into iree-org:main Jan 14, 2025
35 of 36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants