本文是淦!MLIR输出Hello World不应该这么难的续集,同时也是对MLIR官方Toy Tutorial的Chapter #7中一些困惑的自我解答尝试

本文的实现代码已放入Github仓库:

https://github.com/mocusez/mlir-hello-world

根据Github CI的脚本可了解运行流程

对于Toy Tutroial的困惑

文章中在Toy Dialect中定义一个composite type,名称为struct

# A struct is defined by using the `struct` keyword followed by a name.

struct MyStruct {

# Inside of the struct is a list of variable declarations without initializers

# or shapes, which may also be other previously defined structs.

var a;

var b;

}

文章只是告诉我们了定义一个Type的流程:

  1. 定义Type
  2. 定义TypeStorage
  3. 定义ODS,即TableGen
  4. 定义Type的解析和打印
  5. 定义Type的相关操作

但是很多细节语焉不详:

  1. 我们该如何处理Type和TypeStorage的下降(lowering)?
  2. 我明明已经用了MLIR,为什么需要处理Type的解析和打印?(以及这个解析和打印,是针对Toy这门语言,还是MLIR的Dialect)

第二个问题其实极具迷惑性,但实际答案是:用于MLIR文本源代码的解析,在这点上与Toy语言的解析大不相同

基于这两点,我们开启今天文章的内容:实现一个简易的HashTable,并且能打印(Print)看到结果

实现方案概述

在Hello Dialect中定义Dict这个Type,以及Type的操作方法(Put,Get,Delete),转化为LLVM IR进行执行实现

添加自定义类型Type

首先,在Dialect的TableGen里面声明类型

def DictType :

DialectType<Hello_Dialect, CPred<"::llvm::isa<DictType>($_self)">,

"Hello dict type">;

如果不确定这里面的CPred怎么写,那么参照这个进行修改即可

dialect.h中定义Type的C++方法,用于在转化中生成和访问Dialect,同时声明使用的TypeStorage

namespace hello {

struct DictTypeStorage;

}

class DictType : public mlir::Type::TypeBase<DictType, mlir::Type,

hello::DictTypeStorage> {

public:

/// Inherit some necessary constructors from 'TypeBase'.

using Base::Base;

/// Create an instance of a `DictType` with the given key and value types.

static DictType get(mlir::Type keyType, mlir::Type valueType);

/// Returns the key type of this dict type.

mlir::Type getKeyType();

/// Returns the value type of this dict type.

mlir::Type getValueType();

/// The name of this dict type.

static constexpr mlir::StringLiteral name = "hello.dict";

};

dialect.cpp中完成Type的方法补齐,下文中的getImpl()是指从TypeStorage访问相关MLIR Type

DictType DictType::get(mlir::Type keyType, mlir::Type valueType) {

return Base::get(keyType.getContext(), keyType, valueType);

}

mlir::Type DictType::getKeyType() {

return getImpl()->keyType;

}

/// Returns the value type of this dict type.

mlir::Type DictType::getValueType() {

return getImpl()->valueType;

}

以及在Dialect中的initialize()方法中添加对应的Type

void HelloDialect::initialize() {

addOperations<

#define GET_OP_LIST

#include "Hello/HelloOps.cpp.inc"

>();

addTypes<DictType>();

}

添加自定义的TypeStorage

MLIR存在Type和TypeStorage结构,关于TypeStorage,可以理解为Type当中数据和元数据的实际存储

由于在Dialect.h中已经声明了DictTypeStorage,那么在Dialect.cpp中直接继承mlir::TypeStorage实现就行

operator==确定了Type之间的比较关系

hashKey确定操作的Type是否为同一个

实现*construct方法,使用allocator告诉Dialect需要申请的空间(这一步操作并不会在lowering中体现)

