/* derived from https://github.com/prophile/bfc/blob/master/bfc.c */
#include <stdarg.h>
#include <stdio.h>

/* the number of registers to provide */
const static int REG_COUNT = 8192;
/* this is provides unique values for the names of LLVM SSA variables */
static int rv = 1;

/* this contains a stack of loop numbers */
static int loopStack[4096];
/* and an index to that stack */
static int loopStackIndex = 0;

#define GEN( ...) \
    do { \
        fprintf(fp, __VA_ARGS__); \
    } while (0)

/* emit file the header */
static void emit_header(FILE *fp)
{
    // first, define @op_out which corresponds to output
    GEN("define internal void @op_out(i64 %%val) nounwind\n"
        "{\n"
        "entry:\n"
        "    %%conv = trunc i64 %%val to i32\n"
        "    %%call = tail call i32 @putchar ( i32 %%conv ) nounwind\n"
        "    ret void\n"
        "}\n\n"
        "declare i32 @putchar(i32) nounwind\n\n");
    // next, define @op_in which corresponds to input
    GEN("define internal i64 @op_in() nounwind\n"
        "{\n"
        "entry:\n"
        "    %%call = tail call i32 @getchar() nounwind\n"
        "    %%conv = sext i32 %%call to i64\n"
        "    ret i64 %%conv\n"
        "}\n\n"
        "declare i32 @getchar() nounwind\n\n");
    // now, define the actual program body
    GEN("define void @program() nounwind\n"
        "{\n"
        "entry:\n");
    // allocate the index stack slot...
    GEN("    %%index = alloca i64, align 8\n");
    // and stack space for the registers
    GEN("    %%registers = alloca [%d x i64], align 8\n", REG_COUNT);
    // set the initial index to 0
    GEN("    store i64 0, i64* %%index\n");
    // get a pointer to the first element of the registers
    GEN("    %%regroot = getelementptr [%d x i64]* %%registers, i64 0, i64 0\n", REG_COUNT);
    // clear all the registers
    GEN("    %%ptrconv = bitcast i64* %%regroot to i8*\n");
    GEN("    call void @llvm.memset.i64(i8* %%ptrconv, i8 0, i64 %d, i32 8)\n", REG_COUNT * 8);
}

/* emit the footer */
static void emit_footer(FILE *fp)
{
    // return a value (void!)
    GEN("    ret void\n"
        "}\n\n"
    // declaration for the memset intrinsic
        "declare void @llvm.memset.i64(i8*, i8, i64, i32)\n"
    // actual C entry point, whcih just calls program
        "define i32 @main() nounwind\n"
        "{\n"
        "entry:\n"
        "    call void @program() nounwind\n"
        "    ret i32 0\n"
        "}\n\n");
}

/* an add is any arithmetic, so a sequence of +s and -s */
static void emit_add(FILE *fp, long amount)
{
    int opID = rv++;
    if (amount == 0) // if we're not actually modifying, short-circuit
        return;
    // load the index...
    GEN("    %%idx%d = load i64* %%index\n", opID);
    // load the actual register...
    GEN("    %%ptr%d = getelementptr i64* %%regroot, i64 %%idx%d\n", opID, opID);
    GEN("    %%tmp%d = load i64* %%ptr%d\n", opID, opID);
    // perform the arithmetic...
    GEN("    %%add%d = add i64 %%tmp%d, %ld\n", opID, opID, amount);
    // and store.
    GEN("    store i64 %%add%d, i64* %%ptr%d\n", opID, opID);
}

/* a move moves the register index, so a sequence of >s and <s */
static void emit_move(FILE *fp, long amount)
{
    int opID = rv++;
    if (amount == 0)
        return; // short-circuit
    // load the index
    GEN("    %%idx%d = load i64* %%index\n", opID);
    // adjust it
    GEN("    %%add%d = add i64 %%idx%d, %ld\n", opID, opID, amount);
    // store the new index
    GEN("    store i64 %%add%d, i64* %%index\n", opID);
}

/* emits an output operation */
static void emit_out(FILE *fp)
{
    int opID = rv++;
    // load the index
    GEN("    %%idx%d = load i64* %%index\n", opID);
    // load the register
    GEN("    %%ptr%d = getelementptr i64* %%regroot, i64 %%idx%d\n", opID, opID);
    GEN("    %%tmp%d = load i64* %%ptr%d\n", opID, opID);
    // pass it to @op_out
    GEN("    call void @op_out(i64 %%tmp%d) nounwind\n", opID);
}

/* emits an input instruction */
void emit_in(FILE *fp)
{
    int opID = rv++;
    // get the input
    GEN("    %%tmp%d = call i64 @op_in() nounwind\n", opID);
    // load the index
    GEN("    %%idx%d = load i64* %%index\n", opID);
    // store the new value
    GEN("    %%ptr%d = getelementptr i64* %%regroot, i64 %%idx%d\n", opID, opID);
    GEN("    store i64 %%tmp%d, i64* %%ptr%d\n", opID, opID);
}

