From d7f841b8e3124bdf9143786ed04dbe90a74805fc Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Mar 2025 20:35:50 -0500 Subject: [PATCH 1/2] feat: early fail if not correct region --- deps/ReactantExtra/API.cpp | 4 ++++ src/Ops.jl | 5 +++++ src/mlir/IR/Operation.jl | 26 ++++++++++++++++++++++++-- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index fc1ce453e7..5c0f0a85b3 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -2404,3 +2404,7 @@ extern "C" void dump_operation(Operation *op, const char *filename) { extern "C" bool pjrt_device_is_addressable(PjRtDevice *device) { return device->IsAddressable(); } + +extern "C" mlir::Operation *mlirGetParentOfTypeFunctionOp(mlir::Operation *op) { + return op->getParentOfType(); +} diff --git a/src/Ops.jl b/src/Ops.jl index 7c4ebbd2c6..75cb3f869c 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -126,6 +126,11 @@ end result_inference=false, ) + parent_func_op = MLIR.IR.get_parent_of_type_function_op(cstop) + if parent_func_op == C_NULL + error("Constant must be created inside a Function Op.") + end + res = MLIR.IR.result(cstop) tres = TracedRArray{T,N}((), res, size(x)) constants[value] = tres diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 32f42b6838..eb9b6d8fd2 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -331,8 +331,20 @@ function create_operation_common( end end +function create_operation_common_with_checks(args...; operands=nothing, kwargs...) + op = create_operation_common(args...; operands, kwargs...) + # if !isnothing(operands) + # parent_function_op = get_parent_of_type_function_op(op) + # if parent_function_op != C_NULL + # function_op_region = parent_region(parent_function_op) + # # TODO: add the checks + # end + # end + return op +end + function create_operation(args...; kwargs...) - res = create_operation_common(args...; kwargs...) + res = create_operation_common_with_checks(args...; kwargs...) if _has_block() push!(block(), res) end @@ -340,7 +352,17 @@ function create_operation(args...; kwargs...) end function create_operation_at_front(args...; kwargs...) - res = create_operation_common(args...; kwargs...) + res = create_operation_common_with_checks(args...; kwargs...) Base.pushfirst!(block(), res) return res end + +function get_parent_of_type_function_op(op::Operation) + GC.@preserve op begin + funcop = @ccall API.mlir_c.mlirGetParentOfTypeFunctionOp( + op::API.MlirOperation + )::API.MlirOperation + end + funcop.ptr == C_NULL && return C_NULL + return Operation(funcop, false) +end From a3912bccb79270756100c2f0758f92c76c611520 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Mar 2025 20:39:58 -0500 Subject: [PATCH 2/2] feat: more constant checks --- src/Ops.jl | 12 ++++++++++++ src/mlir/IR/Operation.jl | 21 ++++++++++++++------- src/mlir/IR/Value.jl | 2 ++ 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 75cb3f869c..0f13fa097a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -206,6 +206,12 @@ for (T, mlir_func) in ( splatattr = MLIR.API.$mlir_func(tt, number) cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) + + parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op) + if parent_func_op == C_NULL + error("Constant must be created inside a Function Op.") + end + cst = MLIR.IR.result(cst_op) ta = TracedRArray{$T,length(shape)}((), cst, shape) return ta @@ -226,6 +232,12 @@ end tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element)) cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) + + parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op) + if parent_func_op == C_NULL + error("Constant must be created inside a Function Op.") + end + cst = MLIR.IR.result(cst_op) ta = TracedRArray{T,length(shape)}((), cst, shape) return ta diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index eb9b6d8fd2..c1768b44a7 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -68,6 +68,12 @@ Gets the operation that owns this operation, returning null if the operation is parent_op(operation::Operation) = Operation(API.mlirOperationGetParentOperation(operation), false) +""" + parent_region(op) +Gets the region that owns this operation. +""" +parent_region(operation::Operation) = parent_region(block(operation)) + """ rmfromparent!(op) @@ -333,13 +339,14 @@ end function create_operation_common_with_checks(args...; operands=nothing, kwargs...) op = create_operation_common(args...; operands, kwargs...) - # if !isnothing(operands) - # parent_function_op = get_parent_of_type_function_op(op) - # if parent_function_op != C_NULL - # function_op_region = parent_region(parent_function_op) - # # TODO: add the checks - # end - # end + if !isnothing(operands) + parent_function_op = get_parent_of_type_function_op(op) + if parent_function_op != C_NULL + function_op_region = parent_region(parent_function_op) + operand_region = parent_region.(operands) + # TODO: add the checks + end + end return op end diff --git a/src/mlir/IR/Value.jl b/src/mlir/IR/Value.jl index a24632d934..38c877f763 100644 --- a/src/mlir/IR/Value.jl +++ b/src/mlir/IR/Value.jl @@ -121,3 +121,5 @@ function Base.show(io::IO, value::Value) API.mlirValuePrint(value, c_print_callback, ref) end end + +parent_region(value::Value) = parent_region(owner(value))