4. Kaleidoscope:加入 JIT 與最佳化器支援

4.1. 第四章 簡介

歡迎來到「使用 LLVM 實作語言」教學的第四章。第 1-3 章描述了一個簡單語言的實作,並加入了產生 LLVM IR 的支援。本章將介紹兩種新技術:為您的語言加入最佳化器支援,以及加入 JIT 編譯器支援。這些新增的功能將展示如何為 Kaleidoscope 語言產生優良、有效率的程式碼。

4.2. 簡單的常數摺疊

我們在第三章的示範既優雅又易於擴展。可惜的是,它產生的程式碼並不理想。然而,在編譯簡單程式碼時,IRBuilder 的確為我們提供了顯而易見的最佳化

ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        ret double %addtmp
}

這段程式碼並不是由解析輸入所建立的 AST 的逐字轉錄。那會是

ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 2.000000e+00, 1.000000e+00
        %addtmp1 = fadd double %addtmp, %x
        ret double %addtmp1
}

特別是像上面看到的常數摺疊,這是一種非常常見且非常重要的最佳化:以至於許多語言實作者都在他們的 AST 表示中實作了常數摺疊支援。

有了 LLVM,您在 AST 中不需要這種支援。由於所有建立 LLVM IR 的呼叫都是透過 LLVM IR 建立器進行的,因此建立器本身會在您呼叫它時檢查是否有常數摺疊的機會。如果有的話,它就會直接進行常數摺疊並返回常數,而不是建立指令。

嗯,這很簡單:)。在實務上,我們建議在產生這樣的程式碼時,始終使用 IRBuilder。使用它沒有「語法上的開銷」(您不需要在任何地方都使用常數檢查來醜化您的編譯器),而且它可以在某些情況下大幅減少產生的 LLVM IR 數量(特別是對於具有巨集預處理器或使用大量常數的語言)。

另一方面,IRBuilder 的限制在於,它會在程式碼建立時進行所有內聯分析。如果您採用稍微複雜一點的範例

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        %addtmp1 = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp1
        ret double %multmp
}

在這種情況下,乘法的 LHS 和 RHS 是相同的值。我們真的希望看到它產生「tmp = x+3; result = tmp*tmp;」,而不是計算兩次「x+3」。

遺憾的是,無論進行多少本地分析都無法偵測並更正此問題。這需要兩個轉換:表達式的重新關聯(使加法在語法上相同)和共同子表達式消除 (CSE) 來刪除冗餘的加法指令。幸運的是,LLVM 以「遍歷」的形式提供了大量可供使用的優化。

4.3. LLVM 優化遍歷

LLVM 提供了許多優化遍歷,它們執行許多不同種類的操作,並具有不同的取捨。與其他系統不同,LLVM 不會堅持一套優化方案適用於所有語言和所有情況的錯誤觀念。LLVM 允許編譯器實作者完全決定使用哪些優化、以何種順序以及在何種情況下使用。

舉個具體的例子,LLVM 既支持「整個模組」遍歷,這些遍歷會查看儘可能多的程式碼主體(通常是整個檔案,但如果在連結時執行,則可能是整個程式的很大一部分)。它還支持並包含「每個函式」遍歷,這些遍歷一次只對一個函式進行操作,而不查看其他函式。如需有關遍歷及其執行方式的詳細資訊,請參閱如何編寫遍歷文件和LLVM 遍歷列表

對於 Kaleidoscope,我們目前正在動態產生函式,一次一個,就像使用者輸入它們一樣。在這種情況下,我們並非追求極致的優化體驗,但我們也希望儘可能地捕捉到簡單快速的東西。因此,我們選擇在使用者輸入函式時執行一些每個函式的優化。如果我們想製作一個「靜態 Kaleidoscope 編譯器」,我們將使用與現在完全相同的程式碼,只是我們會延遲執行優化器,直到解析完整個檔案為止。

除了函式和模組遍歷之間的區別之外,遍歷還可以分為轉換遍歷和分析遍歷。轉換遍歷會改變 IR,而分析遍歷會計算其他遍歷可以使用的資訊。為了添加轉換遍歷,必須事先註冊它所依賴的所有分析遍歷。

