Fix MSVC CUDA build: is_unsigned_v not available in cutlass::platform#3229
Open
TxsharDev wants to merge 1 commit into
Open
Fix MSVC CUDA build: is_unsigned_v not available in cutlass::platform#3229TxsharDev wants to merge 1 commit into
TxsharDev wants to merge 1 commit into
Conversation
9bf30ce to
ebfa4c3
Compare
is_integral_v and is_unsigned_v in platform.h were imported from the STL inside a #if (201703L <= __cplusplus) guard. On MSVC, __cplusplus is not set to 201703L without /Zc:__cplusplus, so the imports never fire -- but CUTLASS_CXX17_OR_LATER correctly evaluates to 1 via _MSVC_LANG, enabling code that uses these traits. Result: symbol used but never declared. Breaks Windows CUDA builds of downstream projects (e.g. flash-attention with PyTorch >= 2.10). Fix: define is_integral_v and is_unsigned_v as constexpr bool variable templates using ::value, removing the fragile __cplusplus guard. Also import is_unsigned unconditionally (matching is_arithmetic, is_void, etc. which are already unconditional). Add unit test.
ebfa4c3 to
d70be29
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Building any project that includes
cutlass/exmy_base.hon Windows with MSVC + CUDA fails with:Hit this building flash-attention against PyTorch 2.10+ on Windows (CUDA 12.8, MSVC 14.43).
Root cause
is_integral_vandis_unsigned_vinplatform.hare imported from the STL inside a#if (201703L <= __cplusplus)guard. On MSVC,__cplusplusis not set to201703Lwithout/Zc:__cplusplus, so the imports never fire. ButCUTLASS_CXX17_OR_LATER(which correctly checks_MSVC_LANGviaCUTLASS_CPLUSPLUS) evaluates to1, enabling code that uses these traits. Symbol used but never declared.Fix
platform.h: Defineis_integral_vandis_unsigned_vasconstexpr boolvariable templates using::value, removing the fragile__cplusplusguard entirely. Also importis_unsignedunconditionally (matchingis_arithmetic,is_void, etc. which are already unconditional in the same file).exmy_base.h- fixingplatform.hmakes the existing code work as intended.platform_traits.cppto verify both traits are available and correct.Verified
Built flash-attention 2.8.3 from source on Windows with this fix applied — MSVC 14.43, CUDA 12.8, PyTorch 2.11, Python 3.12. Clean build, all kernels compile.