@@ -14,11 +14,6 @@ FetchContent_Declare(
14
14
GIT_REPOSITORY https://github.com/microsoft/mscclpp.git
15
15
GIT_TAG 11e62024d3eb190e005b4689f8c8443d91a6c82e)
16
16
17
- set (BOOST_ENABLE_CMAKE ON )
18
- FetchContent_Declare(boost_math
19
- GIT_REPOSITORY https://github.com/boostorg/math.git)
20
- FetchContent_MakeAvailable(boost_math)
21
-
22
17
# -----------------------------------------------------------------------------#
23
18
24
19
find_package (Python3 REQUIRED)
@@ -81,10 +76,21 @@ if(FLASHINFER_ENABLE_BF16)
81
76
add_definitions (-DFLASHINFER_ENABLE_BF16)
82
77
endif (FLASHINFER_ENABLE_BF16)
83
78
79
+
84
80
if (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
81
+ # --- Dependencies specific to FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS -------#
82
+ include (FetchContent)
83
+
84
+ set (BOOST_ENABLE_CMAKE ON )
85
+ FetchContent_Declare(boost_math
86
+ GIT_REPOSITORY https://github.com/boostorg/math.git)
87
+ FetchContent_MakeAvailable(boost_math)
88
+ # --------------------------------------------------------------------------#
85
89
set (USE_FP16_QK_REDUCTIONS "true" )
90
+ message (STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS} " )
86
91
else (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
87
92
set (USE_FP16_QK_REDUCTIONS "false" )
93
+ message (STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS} " )
88
94
endif (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
89
95
90
96
# generate kernel inst
@@ -98,6 +104,34 @@ message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
98
104
message (STATUS "FLASHINFER_USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS} " )
99
105
message (STATUS "FLASHINFER_MASK_MODES=${MASK_MODES} " )
100
106
107
+ #----------------------- SM90 head dims computation ------------------------#
108
+ set (SM90_ALLOWED_HEAD_DIMS "64,64" "128,128" "256,256" "192,128" )
109
+ set (HEAD_DIMS_SM90 "" )
110
+
111
+ foreach (DIM_VAL ${HEAD_DIMS} )
112
+ string (CONCAT TUPLE_VAL "${DIM_VAL} " "," "${DIM_VAL} " )
113
+ list (FIND SM90_ALLOWED_HEAD_DIMS ${TUPLE_VAL} RESULT)
114
+ if (NOT ${RESULT} EQUAL -1)
115
+ list (APPEND HEAD_DIMS_SM90 ${TUPLE_VAL} )
116
+ endif (NOT ${RESULT} EQUAL -1)
117
+ endforeach (DIM_VAL)
118
+
119
+ foreach (TUPLE_VAL ${SM90_ALLOWED_HEAD_DIMS} )
120
+ string (REPLACE "," ";" HEAD_DIMS_LIST ${TUPLE_VAL} )
121
+ list (GET HEAD_DIMS_LIST 0 K)
122
+ list (GET HEAD_DIMS_LIST 1 V)
123
+ if (NOT K EQUAL V)
124
+ list (APPEND HEAD_DIMS_SM90 ${TUPLE_VAL} )
125
+ endif (NOT K EQUAL V)
126
+ endforeach (TUPLE_VAL)
127
+
128
+ list (REMOVE_DUPLICATES HEAD_DIMS_SM90)
129
+ #---------------------------------------------------------------------------#
130
+
131
+ # Log SM90_ALLOWED_HEAD_DIMS and HEAD_DIMS_SM90
132
+ message (STATUS "SM90_ALLOWED_HEAD_DIMS=${SM90_ALLOWED_HEAD_DIMS} " )
133
+ message (STATUS "HEAD_DIMS_SM90=${HEAD_DIMS_SM90} " )
134
+
101
135
set (GENERATED_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR} /src/generated )
102
136
cmake_path(GET GENERATED_SOURCE_DIR PARENT_PATH GENERATED_SOURCE_DIR_ROOT)
103
137
file (MAKE_DIRECTORY ${GENERATED_SOURCE_DIR} )
@@ -114,7 +148,7 @@ set(AOT_GENERATE_COMMAND
114
148
set (AOT_GENERATE_DISPATCH_INC_COMMAND
115
149
${Python3_EXECUTABLE} -m aot_build_utils.generate_dispatch_inc --path
116
150
"${GENERATED_SOURCE_DIR} /dispatch.inc" --head_dims ${HEAD_DIMS}
117
- --head_dims_sm90 "64,64" # FIXME Make this configurable.
151
+ --head_dims ${HEAD_DIMS_SM90}
118
152
--pos_encoding_modes ${POS_ENCODING_MODES} --use_fp16_qk_reductions
119
153
${USE_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} )
120
154
@@ -168,7 +202,6 @@ string(
168
202
"--threads=1 "
169
203
"-Xfatbin=-compress-all "
170
204
"-use_fast_math "
171
- "-DFLASHINFER_ENABLE_F16 "
172
205
"--expt-relaxed-constexpr " )
173
206
string (CONCAT FLASHINFER_CXXFLAGS "${WARNING_FLAGS} " "${SECURITY_FLAGS} " )
174
207
@@ -179,11 +212,16 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS}")
179
212
180
213
add_library (decode_kernels STATIC ${DECODE_KERNELS_SRCS} )
181
214
target_include_directories (decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR} )
182
- target_link_libraries (decode_kernels PRIVATE Boost::math)
215
+ if (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
216
+ target_link_libraries (decode_kernels PRIVATE Boost::math)
217
+ endif (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
183
218
184
219
add_library (prefill_kernels STATIC ${PREFILL_KERNELS_SRCS} )
185
220
target_include_directories (prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR} )
186
- target_link_libraries (prefill_kernels PRIVATE Boost::math)
221
+ if (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
222
+ add_definitions (-DFP16_QK_REDUCTION_SUPPORTED)
223
+ target_link_libraries (prefill_kernels PRIVATE Boost::math)
224
+ endif (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
187
225
188
226
if (FLASHINFER_UNITTESTS)
189
227
add_subdirectory (tests)
0 commit comments