-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][EmitC]Expand the MemRefToEmitC pass - Lowering AllocOp
#148257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
if (!memrefType.hasStaticShape()) | ||
return rewriter.notifyMatchFailure( | ||
allocOp.getLoc(), "cannot transform alloc op with dynamic shape"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this always a limitation? I'd imagine its just something we can't handle for now, but could potentially in the future (e.g. if the size of the alloc is the result of some function you could evaluate the function and then use the result in the call to allocate). If we think it may be possible, add a TODO: to figure that out. I'm not 100% on this, so I'll defer to folks who grasp the minutiae in the two dialects more firmly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool cool, I'll mark it as TODO and await comments on this during review.
int64_t totalSize = | ||
memrefType.getNumElements() * memrefType.getElementTypeBitWidth() / 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some contexts bits/byte aren't guaranteed to be 8. IDK if that's the case here or if there's a API we can use to guarantee we use the right constants. If this pattern is used elsewhere its fine. I just know we've run into similar issues on the LLVM side, and its often hard to run down.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found CHAR_BIT
within the code base but I'm having a TODO
for now incase there is a better API.
auto alignment = allocOp.getAlignment(); | ||
if (alignment) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto alignment = allocOp.getAlignment(); | |
if (alignment) { | |
if (auto alignment = allocOp.getAlignment()) { |
I don't see alignment getting used outside of this block...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pointer!
@llvm/pr-subscribers-mlir Author: Jaden Angella (Jaddyen) ChangesThis aims to lower Full diff: https://github.com/llvm/llvm-project/pull/148257.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..ee6b7d89a76a6 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,6 +77,43 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = allocOp.getLoc();
+ auto memrefType = allocOp.getType();
+ if (!memrefType.hasStaticShape())
+ // TODO: Handle Dynamic shapes in the future. If the size
+ // of the allocation is the result of some function, we could
+ // potentially evaluate the function and use the result in the call to
+ // allocate.
+ return rewriter.notifyMatchFailure(
+ allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
+
+ // TODO: Is there a better API to determine the number of bits in a byte in
+ // MLIR?
+ int64_t totalSize = memrefType.getNumElements() *
+ memrefType.getElementTypeBitWidth() / CHAR_BIT;
+ if (auto alignment = allocOp.getAlignment()) {
+ int64_t alignVal = alignment.value();
+ totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
+ }
+ mlir::Value sizeBytes = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize));
+ auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(),
+ memrefType.getElementType());
+ auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc, mallocPtrType, rewriter.getStringAttr("malloc"),
+ mlir::ValueRange{sizeBytes});
+
+ rewriter.replaceOp(allocOp, mallocCall);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -222,6 +259,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..23e1c20670f8c 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -8,6 +8,14 @@ func.func @alloca() {
return
}
+// CHECK-LABEL: alloc()
+func.func @alloc() {
+ // CHECK-NEXT: %0 = "emitc.constant"() <{value = 3996 : index}> : () -> index
+ // CHECK-NEXT: %1 = emitc.call_opaque "malloc"(%0) : (index) -> !emitc.ptr<i32>
+ %alloc = memref.alloc() : memref<999xi32>
+ return
+}
+
// -----
// CHECK-LABEL: memref_store
|
@llvm/pr-subscribers-mlir-emitc Author: Jaden Angella (Jaddyen) ChangesThis aims to lower Full diff: https://github.com/llvm/llvm-project/pull/148257.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index db244d1d1cac8..ee6b7d89a76a6 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -77,6 +77,43 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ mlir::Location loc = allocOp.getLoc();
+ auto memrefType = allocOp.getType();
+ if (!memrefType.hasStaticShape())
+ // TODO: Handle Dynamic shapes in the future. If the size
+ // of the allocation is the result of some function, we could
+ // potentially evaluate the function and use the result in the call to
+ // allocate.
+ return rewriter.notifyMatchFailure(
+ allocOp.getLoc(), "cannot transform alloc op with dynamic shape");
+
+ // TODO: Is there a better API to determine the number of bits in a byte in
+ // MLIR?
+ int64_t totalSize = memrefType.getNumElements() *
+ memrefType.getElementTypeBitWidth() / CHAR_BIT;
+ if (auto alignment = allocOp.getAlignment()) {
+ int64_t alignVal = alignment.value();
+ totalSize = (totalSize + alignVal - 1) / alignVal * alignVal;
+ }
+ mlir::Value sizeBytes = rewriter.create<emitc::ConstantOp>(
+ loc, rewriter.getIndexType(),
+ rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize));
+ auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(),
+ memrefType.getElementType());
+ auto mallocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc, mallocPtrType, rewriter.getStringAttr("malloc"),
+ mlir::ValueRange{sizeBytes});
+
+ rewriter.replaceOp(allocOp, mallocCall);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -222,6 +259,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index d37fd1de90add..23e1c20670f8c 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -8,6 +8,14 @@ func.func @alloca() {
return
}
+// CHECK-LABEL: alloc()
+func.func @alloc() {
+ // CHECK-NEXT: %0 = "emitc.constant"() <{value = 3996 : index}> : () -> index
+ // CHECK-NEXT: %1 = emitc.call_opaque "malloc"(%0) : (index) -> !emitc.ptr<i32>
+ %alloc = memref.alloc() : memref<999xi32>
+ return
+}
+
// -----
// CHECK-LABEL: memref_store
|
int64_t totalSize = memrefType.getNumElements() * | ||
memrefType.getElementTypeBitWidth() / CHAR_BIT; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CHAR_BIT is probably not what you want. That definition will provide the bits/char in this TU when you're compiling the compiler (e.g. mlir-opt), not the target program being compiled (I'm using "compiled" pretty loosely here "lowered" or "translated" are probably more accurate). The # of bits/byte is going to depend on the target your compiling for, and if you wanted to use CHAR_BIT for EmitC, you'd have to emit that in the C code as an expression w/ the right header. Its probably fine for now to assume 8bits/byte, but a more experienced MLIR maintainer would know for sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed making an assumption on layout indeed (see getMemRefEltSizeInBytes for example accounting for it). Would you be able to query the data layout analysis? (its optional for non-LLVM paths at the moment, so check if here or could be used here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memref
to LLVM
lowering implements a version of sizeof
for finding the element size in bytes (see ConvertToLLVMPattern::getSizeInBytes
). Since we're emitting C code you could emit the computation as malloc()
's parameter, e.g
%c = emitc.literal "int" : !emitc.opaque<"type">
%e = emitc.call_opaque "sizeof", %c : !emitc.size_t
%d = emitc.constant 57: !emitc.size_t
%s = emitc.mul %e, %d : !emitc.size_t
%m = emitc.call_opaque "malloc", %s : !emitc.ptr<!emitc.opaque<"void">>
which should translate to
size_t v0 = sizeof(int);
size_t v1 = 57;
size_t v2 = v0 * v1;
void* v3 = malloc(v2);
The form-expressions
pass sould fold this code into a single expression, i.e.
void* v3 = malloc(sizeof(int) * 57);
And the C compiler irons out such static calculations anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, addressed in new change.
int64_t totalSize = memrefType.getNumElements() * | ||
memrefType.getElementTypeBitWidth() / CHAR_BIT; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is indeed making an assumption on layout indeed (see getMemRefEltSizeInBytes for example accounting for it). Would you be able to query the data layout analysis? (its optional for non-LLVM paths at the moment, so check if here or could be used here)
memrefType.getElementTypeBitWidth() / CHAR_BIT; | ||
if (auto alignment = allocOp.getAlignment()) { | ||
int64_t alignVal = alignment.value(); | ||
totalSize = (totalSize + alignVal - 1) / alignVal * alignVal; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
llvm/Support/MathExtras.h has some helpers that could be used here instead.
rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize)); | ||
auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(), | ||
memrefType.getElementType()); | ||
auto mallocCall = rewriter.create<emitc::CallOpaqueOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should now be emitc::CallOpaqueOp::create(rewriter, ...) since recent change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should've been.
I've addressed this in the new changes.
rewriter.getIntegerAttr(rewriter.getIndexType(), totalSize)); | ||
auto mallocPtrType = emitc::PointerType::get(rewriter.getContext(), | ||
memrefType.getElementType()); | ||
auto mallocCall = rewriter.create<emitc::CallOpaqueOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IINM assigning in C++ a void*
to, say, int*
requires explicit casting or the -fpermissive compiler flag. Since we don't have a clear marking of the target C variant in the program we should probably emit an explicit cast as the common ground.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, addressed in new changes.
memrefType.getElementTypeBitWidth() / CHAR_BIT; | ||
if (auto alignment = allocOp.getAlignment()) { | ||
int64_t alignVal = alignment.value(); | ||
totalSize = (totalSize + alignVal - 1) / alignVal * alignVal; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the alignment value to the size won't affect alignment by itself. It could be used for emitting code that moves the start address to an aligned address, but that doesn't seem to be done here, and doing so would create another problem - the aligned pointer will not be the allocated address which should later be passed to free()
(which is why LLVM-dialect memref descriptors carry both allocated
and aligned
pointers). I think lowering to a single pointer would be safe when there's no alignment requirement and when the alignment required is under the target's malloc()
alignment. WDYT @marbre, @simon-camp, @mgehre-amd?
int64_t totalSize = memrefType.getNumElements() * | ||
memrefType.getElementTypeBitWidth() / CHAR_BIT; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The memref
to LLVM
lowering implements a version of sizeof
for finding the element size in bytes (see ConvertToLLVMPattern::getSizeInBytes
). Since we're emitting C code you could emit the computation as malloc()
's parameter, e.g
%c = emitc.literal "int" : !emitc.opaque<"type">
%e = emitc.call_opaque "sizeof", %c : !emitc.size_t
%d = emitc.constant 57: !emitc.size_t
%s = emitc.mul %e, %d : !emitc.size_t
%m = emitc.call_opaque "malloc", %s : !emitc.ptr<!emitc.opaque<"void">>
which should translate to
size_t v0 = sizeof(int);
size_t v1 = 57;
size_t v2 = v0 * v1;
void* v3 = malloc(v2);
The form-expressions
pass sould fold this code into a single expression, i.e.
void* v3 = malloc(sizeof(int) * 57);
And the C compiler irons out such static calculations anyway.
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
LogicalResult | ||
matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The malloc()
function requires including the relevant header file ("stdlib.h" for C, "cstdlib.h" for C++). The pass would have to add to the module such an emitc.include
op to the module or a forward declaration of malloc()
using emitc.declare_func
(in which case it can use emitc.call
instead of emitc.call_opaque
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeap!I've addressed this in the new change.
Thanks for pointing this out.
5bd6616
to
69e5a98
Compare
AllocOp
AllocOp
AllocOp
AllocOp
This aims to lower
memref.alloc
toemitc.call_opaque “malloc”
From:
To:
Which is then translated as: