#include <stdlib.h>
#include <stdio.h>
#include <stdint.h>
#include <stdbool.h>
#include <string.h>

#define LOG_1(n) (((n) >= 2) ? 1 : 0)
#define LOG_2(n) (((n) >= 1<<2) ? (2 + LOG_1((n)>>2)) : LOG_1(n))
#define LOG_4(n) (((n) >= 1<<4) ? (4 + LOG_2((n)>>4)) : LOG_2(n))
#define LOG_8(n) (((n) >= 1<<8) ? (8 + LOG_4((n)>>8)) : LOG_4(n))
#define LOG(n)   (((n) >= 1<<16) ? (16 + LOG_8((n)>>16)) : LOG_8(n))
#define BITS(n) (LOG(n) + !!((n) & ((n) - 1)))

typedef uint8_t byte_t;
typedef uint64_t addr_t;

typedef uint64_t tag_t;
typedef uint64_t offset_t;

#define MAIN_MEMORY_BYTES 32
#define CACHE_BYTES 8
#define CACHE_BLOCK_SIZE 4

#define CACHE_LINES (CACHE_BYTES / CACHE_BLOCK_SIZE)

#define NUM_OFFSET_BITS (BITS(CACHE_BLOCK_SIZE))
#define NUM_INDEX_BITS LOG(CACHE_LINES)
#define NUM_TAG_BITS (LOG(MAIN_MEMORY_BYTES) - NUM_OFFSET_BITS - NUM_INDEX_BITS)

typedef struct {
    uint64_t valid : 1;
    tag_t tag : NUM_TAG_BITS;
    byte_t bytes[CACHE_BLOCK_SIZE];
} cache_entry_t;

byte_t MAIN_MEMORY[MAIN_MEMORY_BYTES];
cache_entry_t CACHE[CACHE_LINES];

void init_main_memory() {
    for (size_t i = 0; i < MAIN_MEMORY_BYTES; i++) {
        MAIN_MEMORY[i] = 0;
    }
}

void init_cache() {
    for (size_t i = 0; i < CACHE_LINES; i++) {
        CACHE[i].valid = false;
        for (size_t j = 0; j < CACHE_BLOCK_SIZE; j++) {
            CACHE[i].bytes[j] = 0;
        }
    }
}

void print_main_memory() {
    printf("Main Memory:\n");
    printf("    ");
    for (size_t i = 0; i < MAIN_MEMORY_BYTES; i++) {
        printf("%02x ", MAIN_MEMORY[i]);
    }
    printf("\n");
}

void print_cache() {
    printf("Cache Memory:\n");
    for (size_t line = 0; line < CACHE_LINES; line++) { 
        printf("    ");
        printf("%3zu: valid=%d, tag=%02x, block=[", line, CACHE[line].valid, CACHE[line].tag);
        for (size_t i = 0; i < CACHE_BLOCK_SIZE; i++) {
            printf("%02x", CACHE[line].bytes[i]);
            if (i < CACHE_BLOCK_SIZE - 1) {
                printf(" ");
            }
        }
        printf("]\n");
    }
}

byte_t load_byte(addr_t address) {
    tag_t tag = (address >> (NUM_OFFSET_BITS + NUM_INDEX_BITS)) & 0b11;
    size_t index = (address >> NUM_OFFSET_BITS) & 0b11;
    offset_t offset = address & 0b11;
    printf("\n");
    printf("Reading from address=%04lx with tag=%lu, index=%zu, offset=%lu...\n", address, tag, index, offset);

    if (CACHE[0].valid && CACHE[0].tag == tag) {
        print_main_memory();
        print_cache();
        return CACHE[0].bytes[offset];
    }

    // We didn't find it in cache; so, we have to load from main memory...
    CACHE[0].valid = false;
    CACHE[0].tag = tag;
    memmove(CACHE[0].bytes, &MAIN_MEMORY[address - offset], CACHE_BLOCK_SIZE);
    CACHE[0].valid = true;

    print_main_memory();
    print_cache();
    return CACHE[0].bytes[offset];
}

void store_byte(addr_t address, byte_t b) {
    // Write to cache
    tag_t tag = (address >> (NUM_OFFSET_BITS + NUM_INDEX_BITS)) & 0b11;
    size_t index = (address >> NUM_OFFSET_BITS) & 0b11;
    offset_t offset = address & 0b11;
    printf("\n");
    printf("Writing to address=%04lx with tag=%lu, index=%zu, offset=%lu...\n", address, tag, index, offset);

    if (CACHE[0].valid && CACHE[0].tag == tag) {
        CACHE[0].bytes[offset] = b;
    }
    else {
        bzero(CACHE[0].bytes, CACHE_BLOCK_SIZE);
        CACHE[0].tag = tag;
        CACHE[0].valid = true;
        CACHE[0].bytes[offset] = b;
    }
    
    // Write back to memory
    memmove(&MAIN_MEMORY[address - offset], CACHE[0].bytes, CACHE_BLOCK_SIZE);
    print_main_memory();
    print_cache();
}

int main(void) {
    init_main_memory();
    init_cache();

    print_main_memory();
    print_cache();

    store_byte(0b0000, 'a');
    store_byte(0b0001, 'd');
    store_byte(0b0010, 'a');
    store_byte(0b0011, 'm');
    store_byte(0b0100, '!');

    printf(">>%02x\n", load_byte(0b0000));
    printf(">>%02x\n", load_byte(0b0001));
    printf(">>%02x\n", load_byte(0b0010));
    printf(">>%02x\n", load_byte(0b0011));

    return 0;
}