為了進行每個函式的優化,我們需要設置一個FunctionPassManager 來保存和組織我們想要執行的 LLVM 優化。一旦我們有了它,我們就可以添加一組要執行的優化。對於我們想要優化的每個模組,我們都需要一個新的 FunctionPassManager,因此我們將添加到上一章建立的函式中(InitializeModule()

void InitializeModuleAndManagers(void) {
  // Open a new context and module.
  TheContext = std::make_unique<LLVMContext>();
  TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
  TheModule->setDataLayout(TheJIT->getDataLayout());

  // Create a new builder for the module.
  Builder = std::make_unique<IRBuilder<>>(*TheContext);

  // Create new pass and analysis managers.
  TheFPM = std::make_unique<FunctionPassManager>();
  TheLAM = std::make_unique<LoopAnalysisManager>();
  TheFAM = std::make_unique<FunctionAnalysisManager>();
  TheCGAM = std::make_unique<CGSCCAnalysisManager>();
  TheMAM = std::make_unique<ModuleAnalysisManager>();
  ThePIC = std::make_unique<PassInstrumentationCallbacks>();
  TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
                                                    /*DebugLogging*/ true);
  TheSI->registerCallbacks(*ThePIC, TheMAM.get());
  ...

在初始化全域模組 TheModule 和 FunctionPassManager 之後,我們需要初始化框架的其他部分。四個 AnalysisManagers 允許我們添加跨越 IR 階層四個級別運行的分析遍歷。PassInstrumentationCallbacks 和 StandardInstrumentations 是遍歷檢測框架所必需的,允許開發人員自訂遍歷之間發生的事情。

設置好這些管理器後,我們使用一系列「addPass」呼叫來添加一堆 LLVM 轉換遍歷

// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());

在這種情況下,我們選擇添加四個優化遍歷。我們在此選擇的遍歷是一組非常標準的「清理」優化,適用於各種程式碼。我不會深入探討它們的作用,但相信我,它們是一個很好的起點:)。

接下來,我們註冊轉換遍歷使用的分析遍歷。

  // Register analysis passes used in these transform passes.
  PassBuilder PB;
  PB.registerModuleAnalyses(*TheMAM);
  PB.registerFunctionAnalyses(*TheFAM);
  PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}

設定好 PassManager 後,我們需要使用它。我們在建構好新建立的函式後執行它 (在 FunctionAST::codegen() 中),但在將它傳回給客戶端之前

if (Value *RetVal = Body->codegen()) {
  // Finish off the function.
  Builder.CreateRet(RetVal);

  // Validate the generated code, checking for consistency.
  verifyFunction(*TheFunction);

  // Optimize the function.
  TheFPM->run(*TheFunction, *TheFAM);

  return TheFunction;
}

如您所見,這非常簡單。 FunctionPassManager 會最佳化並更新 LLVM Function*,以改善(希望如此)其主體。完成此操作後,我們可以再次嘗試上面的測試

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp
        ret double %multmp
}

如預期,我們現在獲得了經過良好最佳化的程式碼,從該函式的每次執行中都省下了一個浮點加法指令。

LLVM 提供了各種可在特定情況下使用的最佳化。有一些關於各種遍歷的文檔,但它不是很完整。另一個尋找想法的好來源是查看 Clang 執行的遍歷以開始使用。“opt”工具允許您從命令列試驗遍歷,以便您可以查看它們是否有效。

現在我們的編譯前端輸出了合理的程式碼,讓我們來談談如何執行它!

4.4. 新增 JIT 編譯器

可以使用各種工具來處理以 LLVM IR 提供的程式碼。例如,您可以對其執行最佳化(如我們上面所做的),您可以將其以文字或二進制形式傾印出來,您可以將程式碼編譯為某些目標的組合檔案 (.s),或者您可以將其 JIT 編譯。LLVM IR 表示的好處在於它是編譯器許多不同部分之間的「通用貨幣」。

在本節中,我們將在我們的解譯器中新增 JIT 編譯器支援。我們希望 Kaleidoscope 的基本概念是讓使用者像現在這樣輸入函式主體,但立即評估他們輸入的頂層表達式。例如,如果他們輸入“1 + 2;”,我們應該評估並列印出 3。如果他們定義了一個函式,他們應該能夠從命令列呼叫它。

為了做到這一點,我們首先準備環境來為當前的原生目標建立程式碼,並宣告和初始化 JIT。這是透過呼叫一些 InitializeNativeTarget\* 函式並新增一個全域變數 TheJIT,並在 main 中初始化它來完成的

static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();
  InitializeNativeTargetAsmParser();

  // Install standard binary operators.
  // 1 is lowest precedence.
  BinopPrecedence['<'] = 10;
  BinopPrecedence['+'] = 20;
  BinopPrecedence['-'] = 20;
  BinopPrecedence['*'] = 40; // highest.

  // Prime the first token.
  fprintf(stderr, "ready> ");
  getNextToken();

  TheJIT = std::make_unique<KaleidoscopeJIT>();

  // Run the main "interpreter loop" now.
  MainLoop();

  return 0;
}

