Skip to content

Commit db7dd5e

Browse files
committed
Pull in changes from flashinfer-ai#962 into flashinfer-ai#944.
1 parent 0e3c83f commit db7dd5e

File tree

1 file changed

+47
-9
lines changed

1 file changed

+47
-9
lines changed

libflashinfer/CMakeLists.txt

+47-9
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@ FetchContent_Declare(
1414
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
1515
GIT_TAG 11e62024d3eb190e005b4689f8c8443d91a6c82e)
1616

17-
set(BOOST_ENABLE_CMAKE ON)
18-
FetchContent_Declare(boost_math
19-
GIT_REPOSITORY https://github.com/boostorg/math.git)
20-
FetchContent_MakeAvailable(boost_math)
21-
2217
# -----------------------------------------------------------------------------#
2318

2419
find_package(Python3 REQUIRED)
@@ -81,10 +76,21 @@ if(FLASHINFER_ENABLE_BF16)
8176
add_definitions(-DFLASHINFER_ENABLE_BF16)
8277
endif(FLASHINFER_ENABLE_BF16)
8378

79+
8480
if(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
81+
# --- Dependencies specific to FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS -------#
82+
include(FetchContent)
83+
84+
set(BOOST_ENABLE_CMAKE ON)
85+
FetchContent_Declare(boost_math
86+
GIT_REPOSITORY https://github.com/boostorg/math.git)
87+
FetchContent_MakeAvailable(boost_math)
88+
# --------------------------------------------------------------------------#
8589
set(USE_FP16_QK_REDUCTIONS "true")
90+
message(STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
8691
else(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
8792
set(USE_FP16_QK_REDUCTIONS "false")
93+
message(STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
8894
endif(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
8995

9096
# generate kernel inst
@@ -98,6 +104,34 @@ message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
98104
message(STATUS "FLASHINFER_USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
99105
message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}")
100106

107+
#----------------------- SM90 head dims computation ------------------------#
108+
set(SM90_ALLOWED_HEAD_DIMS "64,64" "128,128" "256,256" "192,128")
109+
set(HEAD_DIMS_SM90 "")
110+
111+
foreach(DIM_VAL ${HEAD_DIMS})
112+
string(CONCAT TUPLE_VAL "${DIM_VAL}" "," "${DIM_VAL}")
113+
list(FIND SM90_ALLOWED_HEAD_DIMS ${TUPLE_VAL} RESULT)
114+
if(NOT ${RESULT} EQUAL -1)
115+
list(APPEND HEAD_DIMS_SM90 ${TUPLE_VAL})
116+
endif(NOT ${RESULT} EQUAL -1)
117+
endforeach(DIM_VAL)
118+
119+
foreach(TUPLE_VAL ${SM90_ALLOWED_HEAD_DIMS})
120+
string(REPLACE "," ";" HEAD_DIMS_LIST ${TUPLE_VAL})
121+
list(GET HEAD_DIMS_LIST 0 K)
122+
list(GET HEAD_DIMS_LIST 1 V)
123+
if(NOT K EQUAL V)
124+
list(APPEND HEAD_DIMS_SM90 ${TUPLE_VAL})
125+
endif(NOT K EQUAL V)
126+
endforeach(TUPLE_VAL)
127+
128+
list(REMOVE_DUPLICATES HEAD_DIMS_SM90)
129+
#---------------------------------------------------------------------------#
130+
131+
# Log SM90_ALLOWED_HEAD_DIMS and HEAD_DIMS_SM90
132+
message(STATUS "SM90_ALLOWED_HEAD_DIMS=${SM90_ALLOWED_HEAD_DIMS}")
133+
message(STATUS "HEAD_DIMS_SM90=${HEAD_DIMS_SM90}")
134+
101135
set(GENERATED_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/generated)
102136
cmake_path(GET GENERATED_SOURCE_DIR PARENT_PATH GENERATED_SOURCE_DIR_ROOT)
103137
file(MAKE_DIRECTORY ${GENERATED_SOURCE_DIR})
@@ -114,7 +148,7 @@ set(AOT_GENERATE_COMMAND
114148
set(AOT_GENERATE_DISPATCH_INC_COMMAND
115149
${Python3_EXECUTABLE} -m aot_build_utils.generate_dispatch_inc --path
116150
"${GENERATED_SOURCE_DIR}/dispatch.inc" --head_dims ${HEAD_DIMS}
117-
--head_dims_sm90 "64,64" # FIXME Make this configurable.
151+
--head_dims ${HEAD_DIMS_SM90}
118152
--pos_encoding_modes ${POS_ENCODING_MODES} --use_fp16_qk_reductions
119153
${USE_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES})
120154

@@ -168,7 +202,6 @@ string(
168202
"--threads=1 "
169203
"-Xfatbin=-compress-all "
170204
"-use_fast_math "
171-
"-DFLASHINFER_ENABLE_F16 "
172205
"--expt-relaxed-constexpr ")
173206
string(CONCAT FLASHINFER_CXXFLAGS "${WARNING_FLAGS}" "${SECURITY_FLAGS}")
174207

@@ -179,11 +212,16 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS}")
179212

180213
add_library(decode_kernels STATIC ${DECODE_KERNELS_SRCS})
181214
target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
182-
target_link_libraries(decode_kernels PRIVATE Boost::math)
215+
if(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
216+
target_link_libraries(decode_kernels PRIVATE Boost::math)
217+
endif(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
183218

184219
add_library(prefill_kernels STATIC ${PREFILL_KERNELS_SRCS})
185220
target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
186-
target_link_libraries(prefill_kernels PRIVATE Boost::math)
221+
if(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
222+
add_definitions(-DFP16_QK_REDUCTION_SUPPORTED)
223+
target_link_libraries(prefill_kernels PRIVATE Boost::math)
224+
endif(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
187225

188226
if(FLASHINFER_UNITTESTS)
189227
add_subdirectory(tests)

0 commit comments

Comments
 (0)