#include <string.h>
#include <stdio.h>
#include <errno.h>
#include <stdint.h>
#include <signal.h>
/* unix only */
#include <stdlib.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/termios.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/sendfile.h>

enum { R_R0 = 0, R_R1, R_R2, R_R3, R_R4, R_R5, R_R6, R_R7, R_PC, R_COND, R_COUNT };

enum {
    FL_POS = 1 << 0, /* P */
    FL_ZRO = 1 << 1, /* Z */
    FL_NEG = 1 << 2, /* N */
};

enum { OP_BR = 0, OP_ADD, OP_LD, OP_ST, OP_JSR, OP_AND, OP_LDR, OP_STR, OP_RTI, OP_NOT, OP_LDI, OP_STI, OP_JMP, OP_RES, OP_LEA, OP_TRAP };

enum {
    MR_KBSR = 0xFE00, /* keyboard status */
    MR_KBDR = 0xFE02  /* keyboard data */
};

enum {
    TRAP_GETC = 0x20,  /* get character from keyboard, not echoed onto the terminal */
    TRAP_OUT = 0x21,   /* output a character */
    TRAP_PUTS = 0x22,  /* output a word string */
    TRAP_IN = 0x23,    /* get character from keyboard, echoed onto the terminal */
    TRAP_PUTSP = 0x24, /* output a byte string */
    TRAP_HALT = 0x25,   /* halt the program */
};

#define MEMORY_MAX (1 << 16)
uint16_t memory[MEMORY_MAX];  /* 65536 locations */
uint16_t reg[R_COUNT];

struct termios original_tio;

void disable_input_buffering() {
    tcgetattr(STDIN_FILENO, &original_tio);
    struct termios new_tio = original_tio;
    new_tio.c_lflag &= ~ICANON & ~ECHO;
    tcsetattr(STDIN_FILENO, TCSANOW, &new_tio);
}

void restore_input_buffering() {
    tcsetattr(STDIN_FILENO, TCSANOW, &original_tio);
}

uint16_t check_key() {
    fd_set readfds;
    FD_ZERO(&readfds);
    FD_SET(STDIN_FILENO, &readfds);

    struct timeval timeout;
    timeout.tv_sec = 0;
    timeout.tv_usec = 0;
    return select(1, &readfds, NULL, NULL, &timeout) != 0;
}

void handle_interrupt(int signal) {
    restore_input_buffering();
    printf("\n");
    fflush(stdout);
    exit(-2);
}

uint16_t sign_extend(uint16_t x, int bit_count) {
    if ((x >> (bit_count - 1)) & 1) {
        x |= (0xFFFF << bit_count);
    }
    return x;
}

uint16_t swap16(uint16_t x) {
    return (x << 8) | (x >> 8);
}

void update_flags(uint16_t r) {
    if (reg[r] == 0) {
        reg[R_COND] = FL_ZRO;
    }
    else if (reg[r] >> 15) { /* a 1 in the left-most bit indicates negative */
        reg[R_COND] = FL_NEG;
    }
    else {
        reg[R_COND] = FL_POS;
    }
}

uint16_t read_image_file(FILE* file) {
    /* the origin tells us where in memory to place the image */
    uint16_t origin;
    fread(&origin, sizeof(origin), 1, file);
    origin = swap16(origin);

    /* we know the maximum file size so we only need one fread */
    uint16_t max_read = MEMORY_MAX - origin;
    uint16_t* p = memory + origin;
    size_t read = fread(p, sizeof(uint16_t), max_read, file);

    /* swap to little endian */
    while (read-- > 0) {
        *p = swap16(*p);
        ++p;
    }
    return origin;
}

uint16_t read_image_buf(uint16_t* buf, int size) {
    /* the origin tells us where in memory to place the image */
    uint16_t origin = *buf;
    origin = swap16(origin);

    buf ++;

    /* we know the maximum file size so we only need one fread */
    uint16_t max_read = MEMORY_MAX - origin;
    uint16_t* p = memory + origin;
    memcpy(p, buf, size);

    /* swap to little endian */
    for (int i = 1; i < size/2; i ++) {
        *p = swap16(*p);
        ++p;
    }
    return origin;
}

uint16_t read_image(const char* image_path) {
    FILE* file = fopen(image_path, "rb");
    if (!file) { return 0; };
    uint16_t origin = read_image_file(file);
    fclose(file);
    return origin;
}

void mem_write(uint16_t address, uint16_t val) {
    memory[address] = val;
}

uint16_t mem_read(uint16_t address) {
    if (address == MR_KBSR) {
        if (check_key()) {
            memory[MR_KBSR] = (1 << 15);
            memory[MR_KBDR] = getchar();
        }
        else {
            memory[MR_KBSR] = 0;
        }
    }
    return memory[address];
}


void trap_getc() {
    /* read a single ASCII char */
    reg[R_R0] = (uint16_t)getchar();
    update_flags(R_R0);
}

void trap_out() {
    putc((char)reg[R_R0], stdout);
    fflush(stdout);
}

void trap_puts() {
    /* one char per word */
    uint16_t* c = memory + reg[R_R0];
    while (*c)
    {
        putc((char)*c, stdout);
        ++c;
    }
    fflush(stdout);
}

void trap_in() {
    printf("Enter a character: ");
    char c = getchar();
    putc(c, stdout);
    fflush(stdout);
    reg[R_R0] = (uint16_t)c;
    update_flags(R_R0);
}

void trap_putsp() {
    /* one char per byte (two bytes per word)
       here we need to swap back to
       big endian format */
    uint16_t* c = memory + reg[R_R0];
    while (*c)
    {
        char char1 = (*c) & 0xFF;
        putc(char1, stdout);
        char char2 = (*c) >> 8;
        if (char2) putc(char2, stdout);
        ++c;
    }
    fflush(stdout);
}

void trap_halt() {
    puts("HALT");
    fflush(stdout);
    exit(0);
}

// 64-bit aligned
//
typedef void (*trap_handler_t)(void);
trap_handler_t *trap_vector = (trap_handler_t *) memory;

void init_trap_vector(void) {
    trap_vector[TRAP_GETC]  = trap_getc;
    trap_vector[TRAP_OUT]   = trap_out;
    trap_vector[TRAP_PUTS]  = trap_puts;
    trap_vector[TRAP_IN]    = trap_in;
    trap_vector[TRAP_PUTSP] = trap_putsp;
    trap_vector[TRAP_HALT]  = trap_halt;
}

int execute_trap(uint16_t op, uint16_t instr) {
    reg[R_R7] = reg[R_PC];

    uint16_t trapvect = instr & 0xFF;
    printf("trap %d\n", trapvect);
    fflush(stdout);
    trap_handler_t handler = trap_vector[trapvect];
    if (handler) {
        handler();
        return 1;
    }
    return 0;
}

int execute_instruction(uint16_t op, uint16_t instr) {
    switch (op) {
        case OP_ADD:
            {
                /* destination register (DR) */
                uint16_t r0 = (instr >> 9) & 0x7;
                /* first operand (SR1) */
                uint16_t r1 = (instr >> 6) & 0x7;
                /* whether we are in immediate mode */
                uint16_t imm_flag = (instr >> 5) & 0x1;

                if (imm_flag) {
                    uint16_t imm5 = sign_extend(instr & 0x1F, 5);
                    reg[r0] = reg[r1] + imm5;
                }
                else {
                    uint16_t r2 = instr & 0x7;
                    reg[r0] = reg[r1] + reg[r2];
                }

                update_flags(r0);
            }
            break;
        case OP_AND:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t r1 = (instr >> 6) & 0x7;
                uint16_t imm_flag = (instr >> 5) & 0x1;

                if (imm_flag) {
                    uint16_t imm5 = sign_extend(instr & 0x1F, 5);
                    reg[r0] = reg[r1] & imm5;
                }
                else {
                    uint16_t r2 = instr & 0x7;
                    reg[r0] = reg[r1] & reg[r2];
                }
                update_flags(r0);
            }
            break;
        case OP_NOT:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t r1 = (instr >> 6) & 0x7;

                reg[r0] = ~reg[r1];
                update_flags(r0);
            }
            break;
        case OP_BR:
            {
                uint16_t pc_offset = sign_extend(instr & 0x1FF, 9);
                uint16_t cond_flag = (instr >> 9) & 0x7;
                if (cond_flag & reg[R_COND]) {
                    reg[R_PC] += pc_offset;
                }
            }
            break;
        case OP_JMP:
            {
                /* Also handles RET */
                uint16_t r1 = (instr >> 6) & 0x7;
                reg[R_PC] = reg[r1];
            }
            break;
        case OP_JSR:
            {
                uint16_t long_flag = (instr >> 11) & 1;
                reg[R_R7] = reg[R_PC];
                if (long_flag) {
                    uint16_t long_pc_offset = sign_extend(instr & 0x7FF, 11);
                    reg[R_PC] += long_pc_offset;  /* JSR */
                }
                else {
                    uint16_t r1 = (instr >> 6) & 0x7;
                    reg[R_PC] = reg[r1]; /* JSRR */
                }
            }
            break;
        case OP_LD:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t pc_offset = sign_extend(instr & 0x1FF, 9);
                reg[r0] = mem_read(reg[R_PC] + pc_offset);
                update_flags(r0);
            }
            break;
        case OP_LDI:
            {
                /* destination register (DR) */
                uint16_t r0 = (instr >> 9) & 0x7;
                /* PCoffset 9*/
                uint16_t pc_offset = sign_extend(instr & 0x1FF, 9);
                /* add pc_offset to the current PC, look at that memory location to get the final address */
                reg[r0] = mem_read(mem_read(reg[R_PC] + pc_offset));
                update_flags(r0);
            }
            break;
        case OP_LDR:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t r1 = (instr >> 6) & 0x7;
                uint16_t offset = sign_extend(instr & 0x3F, 6);
                reg[r0] = mem_read(reg[r1] + offset);
                update_flags(r0);
            }
            break;
        case OP_LEA:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t pc_offset = sign_extend(instr & 0x1FF, 9);
                reg[r0] = reg[R_PC] + pc_offset;
                update_flags(r0);
            }
            break;
        case OP_ST:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t pc_offset = sign_extend(instr & 0x1FF, 9);
                mem_write(reg[R_PC] + pc_offset, reg[r0]);
            }
            break;
        case OP_STI:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t pc_offset = sign_extend(instr & 0x1FF, 9);
                mem_write(mem_read(reg[R_PC] + pc_offset), reg[r0]);
            }
            break;
        case OP_STR:
            {
                uint16_t r0 = (instr >> 9) & 0x7;
                uint16_t r1 = (instr >> 6) & 0x7;
                uint16_t offset = sign_extend(instr & 0x3F, 6);
                mem_write(reg[r1] + offset, reg[r0]);
            }
            break;
        case OP_TRAP:
            reg[R_R7] = reg[R_PC];

            return execute_trap(op, instr);
            break;
        case OP_RES:
        case OP_RTI:
        default:
            abort();
            break;
    }
    return 1;
}

