diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index ddd4a4d41..ed5c2d439 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -7712,11 +7712,11 @@ template struct CSE final : OpRewritePattern { continue; if (!isa(nop)) continue; + if (nop->getBlock() != op->getBlock()) + continue; if (!OperationEquivalence::isEquivalentTo( op, nop, OperationEquivalence::IgnoreLocations)) continue; - if (nop->getBlock() != op->getBlock()) - continue; if (nop->isBeforeInBlock(op)) { rewriter.replaceOp(op, nop); return success(); diff --git a/src/enzyme_ad/jax/Passes/Passes.h b/src/enzyme_ad/jax/Passes/Passes.h index b06cd3ff6..e21cd745a 100644 --- a/src/enzyme_ad/jax/Passes/Passes.h +++ b/src/enzyme_ad/jax/Passes/Passes.h @@ -22,6 +22,10 @@ class PatternRewriter; class AffineMap; class DominanceInfo; +namespace arith { +class ArithDialect; +} + namespace enzyme { void populateAffineCFGPatterns(RewritePatternSet &rpl); diff --git a/src/enzyme_ad/jax/Passes/Passes.td b/src/enzyme_ad/jax/Passes/Passes.td index fb63a96f0..489c6066c 100644 --- a/src/enzyme_ad/jax/Passes/Passes.td +++ b/src/enzyme_ad/jax/Passes/Passes.td @@ -555,5 +555,21 @@ def AffineToStableHLORaising : Pass<"raise-affine-to-stablehlo"> { ]; } +def SplitHugeBlocksPass : InterfacePass<"split-huge-blocks", "mlir::FunctionOpInterface"> { + let summary = "Split huge blocks into smaller blocks"; + let dependentDialects = [ + "cf::ControlFlowDialect", + ]; + let options = [ + Option< + /*C++ variable name=*/"max_num_operations", + /*CLI argument=*/"max_num_operations", + /*type=*/"int64_t", + /*default=*/"-1", + /*description=*/"Maximum number of operations per block, -1 being unlimited">, + + ]; +} + #endif diff --git a/src/enzyme_ad/jax/Passes/SplitHugeBlocks.cpp b/src/enzyme_ad/jax/Passes/SplitHugeBlocks.cpp new file mode 100644 index 000000000..feae6ef2e --- /dev/null +++ b/src/enzyme_ad/jax/Passes/SplitHugeBlocks.cpp @@ -0,0 +1,49 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "src/enzyme_ad/jax/Passes/Passes.h" + +namespace mlir { +namespace enzyme { +#define GEN_PASS_DEF_SPLITHUGEBLOCKSPASS +#include "src/enzyme_ad/jax/Passes/Passes.h.inc" +} // namespace enzyme +} // namespace mlir + +using namespace mlir; +using namespace mlir::enzyme; + +static void splitLargeBlock(RewriterBase &rewriter, Block *block, + uint64_t maxNumOperations) { + do { + Block::iterator it = block->begin(); + for (uint64_t i = 0; i < maxNumOperations; ++i) { + if (it == block->end()) + return; + it = std::next(it); + } + Block *current = block; + block = rewriter.splitBlock(block, it); + rewriter.setInsertionPointToEnd(current); + rewriter.create(rewriter.getUnknownLoc(), block); + } while (true); +} + +struct SplitHugeBlocksPass + : public enzyme::impl::SplitHugeBlocksPassBase { + using SplitHugeBlocksPassBase::SplitHugeBlocksPassBase; + + void runOnOperation() override { + if (max_num_operations == -1) + return; + auto context = getOperation()->getContext(); + IRRewriter rewriter(context); + SmallVector originalBlocks = llvm::map_to_vector( + getOperation().getFunctionBody(), [](Block &b) { return &b; }); + for (Block *block : originalBlocks) { + splitLargeBlock(rewriter, block, max_num_operations); + } + } +};