4. Kaleidoscope:新增 JIT 與最佳化器支援

4.1. 第 4 章 簡介

歡迎來到「使用 LLVM 實作語言」教學的第 4 章。第 1-3 章描述了一個簡單語言的實作,並新增了產生 LLVM IR 的支援。本章將介紹兩種新技術:為您的語言新增最佳化器支援,以及新增 JIT 編譯器支援。這些新增功能將示範如何為 Kaleidoscope 語言取得良好且有效率的程式碼。

4.2. 瑣碎的常數摺疊

我們在第 3 章中的示範既優雅又易於擴充。不幸的是,它沒有產生很棒的程式碼。然而,IRBuilder 在編譯簡單程式碼時,確實給了我們顯而易見的最佳化。

ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        ret double %addtmp
}

這段程式碼並不是解析輸入所建構的 AST 的字面轉錄。那會是

ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 2.000000e+00, 1.000000e+00
        %addtmp1 = fadd double %addtmp, %x
        ret double %addtmp1
}

如上所示,常數摺疊尤其是一種非常常見且非常重要的最佳化:如此重要,以至於許多語言實作者在其 AST 表示法中實作了常數摺疊支援。

使用 LLVM,您不需要在 AST 中提供此支援。由於所有對建構 LLVM IR 的呼叫都透過 LLVM IR 建構器,因此建構器本身會檢查在您呼叫它時是否存在常數摺疊的機會。如果有的話,它只會執行常數摺疊並傳回常數,而不是建立指令。

嗯,這很簡單:)。實際上,我們建議在產生像這樣的程式碼時,始終使用 IRBuilder。它的使用沒有「語法開銷」(您不必到處使用常數檢查來醜化您的編譯器),並且在某些情況下(特別是對於具有巨集前處理器或使用大量常數的語言),它可以大幅減少產生的 LLVM IR 數量。

另一方面,IRBuilder 受限於它在程式碼建構時執行所有內聯分析的事實。如果您採用稍微複雜的範例

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        %addtmp1 = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp1
        ret double %multmp
}

在這種情況下,乘法的 LHS 和 RHS 是相同的值。我們真的希望看到它產生「tmp = x+3; result = tmp*tmp;」而不是計算兩次「x+3」。

不幸的是,任何程度的局部分析都無法偵測和修正此問題。這需要兩個轉換:表達式的重新結合(使加法在詞彙上相同)和通用子表達式消除 (CSE) 以刪除冗餘的加法指令。幸運的是,LLVM 提供了廣泛的最佳化,您可以以「pass」的形式使用。

4.3. LLVM 最佳化Pass

LLVM 提供了許多最佳化 pass,它們執行許多不同種類的事情,並且具有不同的權衡考量。與其他系統不同,LLVM 不堅持認為一組最佳化適用於所有語言和所有情況的錯誤觀念。LLVM 允許編譯器實作者完全決定要使用哪些最佳化、以什麼順序以及在什麼情況下使用。

舉一個具體的例子,LLVM 同時支援「全模組」pass,它會在盡可能大的程式碼體中查找(通常是一個完整的文件,但如果在連結時執行,這可能是整個程式的很大一部分)。它還支援並包含「逐函數」pass,它們一次僅對單一函數進行操作,而不查看其他函數。有關 pass 及其運作方式的更多資訊,請參閱如何編寫 Pass 文件和 LLVM Pass 列表

對於 Kaleidoscope,我們目前正在使用者輸入函數時,即時逐個產生函數。我們並不是在這種設定中追求終極的最佳化體驗,但我們也希望在可能的情況下捕捉到簡單且快速的東西。因此,我們將選擇在使用者輸入函數時執行一些逐函數最佳化。如果我們想要製作「靜態 Kaleidoscope 編譯器」,我們將完全使用我們現在擁有的程式碼,除了我們會延遲執行最佳化器,直到整個檔案都被解析完畢。

除了函數和模組 pass 之間的區別之外,pass 還可以分為轉換和分析 pass。轉換 pass 會變更 IR,而分析 pass 會計算其他 pass 可以使用的資訊。為了新增轉換 pass,它所依賴的所有分析 pass 都必須預先註冊。

為了開始進行逐函數最佳化,我們需要設定一個 FunctionPassManager 來保存和組織我們想要執行的 LLVM 最佳化。一旦我們有了這個,我們就可以新增一組要執行的最佳化。我們需要為我們要最佳化的每個模組建立新的 FunctionPassManager,因此我們將新增到前一章中建立的函數 (InitializeModule())。