我們還需要為 JIT 設置數據佈局

void InitializeModuleAndPassManager(void) {
  // Open a new context and module.
  TheContext = std::make_unique<LLVMContext>();
  TheModule = std::make_unique<Module>("my cool jit", TheContext);
  TheModule->setDataLayout(TheJIT->getDataLayout());

  // Create a new builder for the module.
  Builder = std::make_unique<IRBuilder<>>(*TheContext);

  // Create a new pass manager attached to it.
  TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
  ...

KaleidoscopeJIT 類是一個專門為這些教程構建的簡單 JIT,可在 LLVM 原始程式碼中的 llvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.h 找到。在後面的章節中,我們將研究它是如何工作的,並使用新功能對其進行擴展,但現在我們將其視為已知。它的 API 非常簡單:addModule 將 LLVM IR 模組添加到 JIT 中,使其函式可供執行(其記憶體由 ResourceTracker 管理);而 lookup 允許我們查詢指向已編譯程式碼的指標。

我們可以使用這個簡單的 API 並更改我們解析頂層表達式的程式碼,使其看起來像這樣

static ExitOnError ExitOnErr;
...
static void HandleTopLevelExpression() {
  // Evaluate a top-level expression into an anonymous function.
  if (auto FnAST = ParseTopLevelExpr()) {
    if (FnAST->codegen()) {
      // Create a ResourceTracker to track JIT'd memory allocated to our
      // anonymous expression -- that way we can free it after executing.
      auto RT = TheJIT->getMainJITDylib().createResourceTracker();

      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
      ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
      InitializeModuleAndPassManager();

      // Search the JIT for the __anon_expr symbol.
      auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
      assert(ExprSymbol && "Function not found");

      // Get the symbol's address and cast it to the right type (takes no
      // arguments, returns a double) so we can call it as a native function.
      double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
      fprintf(stderr, "Evaluated to %f\n", FP());

      // Delete the anonymous expression module from the JIT.
      ExitOnErr(RT->remove());
    }

如果解析和程式碼生成成功,下一步是將包含頂層表達式的模組添加到 JIT。我們通過呼叫 addModule 來做到這一點,這會觸發模組中所有函數的程式碼生成,並接受一個 ResourceTracker,可以用來稍後從 JIT 中移除模組。將模組添加到 JIT 後,就不能再修改它,因此我們還會通過呼叫 InitializeModuleAndPassManager() 打開一個新模組來存放後續的程式碼。

將模組添加到 JIT 後,我們需要取得指向最終生成程式碼的指標。我們通過呼叫 JIT 的 lookup 方法並傳遞頂層表達式函數的名稱來做到這一點:__anon_expr。由於我們剛剛添加了這個函數,我們斷言 lookup 返回了一個結果。

接下來,我們通過對符號呼叫 getAddress() 來取得 __anon_expr 函數的記憶體位址。回想一下,我們將頂層表達式編譯成一個獨立的 LLVM 函數,它不接受任何參數並返回計算出的雙精度浮點數。因為 LLVM JIT 編譯器與原生平台 ABI 相匹配,這意味著您可以將結果指標直接轉換為該類型的函數指標並直接呼叫它。這意味著,JIT 編譯的程式碼與靜態鏈接到應用程序中的原生機器碼沒有區別。

最後,由於我們不支持重新評估頂層表達式,因此我們在完成後從 JIT 中移除模組以釋放相關的記憶體。但是請記住,我們在幾行之前(通過 InitializeModuleAndPassManager)創建的模組仍然是打開的,並且正在等待添加新的程式碼。

只需進行這兩項更改,讓我們看看 Kaleidoscope 現在是如何工作的!

ready> 4+5;
Read top-level expression:
define double @0() {
entry:
  ret double 9.000000e+00
}

Evaluated to 9.000000

嗯,這看起來基本上是可行的。函數的轉儲顯示了我們為每個輸入的頂層表達式合成的“不帶參數且始終返回雙精度浮點數的函數”。這演示了非常基本的功能,但我們可以做得更多嗎?

ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
  %multmp = fmul double %y, 2.000000e+00
  %addtmp = fadd double %multmp, %x
  ret double %addtmp
}

ready> testfunc(4, 10);
Read top-level expression:
define double @1() {
entry:
  %calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
  ret double %calltmp
}

Evaluated to 24.000000

ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!

函數定義和呼叫也可以正常工作,但最後一行出現了嚴重的錯誤。呼叫看起來是有效的,那麼發生了什麼?正如您可能從 API 中猜到的那樣,模組是 JIT 的分配單位,而 testfunc 與包含匿名表達式的模組屬於同一個模組。當我們從 JIT 中移除該模組以釋放匿名表達式的記憶體時,我們也將 testfunc 的定義一起刪除了。然後,當我們嘗試第二次呼叫 testfunc 時,JIT 就找不到它了。

解決此問題的最簡單方法是將匿名表達式放在與其他函數定義不同的模組中。只要每個被呼叫的函數都有一個原型,並且在被呼叫之前就被添加到 JIT 中,JIT 就會很樂意跨模組邊界解析函數呼叫。通過將匿名表達式放在不同的模組中,我們可以在不影響其他函數的情況下將其刪除。

實際上,我們將更進一步,將每個函數都放在自己的模組中。這樣做可以讓我們利用 KaleidoscopeJIT 的一個有用特性,使我們的環境更像 REPL:函數可以多次添加到 JIT 中(與模組不同,在模組中,每個函數都必須有唯一的定義)。當您在 KaleidoscopeJIT 中查找符號時,它將始終返回最新的定義。

ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 1.000000e+00
  ret double %addtmp
}

ready> foo(2);
Evaluated to 3.000000

ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 2.000000e+00
  ret double %addtmp
}

