Skip to content

feat: add side_effect and backend_config to jit_call #425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def JITCallOp: EnzymeXLA_Op<"jit_call", [DeclareOpInterfaceMethods<SymbolUserOpI
let arguments = (ins
FlatSymbolRefAttr:$fn,
Variadic<AnyType>:$inputs,
DefaultValuedStrAttr<StrAttr, "">:$backend_config,
DefaultValuedOptionalAttr<BoolAttr, "false">:$has_side_effect,
DefaultValuedOptionalAttr<AnyAttrOf<[StrAttr, DictionaryAttr]>, "">:$backend_config,
OptionalAttr<AnyAttr>:$operand_layouts,
OptionalAttr<AnyAttr>:$result_layouts,
DefaultValuedOptionalAttr<
Expand Down
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ enzymexla::JITCallOp ReadOnlyArg<enzymexla::JITCallOp>::create(
ArrayRef<Type> resTys, ArrayAttr outputAliases) const {
return rewriter.create<enzymexla::JITCallOp>(
launchOp.getLoc(), resTys, launchOp.getFn(), launchOp.getInputs(),
launchOp.getBackendConfigAttr(), launchOp.getOperandLayoutsAttr(),
launchOp.getHasSideEffect(), launchOp.getBackendConfigAttr(),
launchOp.getOperandLayoutsAttr(),
/*resultLayouts*/ nullptr, outputAliases);
}

Expand Down
6 changes: 4 additions & 2 deletions src/enzyme_ad/jax/Passes/LowerJIT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,14 +972,16 @@ struct LowerJITPass
SmallVector<NamedAttribute> names;
names.push_back(
NamedAttribute(rewriter.getStringAttr("attr"), backendstr));
auto existingBackendConfig =
op->getAttrOfType<DictionaryAttr>("backend_config");
auto dattr = DictionaryAttr::get(op.getContext(), names);

Operation *replacement;
if (backend == "cuda")
replacement = rewriter.create<stablehlo::CustomCallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
rewriter.getStringAttr("enzymexla_compile_gpu"),
/* has_side_effect*/ rewriter.getBoolAttr(false),
/* has_side_effect*/ op.getHasSideEffectAttr(),
/*backend_config*/ dattr,
/* api_version*/
CustomCallApiVersionAttr::get(
Expand All @@ -991,7 +993,7 @@ struct LowerJITPass
replacement = rewriter.create<stablehlo::CustomCallOp>(
op.getLoc(), op.getResultTypes(), op.getInputs(),
rewriter.getStringAttr("enzymexla_compile_cpu"),
/* has_side_effect*/ rewriter.getBoolAttr(false),
/* has_side_effect*/ op.getHasSideEffectAttr(),
/*backend_config*/ backendstr,
/* api_version*/
CustomCallApiVersionAttr::get(
Expand Down
12 changes: 6 additions & 6 deletions src/enzyme_ad/jax/Passes/LowerKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ bool CompileGPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
auto replacement = rewriter.create<enzymexla::JITCallOp>(
kcall.getLoc(), kcall.getResultTypes(),
mlir::FlatSymbolRefAttr::get(kcall.getContext(), callName),
kcall.getInputs(), kcall.getBackendConfigAttr(),
kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(),
kcall.getOutputOperandAliasesAttr());
kcall.getInputs(), /*has_side_effect*/ rewriter.getBoolAttr(false),
kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(),
kcall.getResultLayoutsAttr(), kcall.getOutputOperandAliasesAttr());
kcall.replaceAllUsesWith(replacement);
kcall.erase();
return true;
Expand Down Expand Up @@ -372,9 +372,9 @@ bool CompileCPUKernel(SymbolTableCollection &symbolTable, mlir::Location loc,
auto replacement = rewriter.create<enzymexla::JITCallOp>(
kcall.getLoc(), kcall.getResultTypes(),
mlir::FlatSymbolRefAttr::get(kcall.getContext(), callName),
kcall.getInputs(), kcall.getBackendConfigAttr(),
kcall.getOperandLayoutsAttr(), kcall.getResultLayoutsAttr(),
kcall.getOutputOperandAliasesAttr());
kcall.getInputs(), /*has_side_effect*/ rewriter.getBoolAttr(false),
kcall.getBackendConfigAttr(), kcall.getOperandLayoutsAttr(),
kcall.getResultLayoutsAttr(), kcall.getOutputOperandAliasesAttr());
kcall.replaceAllUsesWith(replacement);
kcall.erase();
return true;
Expand Down
17 changes: 17 additions & 0 deletions test/lit_tests/jit_call.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-jit{jit=false backend=cpu},enzyme-hlo-opt)" | FileCheck %s

module {
llvm.func internal unnamed_addr fastcc @throw() attributes {dso_local, no_inline, sym_visibility = "private"} {
llvm.unreachable
}

// CHECK-LABEL: func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: stablehlo.custom_call @enzymexla_compile_cpu() {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00", has_side_effect = true} : () -> ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the backend_config remain the same and get passed to the underlying func?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs fixing, the backend config needs to be compiled as part of CompileCall and this backend config should just be the attribute (at least that is what I figured out from the discussion above)

enzymexla.jit_call @throw () {
has_side_effect = true,
backend_config = {bar = 42 : i32}
} : () -> ()
return %arg0 : tensor<4xf32>
}
}
2 changes: 1 addition & 1 deletion test/lit_tests/lowering/cpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ module {
// CHECK-NEXT: }

// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$par0 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$par0 (%arg0) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %0 : tensor<64xi64>
// CHECK-NEXT: }
2 changes: 1 addition & 1 deletion test/lit_tests/lowering/cpujit.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ module {
// CHECK-LABEL: @main
// CHECK-SAME: (%[[ARG0:.+]]: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: %[[CALL:.+]] = stablehlo.custom_call @enzymexla_compile_cpu(%arg0)
// CHECK-SAME: {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00",
// CHECK-SAME: {api_version = 3 : i32, backend_config = "\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00",
// CHECK-SAME: output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %[[CALL]] : tensor<64xi64>
2 changes: 1 addition & 1 deletion test/lit_tests/lowering/gpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ module {
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$call$1 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$call$1 (%arg0) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %0 : tensor<64xi64>
// CHECK-NEXT: }
4 changes: 2 additions & 2 deletions test/lit_tests/lowering/gpu2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ module {
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK: func.func @main(%arg0: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$call$1 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$call$1 (%arg0) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %0 : tensor<64xi64>
// CHECK-NEXT: }
// CHECK: func.func @main2(%arg0: tensor<64xi64>) -> tensor<64xi64> {
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$call$2 (%arg0) {output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: %0 = enzymexla.jit_call @kern$call$2 (%arg0) {backend_config = "", output_operand_aliases = [#stablehlo.output_operand_alias<output_tuple_indices = [], operand_index = 0, operand_tuple_indices = []>]} : (tensor<64xi64>) -> tensor<64xi64>
// CHECK-NEXT: return %0 : tensor<64xi64>
// CHECK-NEXT: }
Loading