void InitializeModuleAndManagers(void) {
  // Open a new context and module.
  TheContext = std::make_unique<LLVMContext>();
  TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
  TheModule->setDataLayout(TheJIT->getDataLayout());

  // Create a new builder for the module.
  Builder = std::make_unique<IRBuilder<>>(*TheContext);

  // Create new pass and analysis managers.
  TheFPM = std::make_unique<FunctionPassManager>();
  TheLAM = std::make_unique<LoopAnalysisManager>();
  TheFAM = std::make_unique<FunctionAnalysisManager>();
  TheCGAM = std::make_unique<CGSCCAnalysisManager>();
  TheMAM = std::make_unique<ModuleAnalysisManager>();
  ThePIC = std::make_unique<PassInstrumentationCallbacks>();
  TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
                                                    /*DebugLogging*/ true);
  TheSI->registerCallbacks(*ThePIC, TheMAM.get());
  ...

在初始化全域模組 TheModule 和 FunctionPassManager 之後,我們需要初始化框架的其他部分。四個 AnalysisManager 允許我們新增在 IR 階層的四個層級上運行的分析 pass。PassInstrumentationCallbacks 和 StandardInstrumentations 是 pass instrumentation 框架所必需的,它允許開發人員自訂 pass 之間發生的事情。

一旦設定好這些管理器,我們就會使用一系列「addPass」呼叫來新增一堆 LLVM 轉換 pass。

// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());

在這種情況下,我們選擇新增四個最佳化 pass。我們在這裡選擇的 pass 是一組相當標準的「清理」最佳化,它們適用於各種程式碼。我不會深入探討它們的作用,但請相信我,它們是一個很好的起點:)。

接下來,我們註冊轉換 pass 使用的分析 pass。

  // Register analysis passes used in these transform passes.
  PassBuilder PB;
  PB.registerModuleAnalyses(*TheMAM);
  PB.registerFunctionAnalyses(*TheFAM);
  PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}

一旦設定好 PassManager,我們就需要使用它。我們在我們新建立的函數建構完成後(在 FunctionAST::codegen() 中),但在將其傳回給客戶端之前執行它。

if (Value *RetVal = Body->codegen()) {
  // Finish off the function.
  Builder.CreateRet(RetVal);

  // Validate the generated code, checking for consistency.
  verifyFunction(*TheFunction);

  // Optimize the function.
  TheFPM->run(*TheFunction, *TheFAM);

  return TheFunction;
}

如您所見,這非常簡單明瞭。FunctionPassManager 會就地最佳化和更新 LLVM Function*,從而(希望)改善其主體。有了這個,我們可以再次嘗試上面的測試。

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp
        ret double %multmp
}

正如預期的那樣,我們現在得到了良好最佳化的程式碼,從這個函數的每次執行中節省了一個浮點加法指令。

LLVM 提供了各種各樣的最佳化,可以在特定情況下使用。一些關於各種 pass 的 文件是可用的,但它不是很完整。另一個好的想法來源可以來自查看 Clang 運行的 pass 以開始使用。「opt」工具允許您從命令列試驗 pass,以便您可以查看它們是否執行任何操作。

現在我們已經從前端獲得了合理的程式碼,讓我們來談談執行它吧!

4.4. 新增 JIT 編譯器

LLVM IR 中可用的程式碼可以應用各種工具。例如,您可以在其上執行最佳化(如我們上面所做的那樣),您可以以文字或二進制形式轉儲它,您可以將程式碼編譯為某些目標的組譯檔案 (.s),或者您可以 JIT 編譯它。關於 LLVM IR 表示法的好處是,它是編譯器許多不同部分之間的「通用貨幣」。

在本節中,我們將為我們的直譯器新增 JIT 編譯器支援。我們希望 Kaleidoscope 的基本概念是讓使用者像現在這樣輸入函數主體,但立即評估他們輸入的頂層表達式。例如,如果他們輸入「1 + 2;」,我們應該評估並印出 3。如果他們定義一個函數,他們應該能夠從命令列呼叫它。

為了做到這一點,我們首先準備環境,為目前的本機目標建立程式碼,並宣告和初始化 JIT。這是透過呼叫一些 InitializeNativeTarget\* 函數並新增一個全域變數 TheJIT,並在 main 中初始化它來完成的。

static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();
  InitializeNativeTargetAsmParser();

  // Install standard binary operators.
  // 1 is lowest precedence.
  BinopPrecedence['<'] = 10;
  BinopPrecedence['+'] = 20;
  BinopPrecedence['-'] = 20;
  BinopPrecedence['*'] = 40; // highest.

  // Prime the first token.
  fprintf(stderr, "ready> ");
  getNextToken();

  TheJIT = std::make_unique<KaleidoscopeJIT>();

  // Run the main "interpreter loop" now.
  MainLoop();

  return 0;
}

