Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: change to TORCH_LIBRARY #823

Merged
merged 27 commits into from
Feb 13, 2025
Merged

Conversation

abmfy
Copy link
Contributor

@abmfy abmfy commented Feb 13, 2025

This PR updates FlashInfer's C++/CUDA extensions from pybind11 modules to torch.libraries, which is recommended since PyTorch 2.5.

This is mainly implemented in #764. We have investigated that the issue in #820 was not caused by this PR, so we're opening it up again.

youkaichao and others added 27 commits January 29, 2025 22:23
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Many files in `include/flashinfer/attention/hopper` have their `namespace flashinfer` to `using namespace cute` directly, which is then `using namespace flashinfer`'ed by some files, causing names from `cute` to leak into the global namespace.

This is a temporary fix to make compilations work.

Signed-off-by: abmfy <[email protected]>
@zhyncs
Copy link
Member

zhyncs commented Feb 13, 2025

Please wait a moment, the CUDA graph issue has been fixed by #822.

@yzh119 yzh119 merged commit dbb1e4e into flashinfer-ai:main Feb 13, 2025
yzh119 added a commit that referenced this pull request Feb 13, 2025
Follow up of #823 , the `CutlassSegmentGEMMSM90` API do not have member
`plan_info_vec`
zhyncs pushed a commit that referenced this pull request Feb 13, 2025
apply #662 again, since
we have #823 merged now.

---------

Signed-off-by: youkaichao <[email protected]>
yzh119 added a commit that referenced this pull request Feb 13, 2025
Followup of #823 , we should import `from .. import flashinfer_kernels,
flashinfer_kernels_sm90` instead of `from .. import _kernels,
_kernels_sm90`, otherwise we will be using JIT compilation all the code.

Also add some logic to catch "undefined symbol" errors in case the AOT
wheel compilation is successful but failed to be loaded.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants