Skip to content

Gather To Slice op in stableHlo #550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6935,6 +6935,246 @@ struct NoopReverse final : OpRewritePattern<mlir::stablehlo::ReverseOp> {
}
};

bool check_periodicity(std::vector<int> &sliceStart,
std::vector<int> &sliceStride, std::vector<int> &prev,
int dim, int k) {
if (sliceStart[k] == -1) {
sliceStart[k] = dim;
prev[k] = dim;
return true;
}
if (dim - prev[k] == sliceStride[k]) {
prev[k] = dim;
return true;
} else
return false;
}

int get_index(const std::vector<int> &strides, const std::vector<int> &dims,
int rank_lower, int rank_upper, int iter) {
int index = 0;
for (int i = rank_lower; i >= rank_upper; i--) {
int mod = iter % dims[i];
iter = iter / dims[i];
index += mod * strides[i];
}
return index;
}

// Rank 0 to n-1, where 0 is the slowest moving dimension (outermost)
bool isGatherPeriodic(const std::vector<int> &data,
const std::vector<int> &dims, int rank, int x_dim,
std::vector<int> &sliceStart, std::vector<int> &sliceEnd,
std::vector<int> &sliceStride) {
// Calculate total elements and strides
int total_size = 1;
std::vector<int> strides(rank);
for (int i = 0; i < rank; i++)
total_size *= dims[i];

strides[rank - 1] = 1;
for (int i = rank - 2; i >= 0; i--)
strides[i] = strides[i + 1] * dims[i + 1];

int outer_batch_size = 1;
for (int i = 0; i <= x_dim; i++) {
outer_batch_size *= dims[i];
}
int inner_batch_size = 1;
for (int i = x_dim + 1; i < rank; i++) {
inner_batch_size *= dims[i];
}

int vec_length = dims[x_dim];

// Run 2 iterations of each rank to get the stride on each rank
bool strideFound = false;
for (int i = 0; i < outer_batch_size / vec_length; i++) {
int index_outer = get_index(strides, dims, x_dim - 1, 0, i);
for (int j = 0; j < inner_batch_size; j++) {
int index_inner = get_index(strides, dims, rank - 1, x_dim + 1, j);
for (int k = 0; k < vec_length; k++) {
int index = index_outer + k * inner_batch_size + index_inner;
int value = data[index];
if (sliceStride[k] == -1) {
sliceStride[k] = value;
} else {
sliceStride[k] = value - sliceStride[k];
strideFound = true;
}
}
if (strideFound)
break;
}
if (strideFound)
break;
}

// Run all the iterations to check if the strides match
std::vector<int> prev(vec_length, -1);
for (int i = 0; i < outer_batch_size / vec_length; i++) {
int index_outer = get_index(strides, dims, x_dim - 1, 0, i);
for (int j = 0; j < inner_batch_size; j++) {
int index_inner = get_index(strides, dims, rank - 1, x_dim + 1, j);
for (int k = 0; k < vec_length; k++) {
int index = index_outer + k * inner_batch_size + index_inner;
int value = data[index];
auto res = check_periodicity(sliceStart, sliceStride, prev, value, k);
if (!res)
return false;
}
}
}
for (int k = 0; k < vec_length; k++) {
if(sliceStride[k] >= 0) {
sliceEnd[k] = prev[k] + 1;
}
else if (sliceStride[k] < 0) {
sliceEnd[k] = prev[k];
sliceStart[k] = sliceStart[k] + 1;
}
}
return true;
}