我們還需要為 JIT 設定資料佈局。

void InitializeModuleAndPassManager(void) {
  // Open a new context and module.
  TheContext = std::make_unique<LLVMContext>();
  TheModule = std::make_unique<Module>("my cool jit", TheContext);
  TheModule->setDataLayout(TheJIT->getDataLayout());

  // Create a new builder for the module.
  Builder = std::make_unique<IRBuilder<>>(*TheContext);

  // Create a new pass manager attached to it.
  TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
  ...

KaleidoscopeJIT 類別是一個專為這些教學建構的簡單 JIT,可在 LLVM 原始碼的 llvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.h 中找到。在後面的章節中,我們將研究它的運作方式,並使用新功能擴充它,但現在我們將把它視為給定的。它的 API 非常簡單:addModule 將 LLVM IR 模組新增到 JIT,使其函數可用於執行(其記憶體由 ResourceTracker 管理);而 lookup 允許我們查找指向已編譯程式碼的指標。

我們可以採用這個簡單的 API,並將我們解析頂層表達式的程式碼變更為如下所示。

static ExitOnError ExitOnErr;
...
static void HandleTopLevelExpression() {
  // Evaluate a top-level expression into an anonymous function.
  if (auto FnAST = ParseTopLevelExpr()) {
    if (FnAST->codegen()) {
      // Create a ResourceTracker to track JIT'd memory allocated to our
      // anonymous expression -- that way we can free it after executing.
      auto RT = TheJIT->getMainJITDylib().createResourceTracker();

      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
      ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
      InitializeModuleAndPassManager();

      // Search the JIT for the __anon_expr symbol.
      auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
      assert(ExprSymbol && "Function not found");

      // Get the symbol's address and cast it to the right type (takes no
      // arguments, returns a double) so we can call it as a native function.
      double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
      fprintf(stderr, "Evaluated to %f\n", FP());

      // Delete the anonymous expression module from the JIT.
      ExitOnErr(RT->remove());
    }

如果解析和程式碼產生成功,下一步是將包含頂層表達式的模組新增到 JIT。我們透過呼叫 addModule 來完成此操作,這會觸發模組中所有函數的程式碼產生,並接受一個 ResourceTracker,可用於稍後從 JIT 中移除模組。一旦模組被新增到 JIT,它就不能再被修改,因此我們也透過呼叫 InitializeModuleAndPassManager() 開啟一個新模組來保存後續程式碼。

一旦我們將模組新增到 JIT,我們需要取得指向最終產生程式碼的指標。我們透過呼叫 JIT 的 lookup 方法,並傳遞頂層表達式函數的名稱:__anon_expr。由於我們剛剛新增了這個函數,我們斷言 lookup 傳回了結果。

接下來,我們透過在符號上呼叫 getAddress() 來取得 __anon_expr 函數的記憶體位址。回想一下,我們將頂層表達式編譯成一個獨立的 LLVM 函數,該函數不帶任何引數並傳回計算出的 double 值。由於 LLVM JIT 編譯器符合本機平台 ABI,這表示您可以直接將結果指標轉換為該類型的函數指標並直接呼叫它。這表示,JIT 編譯的程式碼與靜態連結到您的應用程式中的本機機器碼沒有區別。

最後,由於我們不支援重新評估頂層表達式,我們會在完成後從 JIT 中移除模組,以釋放相關的記憶體。但是,請回想一下,我們在幾行之前建立的模組(透過 InitializeModuleAndPassManager)仍然處於開啟狀態,並等待新增新程式碼。

僅僅透過這兩個變更,讓我們看看 Kaleidoscope 現在如何運作!

ready> 4+5;
Read top-level expression:
define double @0() {
entry:
  ret double 9.000000e+00
}

Evaluated to 9.000000

嗯,這看起來基本上是可行的。函數的轉儲顯示了我們為每個輸入的頂層表達式合成的「永遠傳回 double 且不帶引數的函數」。這展示了非常基本的功能,但我們可以做得更多嗎?

ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
  %multmp = fmul double %y, 2.000000e+00
  %addtmp = fadd double %multmp, %x
  ret double %addtmp
}

ready> testfunc(4, 10);
Read top-level expression:
define double @1() {
entry:
  %calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
  ret double %calltmp
}

Evaluated to 24.000000

ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!

函數定義和呼叫也有效,但最後一行出現了一些非常錯誤的情況。呼叫看起來有效,那麼發生了什麼事?正如您可能從 API 中猜到的那樣,模組是 JIT 的分配單位,而 testfunc 與包含匿名表達式的模組屬於同一個模組。當我們從 JIT 中移除該模組以釋放匿名表達式的記憶體時,我們也刪除了 testfunc 的定義。然後,當我們嘗試第二次呼叫 testfunc 時,JIT 無法再找到它。

修復這個問題的最簡單方法是將匿名表達式放在與其餘函數定義不同的模組中。JIT 會很樂意跨模組邊界解析函數呼叫,只要每個被呼叫的函數都有原型,並且在被呼叫之前被新增到 JIT。透過將匿名表達式放在不同的模組中,我們可以刪除它,而不會影響其餘的函數。

事實上,我們將更進一步,將每個函數放在它自己的模組中。這樣做可以讓我們利用 KaleidoscopeJIT 的一個有用的屬性,這將使我們的環境更像 REPL:函數可以多次新增到 JIT(與模組不同,模組中每個函數都必須具有唯一的定義)。當您在 KaleidoscopeJIT 中查找符號時,它將始終傳回最新的定義。

ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 1.000000e+00
  ret double %addtmp
}

ready> foo(2);
Evaluated to 3.000000

ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
  %addtmp = fadd double %x, 2.000000e+00
  ret double %addtmp
}

