I dusted off my toy compiler this weekend, and had a bit of fun implementing IF and ELSE support. The MLIR+LLVM compiler infrastructure is impressively versatile, allowing me to implement a feature like this in a handful of hours. It was, however, an invasive feature addition, requiring grammar changes, parser/builder, and lowering adjustments, as well as additional lowering for the SCF dialect ops that were used.
I have an ELIF element in the grammar too, but haven’t done that yet (and haven’t done much testing yet.) These are the grammar elements I now have defined:
ifelifelse
: ifStatement
elifStatement*
elseStatement?
;
ifStatement
: IF_TOKEN BRACE_START_TOKEN booleanValue BRACE_END_TOKEN SCOPE_START_TOKEN statement* SCOPE_END_TOKEN
;
elifStatement
: ELIF_TOKEN BRACE_START_TOKEN booleanValue BRACE_END_TOKEN SCOPE_START_TOKEN statement* SCOPE_END_TOKEN
;
elseStatement
: ELSE_TOKEN SCOPE_START_TOKEN statement* SCOPE_END_TOKEN
;
Previously, I had a single monolithic ifelifelse token, but the generated parser data structure was horrendously complicated:
class IfelifelseContext : public antlr4::ParserRuleContext {
public:
IfelifelseContext(antlr4::ParserRuleContext *parent, size_t invokingState);
virtual size_t getRuleIndex() const override;
antlr4::tree::TerminalNode *IF_TOKEN();
std::vector<antlr4::tree::TerminalNode *> BRACE_START_TOKEN();
antlr4::tree::TerminalNode* BRACE_START_TOKEN(size_t i);
std::vector<BooleanValueContext *> booleanValue();
BooleanValueContext* booleanValue(size_t i);
std::vector<antlr4::tree::TerminalNode *> BRACE_END_TOKEN();
antlr4::tree::TerminalNode* BRACE_END_TOKEN(size_t i);
std::vector<antlr4::tree::TerminalNode *> SCOPE_START_TOKEN();
antlr4::tree::TerminalNode* SCOPE_START_TOKEN(size_t i);
std::vector<antlr4::tree::TerminalNode *> SCOPE_END_TOKEN();
antlr4::tree::TerminalNode* SCOPE_END_TOKEN(size_t i);
std::vector<StatementContext *> statement();
StatementContext* statement(size_t i);
std::vector<antlr4::tree::TerminalNode *> ELIF_TOKEN();
antlr4::tree::TerminalNode* ELIF_TOKEN(size_t i);
antlr4::tree::TerminalNode *ELSE_TOKEN();
virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override;
virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override;
};
In particular, I don’t see a way to determine if a statement should be part of the IF, the ELIF, or the ELSE body. Splitting the grammar into pieces was much better, and leave me with:
class IfelifelseContext : public antlr4::ParserRuleContext {
public:
IfelifelseContext(antlr4::ParserRuleContext *parent, size_t invokingState);
virtual size_t getRuleIndex() const override;
IfStatementContext *ifStatement();
std::vector<ElifStatementContext *> elifStatement();
ElifStatementContext* elifStatement(size_t i);
ElseStatementContext *elseStatement();
virtual void enterRule(antlr4::tree::ParseTreeListener *listener) override;
virtual void exitRule(antlr4::tree::ParseTreeListener *listener) override;
};
which allows me to drill down into each of the lower level elements, and drive the codegen from those. In particular, I can define ANTLR4 parse tree walker callbacks:
void exitIfStatement(ToyParser::IfStatementContext *ctx) override;
void exitElseStatement(ToyParser::ElseStatementContext *ctx) override;
void enterIfelifelse( ToyParser::IfelifelseContext *ctx ) override;
and drive the codegen from there. The choice to use enter/exit callbacks for the various statement objects locks you into a certain mode of programming. In particular, this means that I don’t have functions for parsing a generic statement, so I am forced to emit the if/else body statements indirectly. Here’s an example:
mlir::Value conditionPredicate = MLIRListener::parsePredicate( loc, booleanValue );
insertionPointStack.push_back( builder.saveInsertionPoint() );
// Create the scf.if — it will be inserted at the current IP
auto ifOp = builder.create<mlir::scf::IfOp>( loc, conditionPredicate );
mlir::Block &thenBlock = ifOp.getThenRegion().front();
builder.setInsertionPointToStart( &thenBlock );
Then in the exitIf callback, if there is an else statement to process, I create a block for that, set the insertion point for that, and let the statement processing continue. When that else processing is done, I can restore the insertion point to just after the scf.if, and the rest of the function generation can proceed. Here’s an example of the MLIR dump before any lowering:
module {
func.func @main() -> i32 {
"toy.scope"() ({
"toy.declare"() <{type = i32}> {sym_name = "x"} : () -> ()
%c3_i64 = arith.constant 3 : i64
"toy.assign"(%c3_i64) <{var_name = @x}> : (i64) -> ()
%0 = "toy.load"() <{var_name = @x}> : () -> i32
%c4_i64 = arith.constant 4 : i64
%1 = "toy.less"(%0, %c4_i64) : (i32, i64) -> i1
scf.if %1 {
%5 = "toy.string_literal"() <{value = "x < 4"}> : () -> !llvm.ptr
toy.print %5 : !llvm.ptr
%6 = "toy.string_literal"() <{value = "a second statement"}> : () -> !llvm.ptr
toy.print %6 : !llvm.ptr
} else {
%5 = "toy.string_literal"() <{value = "!(x < 4) -- should be dead code"}> : () -> !llvm.ptr
toy.print %5 : !llvm.ptr
}
%2 = "toy.load"() <{var_name = @x}> : () -> i32
%c5_i64 = arith.constant 5 : i64
%3 = "toy.less"(%c5_i64, %2) : (i64, i32) -> i1
scf.if %3 {
%5 = "toy.string_literal"() <{value = "x > 5"}> : () -> !llvm.ptr
toy.print %5 : !llvm.ptr
} else {
%5 = "toy.string_literal"() <{value = "!(x > 5) -- should see this"}> : () -> !llvm.ptr
toy.print %5 : !llvm.ptr
}
%4 = "toy.string_literal"() <{value = "Done."}> : () -> !llvm.ptr
toy.print %4 : !llvm.ptr
%c0_i32 = arith.constant 0 : i32
"toy.return"(%c0_i32) : (i32) -> ()
}) : () -> ()
"toy.yield"() : () -> ()
}
}
This is the IR for the following “program”:
INT32 x;
x = 3;
IF ( x < 4 )
{
PRINT "x < 4";
PRINT "a second statement";
}
ELSE
{
PRINT "!(x < 4) -- should be dead code";
};
IF ( x > 5 )
{
PRINT "x > 5";
}
ELSE
{
PRINT "!(x > 5) -- should see this";
};
PRINT "Done.";
There are existing mechanisms for lowering the SCF dialect, so it doesn’t take much work. There was one lowering quirk that I didn’t expect to have to deal with. I lower the toy.print Op in two steps, first to toy.call, and then let all the toy.call lowering kick in — however, I was doing that as part of a somewhat hacky toy.scope lowering step. That didn’t work anymore, since I can now have functions that aren’t in the scope, but part of an scf if or else block body. To fix that, I had to switch to a more conventional CallOp lowering class:
class CallOpLowering : public ConversionPattern {
public:
explicit CallOpLowering(MLIRContext* context)
: ConversionPattern(toy::CallOp::getOperationName(), 1, context) {}
LogicalResult matchAndRewrite(Operation* op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override {
auto callOp = cast<toy::CallOp>(op);
auto loc = callOp.getLoc();
// Get the callee symbol reference (stored as "callee" attribute)
auto calleeAttr = callOp->getAttrOfType<FlatSymbolRefAttr>("callee");
if (!calleeAttr)
return failure();
// Get result types (empty for void, one type for scalar return)
TypeRange resultTypes = callOp.getResultTypes();
auto mlirCall = rewriter.create<mlir::func::CallOp>( loc, resultTypes,
calleeAttr, callOp.getOperands() );
// Replace uses correctly
if (!resultTypes.empty()) {
// Non-void: replace the single result
rewriter.replaceOp(op, mlirCall.getResults());
} else {
// Void: erase the op (no result to replace)
rewriter.eraseOp(op);
}
return success();
}
};
I lower all calls (print and any other) first to mlir::func::CallOp, and then let existing mlir func lowering kick in and do the rest.
After my “stage I” lowering, this program is mostly translated into LLVM:
module attributes {llvm.ident = "toycalculator V2"} {
llvm.mlir.global private constant @str_5(dense<[68, 111, 110, 101, 46]> : tensor<5xi8>) {addr_space = 0 : i32} : !llvm.array<5 x i8>
llvm.mlir.global private constant @str_4(dense<[33, 40, 120, 32, 62, 32, 53, 41, 32, 45, 45, 32, 115, 104, 111, 117, 108, 100, 32, 115, 101, 101, 32, 116, 104, 105, 115]> : tensor<27xi8>) {addr_space = 0 : i32} : !llvm.array<27 x i8>
llvm.mlir.global private constant @str_3(dense<[120, 32, 62, 32, 53]> : tensor<5xi8>) {addr_space = 0 : i32} : !llvm.array<5 x i8>
llvm.mlir.global private constant @str_2(dense<[33, 40, 120, 32, 60, 32, 52, 41, 32, 45, 45, 32, 115, 104, 111, 117, 108, 100, 32, 98, 101, 32, 100, 101, 97, 100, 32, 99, 111, 100, 101]> : tensor<31xi8>) {addr_space = 0 : i32} : !llvm.array<31 x i8>
llvm.mlir.global private constant @str_1(dense<[97, 32, 115, 101, 99, 111, 110, 100, 32, 115, 116, 97, 116, 101, 109, 101, 110, 116]> : tensor<18xi8>) {addr_space = 0 : i32} : !llvm.array<18 x i8>
func.func private @__toy_print_string(i64, !llvm.ptr)
llvm.mlir.global private constant @str_0(dense<[120, 32, 60, 32, 52]> : tensor<5xi8>) {addr_space = 0 : i32} : !llvm.array<5 x i8>
func.func @main() -> i32 attributes {llvm.debug.subprogram = #llvm.di_subprogram, compileUnit = , sourceLanguage = DW_LANG_C, file = <"if.toy" in ".">, producer = "toycalculator", isOptimized = false, emissionKind = Full>, scope = #llvm.di_file<"if.toy" in ".">, name = "main", linkageName = "main", file = <"if.toy" in ".">, line = 1, scopeLine = 1, subprogramFlags = Definition, type = >>} {
"toy.scope"() ({
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64, bindc_name = "x"} : (i64) -> !llvm.ptr
llvm.intr.dbg.declare #llvm.di_local_variable, compileUnit = , sourceLanguage = DW_LANG_C, file = <"if.toy" in ".">, producer = "toycalculator", isOptimized = false, emissionKind = Full>, scope = #llvm.di_file<"if.toy" in ".">, name = "main", linkageName = "main", file = <"if.toy" in ".">, line = 1, scopeLine = 1, subprogramFlags = Definition, type = >>, name = "x", file = <"if.toy" in ".">, line = 1, alignInBits = 32, type = #llvm.di_basic_type> = %1 : !llvm.ptr
%2 = llvm.mlir.constant(3 : i64) : i64
%3 = llvm.trunc %2 : i64 to i32
llvm.store %3, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%4 = llvm.load %1 : !llvm.ptr -> i32
%5 = llvm.mlir.constant(4 : i64) : i64
%6 = llvm.sext %4 : i32 to i64
%7 = llvm.icmp "slt" %6, %5 : i64
scf.if %7 {
%15 = llvm.mlir.addressof @str_0 : !llvm.ptr
%16 = llvm.mlir.constant(5 : i64) : i64
"toy.call"(%16, %15) <{callee = @__toy_print_string}> : (i64, !llvm.ptr) -> ()
%17 = llvm.mlir.addressof @str_1 : !llvm.ptr
%18 = llvm.mlir.constant(18 : i64) : i64
"toy.call"(%18, %17) <{callee = @__toy_print_string}> : (i64, !llvm.ptr) -> ()
} else {
%15 = llvm.mlir.addressof @str_2 : !llvm.ptr
%16 = llvm.mlir.constant(31 : i64) : i64
"toy.call"(%16, %15) <{callee = @__toy_print_string}> : (i64, !llvm.ptr) -> ()
}
%8 = llvm.load %1 : !llvm.ptr -> i32
%9 = llvm.mlir.constant(5 : i64) : i64
%10 = llvm.sext %8 : i32 to i64
%11 = llvm.icmp "slt" %9, %10 : i64
scf.if %11 {
%15 = llvm.mlir.addressof @str_3 : !llvm.ptr
%16 = llvm.mlir.constant(5 : i64) : i64
"toy.call"(%16, %15) <{callee = @__toy_print_string}> : (i64, !llvm.ptr) -> ()
} else {
%15 = llvm.mlir.addressof @str_4 : !llvm.ptr
%16 = llvm.mlir.constant(27 : i64) : i64
"toy.call"(%16, %15) <{callee = @__toy_print_string}> : (i64, !llvm.ptr) -> ()
}
%12 = llvm.mlir.addressof @str_5 : !llvm.ptr
%13 = llvm.mlir.constant(5 : i64) : i64
"toy.call"(%13, %12) <{callee = @__toy_print_string}> : (i64, !llvm.ptr) -> ()
%14 = llvm.mlir.constant(0 : i32) : i32
"toy.return"(%14) : (i32) -> ()
}) : () -> ()
"toy.yield"() : () -> ()
}
}
The only things that are left in my MLIR toy dialect are “toy.scope”, “toy.call”, “toy.yield”, and “toy.return”. After my second lowering pass, everything is in the MLIR LLVM dialect:
llvm.mlir.global private constant @str_5(dense<[68, 111, 110, 101, 46]> : tensor<5xi8>) {addr_space = 0 : i32} : !llvm.array<5 x i8>
llvm.mlir.global private constant @str_4(dense<[33, 40, 120, 32, 62, 32, 53, 41, 32, 45, 45, 32, 115, 104, 111, 117, 108, 100, 32, 115, 101, 101, 32, 116, 104, 105, 115]> : tensor<27xi8>) {addr_space = 0 : i32} : !llvm.array<27 x i8>
llvm.mlir.global private constant @str_3(dense<[120, 32, 62, 32, 53]> : tensor<5xi8>) {addr_space = 0 : i32} : !llvm.array<5 x i8>
llvm.mlir.global private constant @str_2(dense<[33, 40, 120, 32, 60, 32, 52, 41, 32, 45, 45, 32, 115, 104, 111, 117, 108, 100, 32, 98, 101, 32, 100, 101, 97, 100, 32, 99, 111, 100, 101]> : tensor<31xi8>) {addr_space = 0 : i32} : !llvm.array<31 x i8>
llvm.mlir.global private constant @str_1(dense<[97, 32, 115, 101, 99, 111, 110, 100, 32, 115, 116, 97, 116, 101, 109, 101, 110, 116]> : tensor<18xi8>) {addr_space = 0 : i32} : !llvm.array<18 x i8>
func.func private @__toy_print_string(i64, !llvm.ptr)
llvm.mlir.global private constant @str_0(dense<[120, 32, 60, 32, 52]> : tensor<5xi8>) {addr_space = 0 : i32} : !llvm.array<5 x i8>
func.func @main() -> i32 attributes {llvm.debug.subprogram = #llvm.di_subprogram, compileUnit = , sourceLanguage = DW_LANG_C, file = <"if.toy" in ".">, producer = "toycalculator", isOptimized = false, emissionKind = Full>, scope = #llvm.di_file<"if.toy" in ".">, name = "main", linkageName = "main", file = <"if.toy" in ".">, line = 1, scopeLine = 1, subprogramFlags = Definition, type = >>} {
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64, bindc_name = "x"} : (i64) -> !llvm.ptr
llvm.intr.dbg.declare #llvm.di_local_variable, compileUnit = , sourceLanguage = DW_LANG_C, file = <"if.toy" in ".">, producer = "toycalculator", isOptimized = false, emissionKind = Full>, scope = #llvm.di_file<"if.toy" in ".">, name = "main", linkageName = "main", file = <"if.toy" in ".">, line = 1, scopeLine = 1, subprogramFlags = Definition, type = >>, name = "x", file = <"if.toy" in ".">, line = 1, alignInBits = 32, type = #llvm.di_basic_type> = %1 : !llvm.ptr
%2 = llvm.mlir.constant(3 : i64) : i64
%3 = llvm.trunc %2 : i64 to i32
llvm.store %3, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%4 = llvm.load %1 : !llvm.ptr -> i32
%5 = llvm.mlir.constant(4 : i64) : i64
%6 = llvm.sext %4 : i32 to i64
%7 = llvm.icmp "slt" %6, %5 : i64
llvm.cond_br %7, ^bb1, ^bb2
^bb1: // pred: ^bb0
%8 = llvm.mlir.addressof @str_0 : !llvm.ptr
%9 = llvm.mlir.constant(5 : i64) : i64
call @__toy_print_string(%9, %8) : (i64, !llvm.ptr) -> ()
%10 = llvm.mlir.addressof @str_1 : !llvm.ptr
%11 = llvm.mlir.constant(18 : i64) : i64
call @__toy_print_string(%11, %10) : (i64, !llvm.ptr) -> ()
llvm.br ^bb3
^bb2: // pred: ^bb0
%12 = llvm.mlir.addressof @str_2 : !llvm.ptr
%13 = llvm.mlir.constant(31 : i64) : i64
call @__toy_print_string(%13, %12) : (i64, !llvm.ptr) -> ()
llvm.br ^bb3
^bb3: // 2 preds: ^bb1, ^bb2
%14 = llvm.load %1 : !llvm.ptr -> i32
%15 = llvm.mlir.constant(5 : i64) : i64
%16 = llvm.sext %14 : i32 to i64
%17 = llvm.icmp "slt" %15, %16 : i64
llvm.cond_br %17, ^bb4, ^bb5
^bb4: // pred: ^bb3
%18 = llvm.mlir.addressof @str_3 : !llvm.ptr
%19 = llvm.mlir.constant(5 : i64) : i64
call @__toy_print_string(%19, %18) : (i64, !llvm.ptr) -> ()
llvm.br ^bb6
^bb5: // pred: ^bb3
%20 = llvm.mlir.addressof @str_4 : !llvm.ptr
%21 = llvm.mlir.constant(27 : i64) : i64
call @__toy_print_string(%21, %20) : (i64, !llvm.ptr) -> ()
llvm.br ^bb6
^bb6: // 2 preds: ^bb4, ^bb5
%22 = llvm.mlir.addressof @str_5 : !llvm.ptr
%23 = llvm.mlir.constant(5 : i64) : i64
call @__toy_print_string(%23, %22) : (i64, !llvm.ptr) -> ()
%24 = llvm.mlir.constant(0 : i32) : i32
return %24 : i32
}
Once assembled, we are left with:
0000000000000000: 0: push %rax 1: movl $0x3,0x4(%rsp) 9: xor %eax,%eax b: test %al,%al d: jne 2a f: mov $0x5,%edi 14: mov $0x0,%esi 15: R_X86_64_32 .rodata+0x62 19: call 1e 1a: R_X86_64_PLT32 __toy_print_string-0x4 1e: mov $0x12,%edi 23: mov $0x0,%esi 24: R_X86_64_32 .rodata+0x50 28: jmp 34 2a: mov $0x1f,%edi 2f: mov $0x0,%esi 30: R_X86_64_32 .rodata+0x30 34: call 39 35: R_X86_64_PLT32 __toy_print_string-0x4 39: movslq 0x4(%rsp),%rax 3e: cmp $0x6,%rax 42: jl 50 44: mov $0x5,%edi 49: mov $0x0,%esi 4a: R_X86_64_32 .rodata+0x2b 4e: jmp 5a 50: mov $0x1b,%edi 55: mov $0x0,%esi 56: R_X86_64_32 .rodata+0x10 5a: call 5f 5b: R_X86_64_PLT32 __toy_print_string-0x4 5f: mov $0x5,%edi 64: mov $0x0,%esi 65: R_X86_64_32 .rodata 69: call 6e 6a: R_X86_64_PLT32 __toy_print_string-0x4 6e: xor %eax,%eax 70: pop %rcx 71: ret
Of course, we may enable optimization, and get something much nicer. With -O2, by the time we are done the LLVM lowering, all the constant propagation has kicked in, leaving just:
; ModuleID = 'if.toy'
source_filename = "if.toy"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
@str_5 = private constant [5 x i8] c"Done."
@str_4 = private constant [27 x i8] c"!(x > 5) -- should see this"
@str_1 = private constant [18 x i8] c"a second statement"
@str_0 = private constant [5 x i8] c"x < 4"
declare void @__toy_print_string(i64, ptr) local_unnamed_addr
define noundef i32 @main() local_unnamed_addr !dbg !4 {
#dbg_value(i32 3, !8, !DIExpression(), !10)
tail call void @__toy_print_string(i64 5, ptr nonnull @str_0), !dbg !11
tail call void @__toy_print_string(i64 18, ptr nonnull @str_1), !dbg !12
tail call void @__toy_print_string(i64 27, ptr nonnull @str_4), !dbg !13
tail call void @__toy_print_string(i64 5, ptr nonnull @str_5), !dbg !14
ret i32 0, !dbg !14
}
Our program is reduced to just a couple of print statements:
0000000000400470: 400470: 50 push %rax 400471: bf 05 00 00 00 mov $0x5,%edi 400476: be 12 12 40 00 mov $0x401212,%esi 40047b: e8 f0 fe ff ff call 400370 <__toy_print_string@plt> 400480: bf 12 00 00 00 mov $0x12,%edi 400485: be 00 12 40 00 mov $0x401200,%esi 40048a: e8 e1 fe ff ff call 400370 <__toy_print_string@plt> 40048f: bf 1b 00 00 00 mov $0x1b,%edi 400494: be e0 11 40 00 mov $0x4011e0,%esi 400499: e8 d2 fe ff ff call 400370 <__toy_print_string@plt> 40049e: bf 05 00 00 00 mov $0x5,%edi 4004a3: be d0 11 40 00 mov $0x4011d0,%esi 4004a8: e8 c3 fe ff ff call 400370 <__toy_print_string@plt> 4004ad: 31 c0 xor %eax,%eax 4004af: 59 pop %rcx 4004b0: c3 ret
We've seen how easy it was to implement enough control flow to almost make the toy language useful. Another side effect of this MLIR+LLVM infrastructure, is that our IF/ELSE debugging support comes for free (having paid the cost earlier of having figured out how to emit the dwarf instrumentation). Here's an example:
> gdb -q out/if
Reading symbols from out/if...
(gdb) b main
Breakpoint 1 at 0x400471: file if.toy, line 3.
(gdb) run
Starting program: /home/pjoot/toycalculator/samples/out/if
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/lib64/libthread_db.so.1".
Breakpoint 1, main () at if.toy:3
3 x = 3;
(gdb) l
1 INT32 x;
2
3 x = 3;
4
5 IF ( x < 4 )
6 {
7 PRINT "x < 4";
8 PRINT "a second statement";
9 }
10 ELSE
(gdb) b 7
Breakpoint 2 at 0x40047f: file if.toy, line 7.
(gdb) c
Continuing.
Breakpoint 2, main () at if.toy:7
7 PRINT "x < 4";
(gdb) l
2
3 x = 3;
4
5 IF ( x < 4 )
6 {
7 PRINT "x < 4";
8 PRINT "a second statement";
9 }
10 ELSE
11 {
(gdb) p x
$1 = 3
(gdb) l
12 PRINT "!(x < 4) -- should be dead code";
13 };
14
15 IF ( x > 5 )
16 {
17 PRINT "x > 5";
18 }
19 ELSE
20 {
21 PRINT "!(x > 5) -- should see this";
(gdb) b 21
Breakpoint 3 at 0x4004c0: file if.toy, line 21.
(gdb) c
Continuing.
x < 4
a second statement
Breakpoint 3, main () at if.toy:21
21 PRINT "!(x > 5) -- should see this";