/// Converts gather ops to slice ops in case we have a single set of constant
/// indices.
struct GatherToSliceOp final : OpRewritePattern<mlir::stablehlo::GatherOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::GatherOp gather,
PatternRewriter &rewriter) const override {
DenseIntElementsAttr index;
// 1. Check following preconditions for converting gather to slice op
// Preconditions:
// i. Check if all dims are collapsed?
// ii. Check if sliceSizes are 1 in each dim (else more complicated and
// there could be a subset that
// could still be transformed to slice op if there is no overlap)
// iii. Based on indexVector dim, check if correspoding values in each dim
// are either constant or strided
// 2. If so, convert the gather to a slice op
// 3. If not, return failure
// 4. To convert to slice operation,

// Get all the other properties of the gather operation
auto startIndices = gather.getStartIndices();
auto gatherOperands = gather.getOperands();
auto gatherDimensionNumbers = gather.getDimensionNumbers();
auto gatherSliceSizes = gather.getSliceSizes();
auto gatherIndexVectorDim = gatherDimensionNumbers.getIndexVectorDim();
auto gatherStartIndexMap = gatherDimensionNumbers.getStartIndexMap();
auto collapsedSliceDims = gatherDimensionNumbers.getCollapsedSliceDims();
auto offsetDims =
gatherDimensionNumbers.getOffsetDims(); // Non collapsed dimensions

////Check rank of gather operands
// auto gatherOperandsRank = gatherOperands.getType().getRank();

//// TODO: Currently only handling simplified case
//// Check collapseSlicedims size equals gatherOperandsRank
// if(collapsedSliceDims.size() != gatherOperandsRank)
// return failure();

// Check if sliceSizes are 1 in each dim
if (llvm::any_of(gatherSliceSizes, [](int64_t size) { return size != 1; }))
return failure();

// From start indices check if it's constant, else return failure
if (!matchPattern(startIndices, m_Constant(&index)))
return failure();

// Currently only handling constant index case with dense elements
std::vector<int> indices;
if (auto denseAttr = llvm::dyn_cast<mlir::DenseElementsAttr>(index)) {
for (auto value : denseAttr.getValues<APInt>()) {
int64_t intValue = value.getSExtValue();
indices.push_back(intValue);
}
} else {
return failure();
}

// Process indices in row-major order
auto tensorType = index.getType().cast<mlir::RankedTensorType>();
auto rank = tensorType.getRank();
auto shape = tensorType.getShape();

// dims : gatherIndexVectorDim, innermostDim .... outermostDim
int indexVectorSize = shape[gatherIndexVectorDim];

std::vector<int> dims(shape.size());
for (int i = 0; i < shape.size(); i++) {
dims[i] = shape[i];
}

std::vector<int> sliceStart(indexVectorSize, -1);
std::vector<int> sliceEnd(indexVectorSize, -1);
std::vector<int> sliceStride(indexVectorSize, -1);

auto isPeriodic =
isGatherPeriodic(indices, dims, rank, gatherIndexVectorDim, sliceStart,
sliceEnd, sliceStride);
if (!isPeriodic)
return failure();

SmallVector<int64_t, 4> sliceStartI64(sliceStart.begin(), sliceStart.end());
SmallVector<int64_t, 4> sliceEndI64(sliceEnd.begin(), sliceEnd.end());
SmallVector<int64_t, 4> sliceStrideI64(sliceStride.begin(),
sliceStride.end());

// Create the slice type
SmallVector<int64_t, 4> sliceShape(sliceStrideI64.size());
SmallVector<int64_t, 4> collapsedShape;
SmallVector<int64_t, 4> reverseDims;
bool reverse = false;
for (int i = 0; i < sliceStrideI64.size(); i++) {
sliceShape[i] = sliceEndI64[i] - sliceStartI64[i];
//Reverse the slice if it's negative
if(sliceShape[i] < 0) {
sliceShape[i] = -sliceShape[i];
sliceStrideI64[i] = -sliceStrideI64[i];
auto temp = sliceStartI64[i];
sliceStartI64[i] = sliceEndI64[i];
sliceEndI64[i] = temp;
reverseDims.push_back(i);
reverse = true;
}

if (sliceShape[i] != 1)
collapsedShape.push_back(sliceShape[i]);
}
Type elementType = gather.getType().getElementType();
auto sliceType = RankedTensorType::get(sliceShape, elementType);

// Fix the constant dims
for (int i = 0; i < sliceStrideI64.size(); i++) {
if (sliceStrideI64[i] == 0)
sliceStrideI64[i] = 1;
}

// Creating the slice op
Value sliceOp = rewriter.create<mlir::stablehlo::SliceOp>(
gather.getLoc(), sliceType, gather.getOperand(),
rewriter.getDenseI64ArrayAttr(sliceStartI64),
rewriter.getDenseI64ArrayAttr(sliceEndI64),
rewriter.getDenseI64ArrayAttr(sliceStrideI64));

if(reverse) {
sliceOp = rewriter.create<mlir::stablehlo::ReverseOp>(gather.getLoc(), sliceOp, reverseDims);
}

// Create result type and reshape operation
auto collapsedType = RankedTensorType::get(collapsedShape, elementType);
Value sliceReshaped = rewriter.create<mlir::stablehlo::ReshapeOp>(
gather.getLoc(), collapsedType, sliceOp);

rewriter.replaceOp(gather, sliceReshaped);

return success();
}
};

/// Converts gather ops to slice ops in case we have a single set of constant
/// indices.
struct GatherOpCanon final : OpRewritePattern<mlir::stablehlo::GatherOp> {
Expand Down Expand Up @@ -8215,6 +8455,7 @@ struct EnzymeHLOOptPass

// clang-format off
patterns.add<
GatherToSliceOp,
BroadcastInDimOpCanon,
ChainedDynamicBroadcastInDimCanonicalization,
CompareOpCanon,
Expand Down
Loading
Loading