ready> foo(2);
Evaluated to 4.000000

為了允許每個函數存在於它自己的模組中,我們需要一種方法來將先前的函數宣告重新產生到我們開啟的每個新模組中。

static std::unique_ptr<KaleidoscopeJIT> TheJIT;

...

Function *getFunction(std::string Name) {
  // First, see if the function has already been added to the current module.
  if (auto *F = TheModule->getFunction(Name))
    return F;

  // If not, check whether we can codegen the declaration from some existing
  // prototype.
  auto FI = FunctionProtos.find(Name);
  if (FI != FunctionProtos.end())
    return FI->second->codegen();

  // If no existing prototype exists, return null.
  return nullptr;
}

...

Value *CallExprAST::codegen() {
  // Look up the name in the global module table.
  Function *CalleeF = getFunction(Callee);

...

Function *FunctionAST::codegen() {
  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  // reference to it for use below.
  auto &P = *Proto;
  FunctionProtos[Proto->getName()] = std::move(Proto);
  Function *TheFunction = getFunction(P.getName());
  if (!TheFunction)
    return nullptr;

為了啟用此功能,我們將從新增一個新的全域變數 FunctionProtos 開始,它保存每個函數的最新原型。我們還將新增一個方便的方法 getFunction(),以取代對 TheModule->getFunction() 的呼叫。我們的方便方法在 TheModule 中搜尋現有的函數宣告,如果找不到,則回退以從 FunctionProtos 產生新的宣告。在 CallExprAST::codegen() 中,我們只需要取代對 TheModule->getFunction() 的呼叫。在 FunctionAST::codegen() 中,我們需要先更新 FunctionProtos 映射,然後呼叫 getFunction()。完成此操作後,我們始終可以在目前模組中取得任何先前宣告的函數的函數宣告。

我們還需要更新 HandleDefinition 和 HandleExtern。

static void HandleDefinition() {
  if (auto FnAST = ParseDefinition()) {
    if (auto *FnIR = FnAST->codegen()) {
      fprintf(stderr, "Read function definition:");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      ExitOnErr(TheJIT->addModule(
          ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
      InitializeModuleAndPassManager();
    }
  } else {
    // Skip token for error recovery.
     getNextToken();
  }
}

static void HandleExtern() {
  if (auto ProtoAST = ParseExtern()) {
    if (auto *FnIR = ProtoAST->codegen()) {
      fprintf(stderr, "Read extern: ");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

在 HandleDefinition 中,我們新增兩行程式碼,將新定義的函數傳輸到 JIT 並開啟一個新模組。在 HandleExtern 中,我們只需要新增一行程式碼,將原型新增到 FunctionProtos。

警告

自 LLVM-9 以來,不允許在不同的模組中重複符號。這表示您不能像下面顯示的那樣在 Kaleidoscope 中重新定義函數。只需跳過這部分即可。

原因是較新的 OrcV2 JIT API 嘗試非常接近靜態和動態連結器的規則,包括拒絕重複的符號。要求符號名稱是唯一的,這讓我們可以使用(唯一的)符號名稱作為追蹤的鍵來支援符號的並行編譯。

完成這些變更後,讓我們再次嘗試我們的 REPL(這次我移除了匿名函數的轉儲,您現在應該了解它的概念了 :)

ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000

ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000

它運作了!

即使使用這個簡單的程式碼,我們也獲得了一些令人驚訝的強大功能 - 看看這個

ready> extern sin(x);
Read extern:
declare double @sin(double)

ready> extern cos(x);
Read extern:
declare double @cos(double)

ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
  ret double 0x3FEAED548F090CEE
}

Evaluated to 0.841471

ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
  %calltmp = call double @sin(double %x)
  %multmp = fmul double %calltmp, %calltmp
  %calltmp2 = call double @cos(double %x)
  %multmp4 = fmul double %calltmp2, %calltmp2
  %addtmp = fadd double %multmp, %multmp4
  ret double %addtmp
}

ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
  %calltmp = call double @foo(double 4.000000e+00)
  ret double %calltmp
}

Evaluated to 1.000000

哇,JIT 怎麼知道 sin 和 cos?答案非常簡單:KaleidoscopeJIT 有一個直接的符號解析規則,它用於尋找任何給定模組中不可用的符號:首先,它搜尋已新增到 JIT 的所有模組,從最新的到最舊的,以尋找最新的定義。如果在 JIT 內部找不到定義,它會回退呼叫 Kaleidoscope 處理程序本身的「dlsym("sin")」。由於「sin」是在 JIT 的位址空間內定義的,因此它只是修補模組中的呼叫,以直接呼叫 libm 版本的 sin。但在某些情況下,這甚至更進一步:由於 sin 和 cos 是標準數學函數的名稱,因此當使用像「sin(1.0)」中那樣的常數呼叫時,常數摺疊器將直接評估函數呼叫以獲得正確的結果。

在未來,我們將看到如何調整此符號解析規則可用於啟用各種有用的功能,從安全性(限制 JIT 程式碼可用的符號集)到基於符號名稱的動態程式碼產生,甚至延遲編譯。

符號解析規則的一個直接好處是,我們現在可以透過編寫任意 C++ 程式碼來實作運算,從而擴充語言。例如,如果我們新增

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
  fputc((char)X, stderr);
  return 0;
}

請注意,對於 Windows,我們需要實際匯出函數,因為動態符號載入器將使用 GetProcAddress 來尋找符號。

現在,我們可以透過使用類似「extern putchard(x); putchard(120);」之類的東西來產生簡單的輸出到主控台,這會在主控台上印出小寫的 'x'(120 是 'x' 的 ASCII 碼)。類似的程式碼可用於在 Kaleidoscope 中實作檔案 I/O、主控台輸入和許多其他功能。

這完成了 Kaleidoscope 教學的 JIT 和最佳化器章節。在這一點上,我們可以編譯一種非圖靈完備的程式語言,以使用者驅動的方式最佳化和 JIT 編譯它。接下來,我們將研究 使用控制流程結構擴充語言,並在此過程中解決一些有趣的 LLVM IR 問題。

4.5. 完整程式碼列表

以下是我們執行範例的完整程式碼列表,其中增強了 LLVM JIT 和最佳化器。要建置此範例,請使用

# Compile
clang++ -g toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native` -O3 -o toy
# Run
./toy

如果您在 Linux 上編譯此程式碼,請確保也新增「-rdynamic」選項。這確保了外部函數在執行時能夠正確解析。

以下是程式碼

#include "../include/KaleidoscopeJIT.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/Reassociate.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>

using namespace llvm;
using namespace llvm::orc;

//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//

// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
  tok_eof = -1,

  // commands
  tok_def = -2,
  tok_extern = -3,

  // primary
  tok_identifier = -4,
  tok_number = -5
};

static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal;             // Filled in if tok_number

/// gettok - Return the next token from standard input.
static int gettok() {
  static int LastChar = ' ';

  // Skip any whitespace.
  while (isspace(LastChar))
    LastChar = getchar();

  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
    IdentifierStr = LastChar;
    while (isalnum((LastChar = getchar())))
      IdentifierStr += LastChar;

    if (IdentifierStr == "def")
      return tok_def;
    if (IdentifierStr == "extern")
      return tok_extern;
    return tok_identifier;
  }

  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
    std::string NumStr;
    do {
      NumStr += LastChar;
      LastChar = getchar();
    } while (isdigit(LastChar) || LastChar == '.');

    NumVal = strtod(NumStr.c_str(), nullptr);
    return tok_number;
  }

  if (LastChar == '#') {
    // Comment until end of line.
    do
      LastChar = getchar();
    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');

    if (LastChar != EOF)
      return gettok();
  }

  // Check for end of file.  Don't eat the EOF.
  if (LastChar == EOF)
    return tok_eof;

  // Otherwise, just return the character as its ascii value.
  int ThisChar = LastChar;
  LastChar = getchar();
  return ThisChar;
}

//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//

namespace {

/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
  virtual ~ExprAST() = default;

  virtual Value *codegen() = 0;
};

/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
  double Val;

public:
  NumberExprAST(double Val) : Val(Val) {}

  Value *codegen() override;
};

/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
  std::string Name;

public:
  VariableExprAST(const std::string &Name) : Name(Name) {}

  Value *codegen() override;
};

/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
  char Op;
  std::unique_ptr<ExprAST> LHS, RHS;

public:
  BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
                std::unique_ptr<ExprAST> RHS)
      : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}

  Value *codegen() override;
};

/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
  std::string Callee;
  std::vector<std::unique_ptr<ExprAST>> Args;

public:
  CallExprAST(const std::string &Callee,
              std::vector<std::unique_ptr<ExprAST>> Args)
      : Callee(Callee), Args(std::move(Args)) {}

  Value *codegen() override;
};

/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
  std::string Name;
  std::vector<std::string> Args;

public:
  PrototypeAST(const std::string &Name, std::vector<std::string> Args)
      : Name(Name), Args(std::move(Args)) {}

  Function *codegen();
  const std::string &getName() const { return Name; }
};

/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
  std::unique_ptr<PrototypeAST> Proto;
  std::unique_ptr<ExprAST> Body;

public:
  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
              std::unique_ptr<ExprAST> Body)
      : Proto(std::move(Proto)), Body(std::move(Body)) {}

  Function *codegen();
};

} // end anonymous namespace

//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//

/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
/// token the parser is looking at.  getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }

/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;

/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
  if (!isascii(CurTok))
    return -1;

  // Make sure it's a declared binop.
  int TokPrec = BinopPrecedence[CurTok];
  if (TokPrec <= 0)
    return -1;
  return TokPrec;
}

/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
  fprintf(stderr, "Error: %s\n", Str);
  return nullptr;
}

std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
  LogError(Str);
  return nullptr;
}

static std::unique_ptr<ExprAST> ParseExpression();

/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
  auto Result = std::make_unique<NumberExprAST>(NumVal);
  getNextToken(); // consume the number
  return std::move(Result);
}

/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
  getNextToken(); // eat (.
  auto V = ParseExpression();
  if (!V)
    return nullptr;

  if (CurTok != ')')
    return LogError("expected ')'");
  getNextToken(); // eat ).
  return V;
}

/// identifierexpr
///   ::= identifier
///   ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
  std::string IdName = IdentifierStr;

  getNextToken(); // eat identifier.

  if (CurTok != '(') // Simple variable ref.
    return std::make_unique<VariableExprAST>(IdName);

  // Call.
  getNextToken(); // eat (
  std::vector<std::unique_ptr<ExprAST>> Args;
  if (CurTok != ')') {
    while (true) {
      if (auto Arg = ParseExpression())
        Args.push_back(std::move(Arg));
      else
        return nullptr;

      if (CurTok == ')')
        break;

      if (CurTok != ',')
        return LogError("Expected ')' or ',' in argument list");
      getNextToken();
    }
  }

  // Eat the ')'.
  getNextToken();

  return std::make_unique<CallExprAST>(IdName, std::move(Args));
}

/// primary
///   ::= identifierexpr
///   ::= numberexpr
///   ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
  switch (CurTok) {
  default:
    return LogError("unknown token when expecting an expression");
  case tok_identifier:
    return ParseIdentifierExpr();
  case tok_number:
    return ParseNumberExpr();
  case '(':
    return ParseParenExpr();
  }
}

/// binoprhs
///   ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
                                              std::unique_ptr<ExprAST> LHS) {
  // If this is a binop, find its precedence.
  while (true) {
    int TokPrec = GetTokPrecedence();

    // If this is a binop that binds at least as tightly as the current binop,
    // consume it, otherwise we are done.
    if (TokPrec < ExprPrec)
      return LHS;

    // Okay, we know this is a binop.
    int BinOp = CurTok;
    getNextToken(); // eat binop

    // Parse the primary expression after the binary operator.
    auto RHS = ParsePrimary();
    if (!RHS)
      return nullptr;

    // If BinOp binds less tightly with RHS than the operator after RHS, let
    // the pending operator take RHS as its LHS.
    int NextPrec = GetTokPrecedence();
    if (TokPrec < NextPrec) {
      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
      if (!RHS)
        return nullptr;
    }

    // Merge LHS/RHS.
    LHS =
        std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
  }
}

/// expression
///   ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
  auto LHS = ParsePrimary();
  if (!LHS)
    return nullptr;

  return ParseBinOpRHS(0, std::move(LHS));
}

/// prototype
///   ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
  if (CurTok != tok_identifier)
    return LogErrorP("Expected function name in prototype");

  std::string FnName = IdentifierStr;
  getNextToken();

  if (CurTok != '(')
    return LogErrorP("Expected '(' in prototype");

  std::vector<std::string> ArgNames;
  while (getNextToken() == tok_identifier)
    ArgNames.push_back(IdentifierStr);
  if (CurTok != ')')
    return LogErrorP("Expected ')' in prototype");

  // success.
  getNextToken(); // eat ')'.

  return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}

/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
  getNextToken(); // eat def.
  auto Proto = ParsePrototype();
  if (!Proto)
    return nullptr;

  if (auto E = ParseExpression())
    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  return nullptr;
}

/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
  if (auto E = ParseExpression()) {
    // Make an anonymous proto.
    auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
                                                 std::vector<std::string>());
    return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
  }
  return nullptr;
}

/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
  getNextToken(); // eat extern.
  return ParsePrototype();
}

//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//

static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
static std::unique_ptr<FunctionPassManager> TheFPM;
static std::unique_ptr<LoopAnalysisManager> TheLAM;
static std::unique_ptr<FunctionAnalysisManager> TheFAM;
static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
static std::unique_ptr<ModuleAnalysisManager> TheMAM;
static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
static std::unique_ptr<StandardInstrumentations> TheSI;
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
static ExitOnError ExitOnErr;

Value *LogErrorV(const char *Str) {
  LogError(Str);
  return nullptr;
}

Function *getFunction(std::string Name) {
  // First, see if the function has already been added to the current module.
  if (auto *F = TheModule->getFunction(Name))
    return F;

  // If not, check whether we can codegen the declaration from some existing
  // prototype.
  auto FI = FunctionProtos.find(Name);
  if (FI != FunctionProtos.end())
    return FI->second->codegen();

  // If no existing prototype exists, return null.
  return nullptr;
}

Value *NumberExprAST::codegen() {
  return ConstantFP::get(*TheContext, APFloat(Val));
}

Value *VariableExprAST::codegen() {
  // Look this variable up in the function.
  Value *V = NamedValues[Name];
  if (!V)
    return LogErrorV("Unknown variable name");
  return V;
}

Value *BinaryExprAST::codegen() {
  Value *L = LHS->codegen();
  Value *R = RHS->codegen();
  if (!L || !R)
    return nullptr;

  switch (Op) {
  case '+':
    return Builder->CreateFAdd(L, R, "addtmp");
  case '-':
    return Builder->CreateFSub(L, R, "subtmp");
  case '*':
    return Builder->CreateFMul(L, R, "multmp");
  case '<':
    L = Builder->CreateFCmpULT(L, R, "cmptmp");
    // Convert bool 0/1 to double 0.0 or 1.0
    return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
  default:
    return LogErrorV("invalid binary operator");
  }
}

Value *CallExprAST::codegen() {
  // Look up the name in the global module table.
  Function *CalleeF = getFunction(Callee);
  if (!CalleeF)
    return LogErrorV("Unknown function referenced");

  // If argument mismatch error.
  if (CalleeF->arg_size() != Args.size())
    return LogErrorV("Incorrect # arguments passed");

  std::vector<Value *> ArgsV;
  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
    ArgsV.push_back(Args[i]->codegen());
    if (!ArgsV.back())
      return nullptr;
  }

  return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}

Function *PrototypeAST::codegen() {
  // Make the function type:  double(double,double) etc.
  std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
  FunctionType *FT =
      FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);

  Function *F =
      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());

  // Set names for all arguments.
  unsigned Idx = 0;
  for (auto &Arg : F->args())
    Arg.setName(Args[Idx++]);

  return F;
}

Function *FunctionAST::codegen() {
  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
  // reference to it for use below.
  auto &P = *Proto;
  FunctionProtos[Proto->getName()] = std::move(Proto);
  Function *TheFunction = getFunction(P.getName());
  if (!TheFunction)
    return nullptr;

  // Create a new basic block to start insertion into.
  BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
  Builder->SetInsertPoint(BB);

  // Record the function arguments in the NamedValues map.
  NamedValues.clear();
  for (auto &Arg : TheFunction->args())
    NamedValues[std::string(Arg.getName())] = &Arg;

  if (Value *RetVal = Body->codegen()) {
    // Finish off the function.
    Builder->CreateRet(RetVal);

    // Validate the generated code, checking for consistency.
    verifyFunction(*TheFunction);

    // Run the optimizer on the function.
    TheFPM->run(*TheFunction, *TheFAM);

    return TheFunction;
  }

  // Error reading body, remove function.
  TheFunction->eraseFromParent();
  return nullptr;
}

//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//

static void InitializeModuleAndManagers() {
  // Open a new context and module.
  TheContext = std::make_unique<LLVMContext>();
  TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
  TheModule->setDataLayout(TheJIT->getDataLayout());

  // Create a new builder for the module.
  Builder = std::make_unique<IRBuilder<>>(*TheContext);

  // Create new pass and analysis managers.
  TheFPM = std::make_unique<FunctionPassManager>();
  TheLAM = std::make_unique<LoopAnalysisManager>();
  TheFAM = std::make_unique<FunctionAnalysisManager>();
  TheCGAM = std::make_unique<CGSCCAnalysisManager>();
  TheMAM = std::make_unique<ModuleAnalysisManager>();
  ThePIC = std::make_unique<PassInstrumentationCallbacks>();
  TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
                                                     /*DebugLogging*/ true);
  TheSI->registerCallbacks(*ThePIC, TheMAM.get());

  // Add transform passes.
  // Do simple "peephole" optimizations and bit-twiddling optzns.
  TheFPM->addPass(InstCombinePass());
  // Reassociate expressions.
  TheFPM->addPass(ReassociatePass());
  // Eliminate Common SubExpressions.
  TheFPM->addPass(GVNPass());
  // Simplify the control flow graph (deleting unreachable blocks, etc).
  TheFPM->addPass(SimplifyCFGPass());

  // Register analysis passes used in these transform passes.
  PassBuilder PB;
  PB.registerModuleAnalyses(*TheMAM);
  PB.registerFunctionAnalyses(*TheFAM);
  PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}

static void HandleDefinition() {
  if (auto FnAST = ParseDefinition()) {
    if (auto *FnIR = FnAST->codegen()) {
      fprintf(stderr, "Read function definition:");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      ExitOnErr(TheJIT->addModule(
          ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
      InitializeModuleAndManagers();
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

static void HandleExtern() {
  if (auto ProtoAST = ParseExtern()) {
    if (auto *FnIR = ProtoAST->codegen()) {
      fprintf(stderr, "Read extern: ");
      FnIR->print(errs());
      fprintf(stderr, "\n");
      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

static void HandleTopLevelExpression() {
  // Evaluate a top-level expression into an anonymous function.
  if (auto FnAST = ParseTopLevelExpr()) {
    if (FnAST->codegen()) {
      // Create a ResourceTracker to track JIT'd memory allocated to our
      // anonymous expression -- that way we can free it after executing.
      auto RT = TheJIT->getMainJITDylib().createResourceTracker();

      auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
      ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
      InitializeModuleAndManagers();

      // Search the JIT for the __anon_expr symbol.
      auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));

      // Get the symbol's address and cast it to the right type (takes no
      // arguments, returns a double) so we can call it as a native function.
      double (*FP)() = ExprSymbol.toPtr<double (*)()>();
      fprintf(stderr, "Evaluated to %f\n", FP());

      // Delete the anonymous expression module from the JIT.
      ExitOnErr(RT->remove());
    }
  } else {
    // Skip token for error recovery.
    getNextToken();
  }
}

/// top ::= definition | external | expression | ';'
static void MainLoop() {
  while (true) {
    fprintf(stderr, "ready> ");
    switch (CurTok) {
    case tok_eof:
      return;
    case ';': // ignore top-level semicolons.
      getNextToken();
      break;
    case tok_def:
      HandleDefinition();
      break;
    case tok_extern:
      HandleExtern();
      break;
    default:
      HandleTopLevelExpression();
      break;
    }
  }
}

//===----------------------------------------------------------------------===//
// "Library" functions that can be "extern'd" from user code.
//===----------------------------------------------------------------------===//

#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif

/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
  fputc((char)X, stderr);
  return 0;
}

/// printd - printf that takes a double prints it as "%f\n", returning 0.
extern "C" DLLEXPORT double printd(double X) {
  fprintf(stderr, "%f\n", X);
  return 0;
}

//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//

int main() {
  InitializeNativeTarget();
  InitializeNativeTargetAsmPrinter();
  InitializeNativeTargetAsmParser();

  // Install standard binary operators.
  // 1 is lowest precedence.
  BinopPrecedence['<'] = 10;
  BinopPrecedence['+'] = 20;
  BinopPrecedence['-'] = 20;
  BinopPrecedence['*'] = 40; // highest.

  // Prime the first token.
  fprintf(stderr, "ready> ");
  getNextToken();

  TheJIT = ExitOnErr(KaleidoscopeJIT::Create());

  InitializeModuleAndManagers();

  // Run the main "interpreter loop" now.
  MainLoop();

  return 0;
}

下一步:擴充語言:控制流程