struct DictTypeStorage : public mlir::TypeStorage {

/// The `KeyTy` defines what uniquely identifies this type.

/// For dict type, we unique on the key type and value type pair.

using KeyTy = std::pair<mlir::Type, mlir::Type>;

/// Constructor for the type storage instance.

DictTypeStorage(mlir::Type keyType, mlir::Type valueType)

: keyType(keyType), valueType(valueType) {}

/// Define the comparison function for the key type.

bool operator==(const KeyTy &key) const {

return key.first == keyType && key.second == valueType;

}

/// Define a hash function for the key type.

static llvm::hash_code hashKey(const KeyTy &key) {

return llvm::hash_combine(key.first, key.second);

}

/// Define a construction function for the key type.

static KeyTy getKey(mlir::Type keyType, mlir::Type valueType) {

return KeyTy(keyType, valueType);

}

/// Define a construction method for creating a new instance of this storage.

static DictTypeStorage *construct(mlir::TypeStorageAllocator &allocator,

const KeyTy &key) {

// Allocate the storage instance and construct it.

return new (allocator.allocate<DictTypeStorage>())

DictTypeStorage(key.first, key.second);

}

/// The key and value types of the dict.

mlir::Type keyType;

mlir::Type valueType;

};

Type类型转化

Hello Dialect中的Dict Type需要转化为一个指针,相关的操作实际发生与Operation的下降lowering,通过下降Op的入参和返回值实现

一个方案是创建一个继承LLVMTypeConverter的方法,添加addConversion——这个方案理论可行,但我没试过

class HelloTypeConverter : public mlir::LLVMTypeConverter {

public:

HelloTypeConverter(mlir::MLIRContext *ctx) : mlir::LLVMTypeConverter(ctx) {

addConversion([](hello::DictType type) {

return mlir::LLVM::LLVMPointerType::get(type.getContext());

});

}

};

另外一个方案是直接在typeconverter上加上addConversion

这里还分出两种写法可供大家参考(这两种方法测试均通过):

  1. 直接点名转换的Type的类型

typeConverter.addConversion([](DictType type) -> mlir::Type {

return mlir::LLVM::LLVMPointerType::get(type.getContext());

});

  1. 根据mlir::dyn_cast判断后确定返回

typeConverter.addConversion([](mlir::Type type) -> std::optional<mlir::Type> {

if (auto dictType = mlir::dyn_cast<DictType>(type))

return mlir::LLVM::LLVMPointerType::get(type.getContext());

return std::nullopt;

});

按照构造函数传入typeconverter后,就可以使用getTypeConverter根据类型决定传回什么参数

auto resultType = getTypeConverter()->convertType(op->getResult(0).getType());

if (!resultType) {

return mlir::failure();

}

所以这段代码就等价于

auto resultType = mlir::LLVM::LLVMPointerType::get(context);

这部分代码相等于就是生成不可变代码——resultType不会因为Operation传入的参数的不同而发生变化

Type类型操作方法转化

对于Dict(哈希表)而言,有putgetdelete这些操作,以及用于释放内存的createfree,这些都需要在Dialect中定义对应Op(Operation)才能实现

下面代码当中,定义AssemblyFormat确保能从文本的mlir文件中能正确解析类型

class Hello_Op<string mnemonic, list<Trait> traits = []> :

Op<Hello_Dialect, mnemonic, traits>;

def Dict_CreateOp : Hello_Op<"dict.create", [Pure]> {

let summary = "Create a new dict<string,i32>";

let results = (outs DictType:$dict);

let assemblyFormat = "attr-dict `:` type($dict)";

}

def Dict_FreeOp : Hello_Op<"dict.free", []> {

let summary = "Free the dict<string,i32> memory";

let arguments = (ins DictType:$dict);

let assemblyFormat = "$dict attr-dict `:` type($dict)";

}

def Dict_PutOp : Hello_Op<"dict.put", []> {

let summary = "Insert string->i32";

let arguments = (ins DictType:$dict, StrAttr:$key, I32Attr:$value);

let results = (outs DictType:$out);

let assemblyFormat = "$dict `,` $key `=` $value attr-dict `:` type($dict) `->` type($out)";

}

