#pragma once

#include <cstdint>

#include <stdexcept>
#include <vector>
#include <queue>
#include <array>
#include <unordered_map>

#include <enyx/cores/flash/flash.hpp>
#include <enyx/hw/result.hpp>

ENYX_CORES_NAMESPACE_BEGIN

namespace flash {
namespace mocking {

static constexpr uint8_t MAX_REGISTER_WORDS = 2;
/*
 * The flash_mng core uses 32bits/4bytes words and the flash API is hard wired
 * to match that.
 * */
static constexpr uint8_t FLASH_WORD_SIZE = 4;

/**
 * Create a fake NAND flash component
 *
 * This can be wired into a mocked flash_mng core
 */
struct emulated_flash {
    emulated_flash(std::uint64_t flash_width,
                   std::uint64_t sector_count)
        : flash_width(flash_width)
        , flash_mem(1ULL << flash_width, -1u)
        , fifo_size(1)
        , word_size(FLASH_WORD_SIZE)
        , sector_count(sector_count)
        , flash_size(flash_mem.size() * word_size)
        , sector_size(flash_size / sector_count)
        , boost(false)
        , dual(false)
    {
        update_fifos_size();
    }

    std::uint32_t read() {
        if (read_fifo.empty()) {
            throw std::runtime_error("Failed to read data: read fifo is empty");
        }
        std::uint32_t value = static_cast<std::uint32_t>(read_fifo.front());
        read_fifo.pop();
        return value;
    }

    void write(std::uint64_t value) {
        if (write_fifo.size() >= fifo_size)
            throw std::runtime_error("Failed to write data: write fifo is full");
        write_fifo.push(static_cast<std::uint32_t>(value));
    };

    void exec_read(std::uint32_t addr, std::uint32_t size)
    {
        if (read_fifo.size() != 0){
            throw std::runtime_error("Failed to read memory: read fifo is not empty");
        }
        if (size > fifo_size){
            throw std::runtime_error("Failed to read memory: read fifo is too small");
        }
        for (std::uint64_t i = 0; i < size ; i++)
            read_fifo.push(flash_mem.at(addr + i));
    }

    void exec_write(std::uint32_t addr, std::uint32_t size)
    {
        for (std::uint64_t i = 0; i < size ; i++) {
            if (write_fifo.empty()) {
                throw std::runtime_error("Failed to write to memory: write fifo is empty");
            }

            std::uint32_t value = write_fifo.front();
            write_fifo.pop();

            flash_mem.at(addr + i) &= value;
        }
    }

    void exec_erase(std::uint32_t addr, std::uint32_t size)
    {
        for (std::uint64_t i = 0; i < size ; i++) {
            flash_mem.at(addr + i) = -1u;
        }
    }

    void exec_read_register(std::uint32_t addr, std::uint32_t size)
    {
        if (read_fifo.size() != 0) {
            throw std::runtime_error("Failed to read register: read fifo is not empty");
        }
        if (size == 0) {
            throw std::runtime_error("Failed to read register: command size is zero");
        }
        if (size > MAX_REGISTER_WORDS) {
            throw std::runtime_error("Failed to read register: maximum size exceeded");
        }

        if (flash_regs.find(addr) == flash_regs.end())
            throw std::runtime_error("Failed to read register: not initialized");

        for (std::uint64_t i = 0; i < size ; i++) {
            read_fifo.push(flash_regs.at(addr)[i]);
        }
    }

    void exec_write_register(std::uint32_t addr, std::uint32_t size)
    {
        if (size == 0) {
            throw std::runtime_error("Failed to write to register: command size is zero");
        }
        if (size > MAX_REGISTER_WORDS) {
            throw std::runtime_error("Failed to write to register: maximum size exceeded");
        }
        if (size != write_fifo.size()) {
            throw std::runtime_error("Failed to write to register: command size does not match loaded data");
        }

        if (flash_regs.find(addr) == flash_regs.end())
            flash_regs.emplace(addr,
                               std::array<std::uint32_t, MAX_REGISTER_WORDS>());

        for (std::uint64_t i = 0; i < size ; i++) {
            flash_regs.at(addr)[i] = write_fifo.front();
            write_fifo.pop();
        }
    }

    void flush_fifos() {
        write_fifo = {};
        read_fifo = {};
    }

    void update_fifos_size() {
        if (boost && dual)
            fifo_size = 512 / 4;
        else if (boost)
            fifo_size = 256 / 4;
        else
            fifo_size = 4 / 4;
    }

    std::uint64_t flash_width;
    std::vector<std::uint32_t> flash_mem;
    std::queue<std::uint32_t> write_fifo;
    std::queue<std::uint32_t> read_fifo;

    std::uint32_t fifo_size;
    std::uint32_t word_size;
    std::uint32_t sector_count;
    std::uint32_t flash_size;
    std::uint32_t sector_size;
    bool boost;
    bool dual;

    std::unordered_map<std::uint8_t,
        std::array<std::uint32_t, MAX_REGISTER_WORDS>> flash_regs;
};

/**
 * Wrap the emulated flash into a basic mocked flash with major version 0.
 *
 * @tparam MockGeneratedFlash The generated mocked flash to inherit from
 * @tparam GeneratedFlash The generated c++ flash class coresponding to the
 * mocked one
 * @tparam FLASH_WIDTH emulated flash width
 * @tparam SECTOR_COUNT number of sector in emulated flash
 *
 * @note This class defines all sections as having the same start/end (0,
 * 0x512) change it if you need that in your test.
 */
template<class MockGeneratedFlash,
         typename GeneratedFlash,
         std::uint32_t FLASH_WIDTH,
         std::uint32_t SECTOR_COUNT>
struct flash_mng_0_mock : public MockGeneratedFlash {
    using flash_cmd = typename GeneratedFlash::flash_cmd_type_t;