ready> foo(2);
Evaluated to 4.000000

為了允許每個函數都位於自己的模組中,我們需要一種方法將先前的函數聲明重新生成到我們打開的每個新模組中

static std::unique_ptr<KaleidoscopeJIT> TheJIT;

...

Function *getFunction(std::string Name) {
  // First, see if the function has already been added to the current module.
  if (auto *F = TheModule->getFunction(Name))
    return F;

  // If not, check whether we can codegen the declaration from some existing
  // prototype.
  auto FI = FunctionProtos.find(Name);
  if (FI != FunctionProtos.end())
    return FI->second->codegen();

  // If no existing prototype exists, return null.
  return nullptr;
}

...

Value *CallExprAST::codegen() {
  // Look up the name in the global module table.
  Function *CalleeF = getFunction(Callee);

...

Function *FunctionAST::codegen() {
  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  // reference to it for use below.
  auto &P = *Proto;
  FunctionProtos[Proto->getName()] = std::move(Proto);
  Function *TheFunction = getFunction(P.getName());
  if (!TheFunction)
    return nullptr;

為了啟用這個功能,我們首先新增一個新的全域變數 FunctionProtos,用來儲存每個函式的最新原型。我們也會新增一個便捷方法 getFunction(),用來取代呼叫 TheModule->getFunction()。我們的便捷方法會在 TheModule 中搜尋現有的函式宣告,如果找不到,就會從 FunctionProtos 產生新的宣告。在 CallExprAST::codegen() 中,我們只需要將呼叫 TheModule->getFunction() 的部分替換掉。在 FunctionAST::codegen() 中,我們需要先更新 FunctionProtos 映射表,然後再呼叫 getFunction()。完成這些步驟後,我們就可以在目前的模組中取得任何先前宣告過的函式宣告。

我們也需要更新 HandleDefinition 和 HandleExtern

static void HandleDefinition() {
  if (auto FnAST = ParseDefinition()) {
    if (auto *FnIR = FnAST->codegen()) {
      fprintf(stderr, "Read function definition:");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      ExitOnErr(TheJIT->addModule(
          ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
      InitializeModuleAndPassManager();
    }
  } else {
    // Skip token for error recovery.
     getNextToken();
  }
}

static void HandleExtern() {
  if (auto ProtoAST = ParseExtern()) {
    if (auto *FnIR = ProtoAST->codegen()) {
      fprintf(stderr, "Read extern: ");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

在 HandleDefinition 中,我們新增兩行程式碼,將新定義的函式傳輸到 JIT 並開啟新的模組。在 HandleExtern 中,我們只需要新增一行程式碼,將原型新增到 FunctionProtos 中。

警告

從 LLVM-9 開始,不允許在不同的模組中重複定義符號。這表示您無法像下方所示範例一樣,在 Kaleidoscope 中重新定義函式。請略過這個部分。

原因是較新的 OrcV2 JIT API 盡可能地與靜態和動態連結器的規則保持一致,包括拒絕重複的符號。要求符號名稱必須是唯一的,讓我們可以使用(唯一的)符號名稱作為索引鍵來追蹤並支援並行編譯。

完成這些變更後,讓我們再次嘗試使用 REPL(這次我刪除了匿名函式的傾印,您現在應該已經了解了 :))

ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000

ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000

成功了!

即使使用這麼簡單的程式碼,我們也能獲得一些強大的功能,請看以下範例

ready> extern sin(x);
Read extern:
declare double @sin(double)

ready> extern cos(x);
Read extern:
declare double @cos(double)

ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
  ret double 0x3FEAED548F090CEE
}

Evaluated to 0.841471

ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
  %calltmp = call double @sin(double %x)
  %multmp = fmul double %calltmp, %calltmp
  %calltmp2 = call double @cos(double %x)
  %multmp4 = fmul double %calltmp2, %calltmp2
  %addtmp = fadd double %multmp, %multmp4
  ret double %addtmp
}

ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
  %calltmp = call double @foo(double 4.000000e+00)
  ret double %calltmp
}

Evaluated to 1.000000

哇,JIT 是怎麼知道 sin 和 cos 的?答案出奇地簡單:KaleidoscopeJIT 有一個簡單的符號解析規則,用於尋找任何給定模組中不存在的符號:首先,它會搜尋所有已新增到 JIT 的模組,從最新的到最舊的,以找到最新的定義。如果在 JIT 中找不到定義,它會退而求其次,呼叫 Kaleidoscope 進程本身的 “dlsym("sin")”。由於 “sin” 是在 JIT 的位址空間中定義的,因此它會直接將模組中的呼叫修補到 libm 版本的 sin。但在某些情況下,它甚至可以做得更多:由於 sin 和 cos 是標準數學函式的名稱,因此當使用常數呼叫函式時(例如上述的 “sin(1.0)”),常數摺疊器會直接將函式呼叫評估為正確的結果。

在未來,我們將會看到如何調整這個符號解析規則,以啟用各種實用的功能,從安全性(限制 JIT 程式碼可用的符號集)到基於符號名稱的動態程式碼產生,甚至是延遲編譯。

符號解析規則的一個直接好處是,我們現在可以透過撰寫任意的 C++ 程式碼來實作運算,進而擴展語言。例如,如果我們新增

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
  fputc((char)X, stderr);
  return 0;
}

請注意,對於 Windows,我們需要明確地匯出函式,因為動態符號載入器會使用 GetProcAddress 來尋找符號。

現在,我們可以使用類似 “extern putchard(x); putchard(120);” 的程式碼在控制台上產生簡單的輸出,這會在控制台上印出一個小寫的 ‘x’(120 是 ‘x’ 的 ASCII 代碼)。類似的程式碼可以用於在 Kaleidoscope 中實現檔案 I/O、控制台輸入和許多其他功能。

以上就是 Kaleidoscope 教學中關於 JIT 和優化器的章節。至此,我們可以編譯一種非圖靈完備的程式語言,並以使用者驅動的方式對其進行優化和 JIT 編譯。接下來,我們將探討使用控制流程結構擴展該語言,並在此過程中解決一些有趣的 LLVM IR 問題。

4.5. 完整程式碼清單

以下是我們運行範例的完整程式碼清單,其中添加了 LLVM JIT 和優化器。要建置此範例,請使用

# Compile
clang++ -g toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native` -O3 -o toy
# Run
./toy

如果您在 Linux 上編譯,請確保還要添加 “-rdynamic” 選項。這可以確保在執行時正確解析外部函數。

以下是程式碼

#include "../include/KaleidoscopeJIT.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/Reassociate.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>

using namespace llvm;
using namespace llvm::orc;

//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//

// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
  tok_eof = -1,

  // commands
  tok_def = -2,
  tok_extern = -3,

  // primary
  tok_identifier = -4,
  tok_number = -5
};

static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal;             // Filled in if tok_number

/// gettok - Return the next token from standard input.
static int gettok() {
  static int LastChar = ' ';

  // Skip any whitespace.
  while (isspace(LastChar))
    LastChar = getchar();

  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
    IdentifierStr = LastChar;
    while (isalnum((LastChar = getchar())))
      IdentifierStr += LastChar;

    if (IdentifierStr == "def")
      return tok_def;
    if (IdentifierStr == "extern")
      return tok_extern;
    return tok_identifier;
  }

  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
    std::string NumStr;
    do {
      NumStr += LastChar;
      LastChar = getchar();
    } while (isdigit(LastChar) || LastChar == '.');

    NumVal = strtod(NumStr.c_str(), nullptr);
    return tok_number;
  }

  if (LastChar == '#') {
    // Comment until end of line.
    do
      LastChar = getchar();
    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');

    if (LastChar != EOF)
      return gettok();
  }

  // Check for end of file.  Don't eat the EOF.
  if (LastChar == EOF)
    return tok_eof;

  // Otherwise, just return the character as its ascii value.
  int ThisChar = LastChar;
  LastChar = getchar();
  return ThisChar;
}

//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//

namespace {

/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
  virtual ~ExprAST() = default;

  virtual Value *codegen() = 0;
};

/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
  double Val;

public:
  NumberExprAST(double Val) : Val(Val) {}

  Value *codegen() override;
};

/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
  std::string Name;

public:
  VariableExprAST(const std::string &Name) : Name(Name) {}

  Value *codegen() override;
};

/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
  char Op;
  std::unique_ptr<ExprAST> LHS, RHS;

public:
  BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                std::unique_ptr<ExprAST> RHS)
      : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}

  Value *codegen() override;
};

/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
  std::string Callee;
  std::vector<std::unique_ptr<ExprAST>> Args;

public:
  CallExprAST(const std::string &Callee,
              std::vector<std::unique_ptr<ExprAST>> Args)
      : Callee(Callee), Args(std::move(Args)) {}

  Value *codegen() override;
};

/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
  std::string Name;
  std::vector<std::string> Args;

public:
  PrototypeAST(const std::string &Name, std::vector<std::string> Args)
      : Name(Name), Args(std::move(Args)) {}

  Function *codegen();
  const std::string &getName() const { return Name; }
};

/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
  std::unique_ptr<PrototypeAST> Proto;
  std::unique_ptr<ExprAST> Body;

public:
  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
              std::unique_ptr<ExprAST> Body)
      : Proto(std::move(Proto)), Body(std::move(Body)) {}

  Function *codegen();
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//

/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
/// token the parser is looking at.  getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }

/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;

/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
  if (!isascii(CurTok))
    return -1;

  // Make sure it's a declared binop.
  int TokPrec = BinopPrecedence[CurTok];
  if (TokPrec <= 0)
    return -1;
  return TokPrec;
}

/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
  fprintf(stderr, "Error: %s\n", Str);
  return nullptr;
}

std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
  LogError(Str);
  return nullptr;
}

static std::unique_ptr<ExprAST> ParseExpression();

/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
  auto Result = std::make_unique<NumberExprAST>(NumVal);
  getNextToken(); // consume the number
  return std::move(Result);
}

/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
  getNextToken(); // eat (.
  auto V = ParseExpression();
  if (!V)
    return nullptr;

  if (CurTok != ')')
    return LogError("expected ')'");
  getNextToken(); // eat ).
  return V;
}

/// identifierexpr
///   ::= identifier
///   ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
  std::string IdName = IdentifierStr;

  getNextToken(); // eat identifier.

  if (CurTok != '(') // Simple variable ref.
    return std::make_unique<VariableExprAST>(IdName);

  // Call.
  getNextToken(); // eat (
  std::vector<std::unique_ptr<ExprAST>> Args;
  if (CurTok != ')') {
    while (true) {
      if (auto Arg = ParseExpression())
        Args.push_back(std::move(Arg));
      else
        return nullptr;

      if (CurTok == ')')
        break;

      if (CurTok != ',')
        return LogError("Expected ')' or ',' in argument list");
      getNextToken();
    }
  }

  // Eat the ')'.
  getNextToken();

  return std::make_unique<CallExprAST>(IdName, std::move(Args));
}

/// primary
///   ::= identifierexpr
///   ::= numberexpr
///   ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
  switch (CurTok) {
  default:
    return LogError("unknown token when expecting an expression");
  case tok_identifier:
    return ParseIdentifierExpr();
  case tok_number:
    return ParseNumberExpr();
  case '(':
    return ParseParenExpr();
  }
}

/// binoprhs
///   ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
                                              std::unique_ptr<ExprAST> LHS) {
  // If this is a binop, find its precedence.
  while (true) {
    int TokPrec = GetTokPrecedence();

    // If this is a binop that binds at least as tightly as the current binop,
    // consume it, otherwise we are done.
    if (TokPrec < ExprPrec)
      return LHS;

    // Okay, we know this is a binop.
    int BinOp = CurTok;
    getNextToken(); // eat binop

    // Parse the primary expression after the binary operator.
    auto RHS = ParsePrimary();
    if (!RHS)
      return nullptr;

    // If BinOp binds less tightly with RHS than the operator after RHS, let
    // the pending operator take RHS as its LHS.
    int NextPrec = GetTokPrecedence();
    if (TokPrec < NextPrec) {
      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
      if (!RHS)
        return nullptr;
    }

    // Merge LHS/RHS.
    LHS =
        std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
  }
}

/// expression
///   ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
  auto LHS = ParsePrimary();
  if (!LHS)
    return nullptr;

  return ParseBinOpRHS(0, std::move(LHS));
}

/// prototype
///   ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
  if (CurTok != tok_identifier)
    return LogErrorP("Expected function name in prototype");

  std::string FnName = IdentifierStr;
  getNextToken();

  if (CurTok != '(')
    return LogErrorP("Expected '(' in prototype");

  std::vector<std::string> ArgNames;
  while (getNextToken() == tok_identifier)
    ArgNames.push_back(IdentifierStr);
  if (CurTok != ')')
    return LogErrorP("Expected ')' in prototype");

  // success.
  getNextToken(); // eat ')'.

  return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}

/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
  getNextToken(); // eat def.
  auto Proto = ParsePrototype();
  if (!Proto)
    return nullptr;

  if (auto E = ParseExpression())
    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  return nullptr;
}

/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
  if (auto E = ParseExpression()) {
    // Make an anonymous proto.
    auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
                                                 std::vector<std::string>());
    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  }
  return nullptr;
}

/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
  getNextToken(); // eat extern.
  return ParsePrototype();
}

//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//

static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
static std::unique_ptr<FunctionPassManager> TheFPM;
static std::unique_ptr<LoopAnalysisManager> TheLAM;
static std::unique_ptr<FunctionAnalysisManager> TheFAM;
static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
static std::unique_ptr<ModuleAnalysisManager> TheMAM;
static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
static std::unique_ptr<StandardInstrumentations> TheSI;
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
static ExitOnError ExitOnErr;

Value *LogErrorV(const char *Str) {
  LogError(Str);
  return nullptr;
}

Function *getFunction(std::string Name) {
  // First, see if the function has already been added to the current module.
  if (auto *F = TheModule->getFunction(Name))
    return F;

  // If not, check whether we can codegen the declaration from some existing
  // prototype.
  auto FI = FunctionProtos.find(Name);
  if (FI != FunctionProtos.end())
    return FI->second->codegen();

  // If no existing prototype exists, return null.
  return nullptr;
}

Value *NumberExprAST::codegen() {
  return ConstantFP::get(*TheContext, APFloat(Val));
}

Value *VariableExprAST::codegen() {
  // Look this variable up in the function.
  Value *V = NamedValues[Name];
  if (!V)
    return LogErrorV("Unknown variable name");
  return V;
}

Value *BinaryExprAST::codegen() {
  Value *L = LHS->codegen();
  Value *R = RHS->codegen();
  if (!L || !R)
    return nullptr;

  switch (Op) {
  case '+':
    return Builder->CreateFAdd(L, R, "addtmp");
  case '-':
    return Builder->CreateFSub(L, R, "subtmp");
  case '*':
    return Builder->CreateFMul(L, R, "multmp");
  case '<':
    L = Builder->CreateFCmpULT(L, R, "cmptmp");
    // Convert bool 0/1 to double 0.0 or 1.0
    return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
  default:
    return LogErrorV("invalid binary operator");
  }
}

Value *CallExprAST::codegen() {
  // Look up the name in the global module table.
  Function *CalleeF = getFunction(Callee);
  if (!CalleeF)
    return LogErrorV("Unknown function referenced");

  // If argument mismatch error.
  if (CalleeF->arg_size() != Args.size())
    return LogErrorV("Incorrect # arguments passed");

  std::vector<Value *> ArgsV;
  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    ArgsV.push_back(Args[i]->codegen());
    if (!ArgsV.back())
      return nullptr;
  }

  return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}

Function *PrototypeAST::codegen() {
  // Make the function type:  double(double,double) etc.
  std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
  FunctionType *FT =
      FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);

  Function *F =
      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());

  // Set names for all arguments.
  unsigned Idx = 0;
  for (auto &Arg : F->args())
    Arg.setName(Args[Idx++]);

  return F;
}

Function *FunctionAST::codegen() {
  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  // reference to it for use below.
  auto &P = *Proto;
  FunctionProtos[Proto->getName()] = std::move(Proto);
  Function *TheFunction = getFunction(P.getName());
  if (!TheFunction)
    return nullptr;

  // Create a new basic block to start insertion into.
  BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
  Builder->SetInsertPoint(BB);

  // Record the function arguments in the NamedValues map.
  NamedValues.clear();
  for (auto &Arg : TheFunction->args())
    NamedValues[std::string(Arg.getName())] = &Arg;

  if (Value *RetVal = Body->codegen()) {
    // Finish off the function.
    Builder->CreateRet(RetVal);

    // Validate the generated code, checking for consistency.
    verifyFunction(*TheFunction);

    // Run the optimizer on the function.
    TheFPM->run(*TheFunction, *TheFAM);

    return TheFunction;
  }

  // Error reading body, remove function.
  TheFunction->eraseFromParent();
  return nullptr;
}

//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//

static void InitializeModuleAndManagers() {
  // Open a new context and module.
  TheContext = std::make_unique<LLVMContext>();
  TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
  TheModule->setDataLayout(TheJIT->getDataLayout());

  // Create a new builder for the module.
  Builder = std::make_unique<IRBuilder<>>(*TheContext);

  // Create new pass and analysis managers.
  TheFPM = std::make_unique<FunctionPassManager>();
  TheLAM = std::make_unique<LoopAnalysisManager>();
  TheFAM = std::make_unique<FunctionAnalysisManager>();
  TheCGAM = std::make_unique<CGSCCAnalysisManager>();
  TheMAM = std::make_unique<ModuleAnalysisManager>();
  ThePIC = std::make_unique<PassInstrumentationCallbacks>();
  TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
                                                     /*DebugLogging*/ true);
  TheSI->registerCallbacks(*ThePIC, TheMAM.get());

  // Add transform passes.
  // Do simple "peephole" optimizations and bit-twiddling optzns.
  TheFPM->addPass(InstCombinePass());
  // Reassociate expressions.
  TheFPM->addPass(ReassociatePass());
  // Eliminate Common SubExpressions.
  TheFPM->addPass(GVNPass());
  // Simplify the control flow graph (deleting unreachable blocks, etc).
  TheFPM->addPass(SimplifyCFGPass());

  // Register analysis passes used in these transform passes.
  PassBuilder PB;
  PB.registerModuleAnalyses(*TheMAM);
  PB.registerFunctionAnalyses(*TheFAM);
  PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}

static void HandleDefinition() {
  if (auto FnAST = ParseDefinition()) {
    if (auto *FnIR = FnAST->codegen()) {
      fprintf(stderr, "Read function definition:");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      ExitOnErr(TheJIT->addModule(
          ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
      InitializeModuleAndManagers();
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

static void HandleExtern() {
  if (auto ProtoAST = ParseExtern()) {
    if (auto *FnIR = ProtoAST->codegen()) {
      fprintf(stderr, "Read extern: ");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

static void HandleTopLevelExpression() {
  // Evaluate a top-level expression into an anonymous function.
  if (auto FnAST = ParseTopLevelExpr()) {
    if (FnAST->codegen()) {
      // Create a ResourceTracker to track JIT'd memory allocated to our
      // anonymous expression -- that way we can free it after executing.
      auto RT = TheJIT->getMainJITDylib().createResourceTracker();

      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
      ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
      InitializeModuleAndManagers();

      // Search the JIT for the __anon_expr symbol.
      auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));

      // Get the symbol's address and cast it to the right type (takes no
      // arguments, returns a double) so we can call it as a native function.
      double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
      fprintf(stderr, "Evaluated to %f\n", FP());

      // Delete the anonymous expression module from the JIT.
      ExitOnErr(RT->remove());
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

/// top ::= definition | external | expression | ';'
static void MainLoop() {
  while (true) {
    fprintf(stderr, "ready> ");
    switch (CurTok) {
    case tok_eof:
      return;
    case ';': // ignore top-level semicolons.
      getNextToken();
      break;
    case tok_def:
      HandleDefinition();
      break;
    case tok_extern:
      HandleExtern();
      break;
    default:
      HandleTopLevelExpression();
      break;
    }
  }
}

//===----------------------------------------------------------------------===//
// "Library" functions that can be "extern'd" from user code.
//===----------------------------------------------------------------------===//

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
  fputc((char)X, stderr);
  return 0;
}

/// printd - printf that takes a double prints it as "%f\n", returning 0.
extern "C" DLLEXPORT double printd(double X) {
  fprintf(stderr, "%f\n", X);
  return 0;
}

//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//

int main() {
  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();
  InitializeNativeTargetAsmParser();

  // Install standard binary operators.
  // 1 is lowest precedence.
  BinopPrecedence['<'] = 10;
  BinopPrecedence['+'] = 20;
  BinopPrecedence['-'] = 20;
  BinopPrecedence['*'] = 40; // highest.

  // Prime the first token.
  fprintf(stderr, "ready> ");
  getNextToken();

  TheJIT = ExitOnErr(KaleidoscopeJIT::Create());

  InitializeModuleAndManagers();

  // Run the main "interpreter loop" now.
  MainLoop();

  return 0;
}

下一步:擴展語言:控制流程