/*
 * Sample Brainfuck JIT compiler using LLVM 2.8
 */
#include <stack>
#include <fstream>
#include <iostream>

#include <llvm/Module.h>
#include <llvm/Function.h>
#include <llvm/PassManager.h>
#include <llvm/CallingConv.h>
#include <llvm/Analysis/Verifier.h>
#include <llvm/Support/IRBuilder.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/Target/TargetData.h>
#include <llvm/LinkAllPasses.h>
#include <llvm/ExecutionEngine/JIT.h>
#include <llvm/Target/TargetSelect.h>
#include <llvm/LLVMContext.h>
using namespace llvm;

struct bfLoopInfo {
  Value* beforeValue;
  PHINode* startValue;
  Value* endValue;
  Value* afterValue;
  
  BasicBlock* beforeBlock;
  BasicBlock* startBlock;
  BasicBlock* endBlock;
  BasicBlock* afterBlock;
};

Function* makeFunc(Module* module, const char* source, int tapeSize = 400)
{
  // Some useful types and constants
  const Type* voidType = Type::getVoidTy(getGlobalContext());
  const IntegerType* cellType = IntegerType::get(getGlobalContext(), 8);
  const IntegerType* indexType = IntegerType::get(getGlobalContext(), 32);
  const PointerType* tapeType = PointerType::get(cellType, 0);
  Value* zero = ConstantInt::get(cellType, 0);
  Value* one = ConstantInt::get(cellType, 1);
  Value* minOne = ConstantInt::get(cellType, -1);
  
  // declare i32 @getchar()
  Function* getchar = cast<Function>(
     module->getOrInsertFunction("getchar", cellType, NULL));
  getchar->setCallingConv(CallingConv::C);
  
  // declare i32 @putchar(i32)
  Function* putchar = cast<Function>(
     module->getOrInsertFunction("putchar", voidType, cellType, NULL));
  putchar->setCallingConv(CallingConv::C);
  
  // Contruct void main(char* tape)
  Function* main = cast<Function>(
     module->getOrInsertFunction("main", voidType, NULL));
  main->setCallingConv(CallingConv::C);
  BasicBlock* block = BasicBlock::Create(getGlobalContext(), "code", main);
  std::stack<bfLoopInfo> loops;
  IRBuilder<> codeIR(block);
  Value* head = codeIR.CreateAlloca(cellType, ConstantInt::get(indexType, tapeSize));
  Value* it = head;
  for (int i = 0; i < tapeSize; i++) {
    codeIR.CreateStore(zero, it);
    it = codeIR.CreateGEP(it, one);
  }
  while(*source) {
    IRBuilder<> builder(block);
    switch(*source++) {
      case '>': head = builder.CreateGEP(head, one); break;
      case '<': head = builder.CreateGEP(head, minOne); break;
      case '+': {
        Value* headValue = builder.CreateLoad(head);
        Value* result = builder.CreateAdd(headValue, one);
        builder.CreateStore(result, head);
        break;
      }
      case '-': {
        Value* headValue = builder.CreateLoad(head);
        Value* result = builder.CreateSub(headValue, one);
        builder.CreateStore(result, head);
        break;
      }
      case '.': {
        Value* output = builder.CreateLoad(head);
        builder.CreateCall(putchar, output);
        break;
      }
      case ',': {
        Value* input = builder.CreateCall(getchar);
        builder.CreateStore(input, head);
        break;
      }
      case '[': {
        // Construct loop info
        bfLoopInfo loop;
        loop.beforeBlock = block;
        loop.startBlock = BasicBlock::Create(getGlobalContext(), "", main);
        loop.afterBlock = BasicBlock::Create(getGlobalContext(), "", main);
        loop.beforeValue = head;
        
        // Create branching instructions
        Value* headValue = builder.CreateLoad(head);
        Value* condition = builder.CreateIsNotNull(headValue);
        builder.CreateCondBr(condition, loop.startBlock, loop.afterBlock);
        
        // Create a phi node
        IRBuilder<> sbuilder(loop.startBlock);
        loop.startValue = sbuilder.CreatePHI(tapeType);
        loop.startValue->addIncoming(loop.beforeValue, loop.beforeBlock);
        
        // Push the loop
        loops.push(loop);
        block = loop.startBlock;
        head = loop.startValue;
        break;
      }
      case ']': {
        // Retrieve the loop info
        bfLoopInfo loop = loops.top(); loops.pop();
        loop.endValue = head;
        loop.endBlock = block;
        
        // Create a conditional branch
        Value* headValue = builder.CreateLoad(head);
        Value* condition = builder.CreateIsNotNull(headValue);
        builder.CreateCondBr(condition, loop.startBlock, loop.afterBlock);
        
        // Augement loops phi node
        loop.startValue->addIncoming(loop.endValue, loop.endBlock);
        
        // Switch to the after block
        block = loop.afterBlock;
        
        // Create a phi node
        IRBuilder<> abuilder(block);
        PHINode* headPhi = abuilder.CreatePHI(tapeType);
        headPhi->addIncoming(loop.beforeValue, loop.beforeBlock);
        headPhi->addIncoming(loop.endValue, loop.endBlock);
        head = headPhi;
        break;
      }
      default:
        break;
    }
  }
  
  // Close the function
  IRBuilder<> builder(block);
  builder.CreateRetVoid();
  return main;
}

int main(int argc, char* argv[])
{
  if (argc < 2) {
    std::cerr << "Usage: " << argv[0] << " bf_file" << std::endl;
    return -1;
  }
  std::ifstream sourceFile(argv[1]);
  std::string line, source;
  while (getline(sourceFile, line))
    source += line;
  
  // Setup a module and engine for JIT-ing
  std::string error;
  InitializeNativeTarget();
  Module* module = new Module("bfcode", getGlobalContext());
  ExecutionEngine *engine = EngineBuilder(module)
    .setErrorStr(&error)
    .setOptLevel(CodeGenOpt::Aggressive)
    .create();
  if (!engine) {
    std::cout << "No engine created: " << error << std::endl;
    return -1;
  }
  
  // Compile the Brainfuck to IR
  std::cout << "Parsing..." << std::flush;
  Function* func = makeFunc(module, source.c_str());
  std::cout << " done" << std::endl;
  
#if 1
  // Run optimization passes
  std::cout << "Optimizing..." << std::flush;
  FunctionPassManager pm(module);
  pm.add(new TargetData(*(engine->getTargetData())));
  pm.add(createVerifierPass());

  // Process
  pm.run(*func);
  std::cout << "done" << std::endl;
#endif

  // Compile
  std::cout << "Compiling..." << std::flush;
  void (*bf)() = (void (*)()) engine->getPointerToFunction(func);
  std::cout << " done" << std::endl;
 
  // and run!
  bf();

  return 0;
}