    flash_mng_0_mock()
        : MockGeneratedFlash()
        , flash(FLASH_WIDTH,
                SECTOR_COUNT)
    {
        this->generics_addr_width.set(flash.flash_width);
        this->generics_sector_count.set(flash.sector_count);

        this->generics_pgm_addr_sta.set(0x0);
        this->generics_pgm_addr_end.set(0x512);
        this->generics_otp_addr_sta.set(0x0);
        this->generics_otp_addr_end.set(0x512);
        this->generics_prm_addr_sta.set(0x0);
        this->generics_prm_addr_end.set(0x512);
        this->generics_usr_addr_sta.set(0x0);
        this->generics_usr_addr_end.set(0x512);
        this->flash_cmd_data.on_write = [&] (std::uint64_t value) {
            flash.write(value);
            return hw::result<void>{};
        };

        this->flash_cmd_read_data.on_read = [&] () {
            return hw::result<std::uint64_t>{flash.read()};
        };

        this->on_flash_cmd_exec =
            [&] (std::uint64_t flash_cmd_type,
                 std::uint64_t flash_cmd_addr,
                 std::uint64_t flash_cmd_size) {

                if (flash_cmd_addr + flash_cmd_size >= flash.flash_size) {
                    throw std::runtime_error("Tried to read/write/erase outside flash");
                }

                switch(static_cast<flash_cmd>(flash_cmd_type)) {
                case flash_cmd::READ:
                    flash.exec_read(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::WRITE:
                    flash.exec_write(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::ERASE:
                    flash.exec_erase(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::READ_REGISTER:
                    flash.exec_read_register(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::WRITE_REGISTER:
                    flash.exec_write_register(flash_cmd_addr, flash_cmd_size);
                    break;
                default:
                    throw std::runtime_error("Unsupported command");
                }
                this->flash_cmd_exec.set(0);
            };

        this->flash_cmd_flush.on_write = [&] (std::uint64_t value) {
            if ((value & 1) == 1) {
                flash.flush_fifos();
            }

            return enyx::hw::result<void>{};
        };
    }

    /**
     * Change the mode of the flash core
     */
    void set_modes(bool boost, bool dual) {
        flash.boost = boost;
        flash.dual = dual;
        this->flash_boost_mode.set(boost);
        this->flash_dual.set(dual);

        flash.update_fifos_size();
    }

    emulated_flash flash;
};


/**
 * Wrap the emulated flash into a basic mocked flash with major version 1.
 *
 * @tparam MockGeneratedFlash The generated mocked flash to inherit from
 * @tparam GeneratedFlash The generated c++ flash class coresponding to the
 * mocked one
 * @tparam FLASH_WIDTH emulated flash width
 * @tparam SECTOR_COUNT number of sector in emulated flash
 *
 * @note This class defines all sections as having the same start/end (0,
 * 0x512) change it if you need that in your test.
 */
template<class MockGeneratedFlash,
         typename GeneratedFlash,
         std::uint32_t FLASH_WIDTH,
         std::uint32_t SECTOR_COUNT>
struct flash_mng_2_mock : MockGeneratedFlash {
    using flash_cmd = typename GeneratedFlash::flash_cmd_type_t;
    flash_mng_2_mock()
        : MockGeneratedFlash()
        , flash(FLASH_WIDTH,
                SECTOR_COUNT)
    {
        this->generics1_flash_addr_width.set(flash.flash_width);
        this->generics1_flash_sector_count.set(flash.sector_count);

        this->section_count.set(4);

        this->get_section_name_string = [] (uint64_t page_id) {
            if (page_id == 0)
                return SECTION_PROGRAM;
            else if (page_id == 1)
                return SECTION_USER;
            else if (page_id == 2)
                return SECTION_PARAMS;
            else if (page_id == 3)
                return SECTION_PARAMS;

            return "";
        };

        this->section_addr_start.set(0x0);
        this->section_addr_end.set(0x512);

        this->flash_cmd_data.on_write = [&] (std::uint64_t value) {
            flash.write(value);
            return hw::result<void>{};
        };

        this->flash_cmd_read_data.on_read = [&] () {
            return hw::result<std::uint64_t>{flash.read()};
        };

        this->on_flash_cmd_exec =
            [&] (std::uint64_t flash_cmd_type,
                 std::uint64_t flash_cmd_addr,
                 std::uint64_t flash_cmd_size) {

                if (flash_cmd_addr + flash_cmd_size >= flash.flash_size) {
                    throw std::runtime_error("Tried to read/write/erase outside flash");
                }

                switch(static_cast<flash_cmd>(flash_cmd_type)) {
                case flash_cmd::READ:
                    flash.exec_read(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::WRITE:
                    flash.exec_write(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::ERASE:
                    flash.exec_erase(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::READ_REGISTER:
                    flash.exec_read_register(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::WRITE_REGISTER:
                    flash.exec_write_register(flash_cmd_addr, flash_cmd_size);
                    break;
                case flash_cmd::FLUSH_FIFOS:
                    flash.flush_fifos();
                    break;
                default:
                    throw std::runtime_error("Unsupported command");
                }
                this->flash_cmd_exec.set(0);
            };
    }

    void set_modes(bool boost, bool dual) {
        flash.boost = boost;
        flash.dual = dual;
        this->generics2_boost_mode.set(boost);
        this->generics2_dual_memory.set(dual);

        flash.update_fifos_size();
    }

    enyx::flash::mocking::emulated_flash flash;
};

}
}

ENYX_CORES_NAMESPACE_END