def Dict_GetOp : Hello_Op<"dict.get", []> {

let summary = "Lookup string->i32, returns i32";

let arguments = (ins DictType:$dict, StrAttr:$key);

let results = (outs I32:$value);

let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `->` type($value)";

}

def Dict_DeleteOp : Hello_Op<"dict.delete", []> {

let summary = "Delete key string";

let arguments = (ins DictType:$dict, StrAttr:$key);

let results = (outs DictType:$out);

let assemblyFormat = "$dict `,` $key attr-dict `:` type($dict) `->` type($out)";

}

同时也需要写明各个Op的lowering,这里以Dict_CreateOp的实现为例

class DictCreateOpLowering : public mlir::ConversionPattern {

public:

explicit DictCreateOpLowering(mlir::TypeConverter &typeConverter,

mlir::MLIRContext *context)

: mlir::ConversionPattern(

typeConverter, hello::CreateOp::getOperationName(), 1, context) {}

mlir::LogicalResult

matchAndRewrite(mlir::Operation *op, mlir::ArrayRef<mlir::Value> operands,

mlir::ConversionPatternRewriter &rewriter) const override {

auto loc = op->getLoc();

auto resultType =

getTypeConverter()->convertType(op->getResult(0).getType());

if (!resultType) {

return mlir::failure();

}

mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();

auto createMapRef = getOrInsertCreateMap(rewriter, module);

auto callOp = rewriter.create<mlir::LLVM::CallOp>(

loc, resultType, createMapRef, mlir::ValueRange{});

rewriter.replaceOp(op, callOp.getResult());

return mlir::success();

}

private:

mlir::FlatSymbolRefAttr getOrInsertCreateMap(mlir::PatternRewriter &rewriter,

mlir::ModuleOp module) const {

auto *context = module.getContext();

if (module.lookupSymbol<mlir::LLVM::LLVMFuncOp>("create_map"))

return mlir::SymbolRefAttr::get(context, "create_map");

// Create function type: () -> !llvm.ptr

auto resultType = mlir::LLVM::LLVMPointerType::get(context);

auto fnType =

mlir::LLVM::LLVMFunctionType::get(resultType, std::nullopt, false);

// Insert function declaration

mlir::PatternRewriter::InsertionGuard insertGuard(rewriter);

rewriter.setInsertionPointToStart(module.getBody());

rewriter.create<mlir::LLVM::LLVMFuncOp>(module.getLoc(), "create_map",

fnType);

return mlir::SymbolRefAttr::get(context, "create_map");

}

};

这个Op在做的事情就是:Dict_CreateOp转化为对于create_map这个C/LLVM IR函数的调用,只要能实现,就能用C的printf()完成输出。

这是非常重要的一部分,需要在声明对于自定义Type的解析范式,即printTypeparseType

为了做这一步,需要将Dialect的useDefaultTypePrinterParser设为1

def Hello_Dialect : Dialect {

let name = "hello";

let summary = "A hello out-of-tree MLIR dialect.";

let description = [{

This dialect is minimal example to implement hello-world kind of sample code

for MLIR.

}];

let cppNamespace = "::hello";

let useDefaultTypePrinterParser = 1;

}

在对应的dialect.cpp中定义printTypeparseType

mlir::Type DictType::getValueType() {

return getImpl()->valueType;

}

mlir::Type HelloDialect::parseType(mlir::DialectAsmParser &parser) const {

llvm::StringRef typeTag;

if (parser.parseKeyword(&typeTag))

return mlir::Type();

if (typeTag == "dict") {

if (parser.parseLess())

return mlir::Type();

mlir::Type keyType;

if (parser.parseType(keyType))

return mlir::Type();

if (parser.parseComma())

return mlir::Type();

mlir::Type valueType;

if (parser.parseType(valueType))

return mlir::Type();

if (parser.parseGreater())

return mlir::Type();

return DictType::get(keyType, valueType);

}

parser.emitError(parser.getNameLoc(), "unknown hello type: ") << typeTag;

return mlir::Type();

}

void HelloDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &printer) const {

if (auto dictType = mlir::dyn_cast<DictType>(type)) {

printer << "dict<";

printer.printType(dictType.getKeyType());

printer << ", ";

printer.printType(dictType.getValueType());

printer << ">";

return;

}

llvm_unreachable("unhandled hello type");

}

运行!

准备一份MLIR代码

module {

func.func @test_dict_operations() -> i32 {

%dict0 = hello.dict.create : !hello.dict<index, i32>

%dict1 = hello.dict.put %dict0, "first1" = 100 : !hello.dict<index, i32> -> !hello.dict<index, i32>

%dict2 = hello.dict.put %dict1, "second" = 200 : !hello.dict<index, i32> -> !hello.dict<index, i32>

%dict3 = hello.dict.put %dict2, "third" = 300 : !hello.dict<index, i32> -> !hello.dict<index, i32>

%val1 = hello.dict.get %dict3, "first1" : !hello.dict<index, i32> -> i32

%val2 = hello.dict.get %dict3, "second" : !hello.dict<index, i32> -> i32

%val3 = hello.dict.get %dict3, "third" : !hello.dict<index, i32> -> i32

%dict4 = hello.dict.delete %dict3, "second" : !hello.dict<index, i32> -> !hello.dict<index, i32>

hello.dict.free %dict4 : !hello.dict<index, i32>

func.return %val2 : i32

}

}

将MLIR下降到LLVM IR,我们看一看看打印出来的LLVM IR的样子

; ModuleID = 'LLVMDialectModule'

source_filename = "LLVMDialectModule"

target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"

target triple = "x86_64-pc-linux-gnu"

@key_third = internal constant [6 x i8] c"third\00"

@key_second = internal constant [7 x i8] c"second\00"

@key_first1 = internal constant [7 x i8] c"first1\00"

declare void @free_map(ptr) local_unnamed_addr

declare void @delete(ptr, ptr) local_unnamed_addr

declare ptr @get(ptr, ptr) local_unnamed_addr

declare void @put(ptr, ptr, i32) local_unnamed_addr

declare ptr @create_map() local_unnamed_addr

define i32 @test_dict_operations() local_unnamed_addr {

%1 = tail call ptr @create_map()

tail call void @put(ptr %1, ptr nonnull @key_first1, i32 100)

tail call void @put(ptr %1, ptr nonnull @key_second, i32 200)

tail call void @put(ptr %1, ptr nonnull @key_third, i32 300)

%2 = tail call ptr @get(ptr %1, ptr nonnull @key_first1)

%3 = tail call ptr @get(ptr %1, ptr nonnull @key_second)

%4 = load i32, ptr %3, align 4

%5 = tail call ptr @get(ptr %1, ptr nonnull @key_third)

tail call void @delete(ptr %1, ptr nonnull @key_second)

tail call void @free_map(ptr %1)

ret i32 %4

}

!llvm.module.flags = !{!0}

!0 = !{i32 2, !"Debug Info Version", i32 3}

大部分转化为了C/LLVM IR的函数调用,这样就变成了我们所熟悉的问题,和前文的printf()行为一样

在C/C++的实现相关接口函数,编译并调用运行

../../build/bin/hello-opt dict.mlir -emit=llvm > dict.ll

clang-20 dict.c dict.ll -o dict

extern int test_dict_operations();

int main() {

printf("Starting dictionary test...\n");

int result = test_dict_operations();

printf("Result from test_dict_operations: %d\n", result);

return 0;

}

预期结果为

Starting dictionary test...

Result from test_dict_operations: 200

image-20250724222358886

思考内容

可以看到这种实现方式的局限性:只能适配于<string,i32>结构的哈希表,不能适用于其他类型。这个问题可以通过动态生成LLVM IR解决,具体操作就是另外一个问题了😂

还是建议要自己上手尝试才加深理解