/* this emits a loop header and the beginning of the body, corresponding to a '[' */
static void emit_loop_open(FILE *fp)
{
    // push this entry onto the stack
    int opID = rv++;
    int stackIndex = loopStackIndex++;
    loopStack[stackIndex] = opID;
    // begin the loop header
    GEN("    br label %%loopHeader%d\n", opID);
    GEN("loopHeader%d:\n", opID);
    // the loop header checks the exit condition, so load the index...
    GEN("    %%idx%d = load i64* %%index\n", opID);
    // load the register...
    GEN("    %%ptr%d = getelementptr i64* %%regroot, i64 %%idx%d\n", opID, opID);
    GEN("    %%tmp%d = load i64* %%ptr%d\n", opID, opID);
    // compare to zero...
    GEN("    %%cmp%d = icmp eq i64 %%tmp%d, 0\n", opID, opID);
    // if zero, jump to the corresponding ']' which is the end of the loop
    // otherwise, run through the loop body
    GEN("    br i1 %%cmp%d, label %%loopExit%d, label %%loopBody%d\n", opID, opID, opID);
    // beginning of loop body
    GEN("loopBody%d:\n", opID);
}

/* this emits the loop exit, which corresponds to ']' */
static void emit_loop_close(FILE *fp)
{
    // get the top loop from the stack
    int stackIndex = --loopStackIndex;
    int loopID = loopStack[stackIndex];
    // jump to the header, which will jump to the loopExit if the exit condition is met
    GEN("    br label %%loopHeader%d\n", loopID);
    GEN("loopExit%d:\n", loopID);
    // continue
}

/* this type contains a reader state */
typedef enum {
    BF_STATE_ARITHMETIC,
    BF_STATE_POINTER,
    BF_STATE_NONE,
} BF_State;

/* this generic function does operation emission, all fairly obvious */
static void emit(FILE *out, BF_State *state, char lastChar, long amount)
{
    switch (*state) {
        case BF_STATE_ARITHMETIC:
            emit_add(out, amount);
            *state = BF_STATE_NONE;
            break;
        case BF_STATE_POINTER:
            emit_move(out, amount);
            *state = BF_STATE_NONE;
            break;
        case BF_STATE_NONE:
            switch (lastChar) {
            case '.':
                emit_out(out);
                break;
            case ',':
                emit_in(out);
                break;
            case '[':
                emit_loop_open(out);
                break;
            case ']':
                emit_loop_close(out);
                break;
            }
            break;
    }
}

/* real compiler - read in as brainfuck, compile to out as LLVM IR */
static void bfp(FILE *in, FILE *out)
{
    long amount = 0;
    BF_State state = BF_STATE_NONE;
    // emit the header
    emit_header(out);
    // loop through everything
    while (!feof(in)) {
        // basically, a finite state machine
        char ch;
        fread(&ch, 1, 1, in);
        /* Do arithmetic */
        if (state == BF_STATE_ARITHMETIC) {
            // if it is a + or -, adjust the working value
            if (ch == '+')
                amount++;
            else if (ch == '-')
                amount--;
            else {
                /* emit the instruction, push it back to be read next loop */
                emit(out, &state, ch, amount);
                ungetc(ch, in);
                amount = 0;
            }
        }
        /* Do for pointers */
        else if (state == BF_STATE_POINTER) {
            if (ch == '>')
                amount++;
            else if (ch == '<')
                amount--;
            else {
                emit(out, &state, ch, amount);
                ungetc(ch, in);
                amount = 0;
            }
        }
        /* Just emit an operation or enter the appropriate state */
        else {
            if (ch == '+') {
                state = BF_STATE_ARITHMETIC;
                amount = 1;
            }
            else if (ch == '-') {
                state = BF_STATE_ARITHMETIC;
                amount = -1;
            }
            else if (ch == '>') {
                state = BF_STATE_POINTER;
                amount = 1;
            }
            else if (ch == '<') {
                state = BF_STATE_POINTER;
                amount = -1;
            }
            else {
                emit(out, &state, ch, 0);
            }
        }
    }
    // emit any trailing +s or -s (maybe this should be removed,
    // as it won't contribute to the output?)
    emit(out, &state, 0, amount);
    emit_footer(out);
}

int main()
{
    /* open up llvm-as | opt -std-compile-opts | lli as a pipe
     * so that we pipe everything to the LLVM toolchain */
    FILE *asOut;
    asOut = popen("llvm-as | opt -std-compile-opts | lli", "w");
    /* run on stdin */
    bfp(stdin, asOut);
    /* tidy up */
    pclose(asOut);
    return 0;
}