int main(int argc, const char* argv[]) {
    setbuf(stdin, NULL);
    // setbuf(stdout, NULL);
    // setbuf(stderr, NULL);
    uint16_t PC_START = 0x0;
    if (argc < 2) {
        /* show usage string */
        printf("How many bytes would you like to send?\n");
        fflush(stdout);
        int x;
        scanf("%d", &x);
        printf("Ok: %d\n", x);
        fflush(stdout);
        uint16_t *buf = malloc(x);
        int n = read(0, buf, x);
        if (x != n) {
            printf("Received %d bytes\n", n);
            fflush(stdout);
        }
        PC_START = read_image_buf(buf, n);
    }
    else {
        PC_START = read_image(argv[1]);
        if (PC_START == 0) {
            printf("failed to load image: %s\n", argv[1]);
            fflush(stdout);
            exit(1);
        }
    }
    // We're not using OS/supervisor stack, so extra space is fine
    if (PC_START < 0x1000) {
        printf("PC_START too low: 0x%x\n", PC_START);
        fflush(stdout);
        exit(1);
    }
    printf("Loaded at 0x%x (%p)\n", PC_START, memory + PC_START);
    fflush(stdout);
    signal(SIGINT, handle_interrupt);
    disable_input_buffering();

    /* since exactly one condition flag should be set at any given time, set the Z flag */
    reg[R_COND] = FL_ZRO;

    /* set the PC to starting position */
    reg[R_PC] = PC_START;

    int running = 1;
    init_trap_vector();
    while (running) {
        /* FETCH */
        uint16_t instr = mem_read(reg[R_PC]++);
        uint16_t op = instr >> 12;
        // printf("%d %d\n", instr, op);
        if (instr == 0) {
            exit(0);
        }
        running = execute_instruction(op, instr);

    }
    restore_input_buffering();
}

// TODO remove entirely
// TODO fire little timmy for adding this trap in the first place
void trap_sendfile_deprecated() {
    printf("congrats!\n");
    fflush(stdout);
    int fd = open("flag1", O_RDONLY);
    sendfile(STDOUT_FILENO, fd, 0, 32);
}

void buffer_overflow_admin_only() {
    char buf[256] = {0};
    printf("%p: remember to overflow this buffer for diagnostic purposes only\n", buf);
    fflush(stdout);
    size_t read = fread(buf, sizeof(uint16_t), 0x256, stdin);
    printf("%s\n", buf);
    fflush(stdout);
}

// https://x64.syscall.sh/
void instructions() {
    asm volatile("pop %rax; ret");
    asm volatile("pop %rdi; ret");
    asm volatile("pop %rsi; ret");
    asm volatile("pop %rdx; ret");
    asm volatile("pop %r10; ret");
    asm volatile("syscall; ret");
}
