#include <cerrno>
#include <algorithm>
#include <atomic>
#include <exception>
#include <functional>
#include <future>
#include <iostream>
#include <memory>
#include <numeric>
#include <sstream>
#include <thread>
#include <vector>
#include <enyx/hw/error.hpp>
#include <enyx/cores/namespace.hpp>
#include "../lib/DeviceDiscovery.hpp"
#include "../lib/SignalHandler.hpp"
#include "DataValidation.hpp"

ENYX_CORES_NAMESPACE_BEGIN

template<typename Runtime>
Application<Runtime>::Application(ApplicationConfiguration const & configuration)
    : configuration_(configuration)
    , accelerator_(create_accelerator(configuration.accelerator))
    , mmio_(get_first_mmio(accelerator_))
    , core_tree_(hw::enumerate_cores(mmio_))
    , send_buffer_(BUFFER_SIZE)
    , runtime_(configuration, accelerator_, core_tree_)
{
    auto const stack_sessions_count = runtime_.session_count;

    if (configuration_.session_configurations.size() > stack_sessions_count)
    {
        std::ostringstream error;
        error << "Requested more sessions ("
              << configuration_.session_configurations.size()
              << ") than stack capability ("
              << stack_sessions_count << ")";

        throw std::runtime_error{error.str()};
    }

    // Initialize send_buffer_ with sequential values
    // starting from 0.
    std::iota(send_buffer_.begin(), send_buffer_.end(), 0U);
}

template<typename Runtime>
void
Application<Runtime>::loopback_task(std::size_t source_id)
{
    auto on_data = [this](std::uint32_t session_id,
                          std::uint8_t const * data,
                          std::uint32_t size) {
        auto & session = runtime_.sessions.at(session_id);
        if (! session.configuration)
        {
            std::ostringstream error;
            error << "Unexpected message on session " << session_id;
            throw std::runtime_error{error.str()};
        }

        session.received_bytes += size;

        // The draw back of this implementation is that a session
        // with a RX bandwidth > TX bandwidth can slow down other
        // sessions.
        // The data could be copied when the session is back pressuring TX
        // but without RX flow control, the memory of the tester
        // would grow limitless.
        std::error_code failure;
        do
            failure = try_to_send_data(session_id, data, size);
        while (failure == std::errc::resource_unavailable_try_again);
    };

    while (! is_exit_requested())
        if (runtime_.poll_source(source_id, on_data) < 0)
            throw std::system_error{errno, std::generic_category()};
}

template<typename Runtime>
void
Application<Runtime>::create_loopback_data_polling_tasks()
{
    for (std::size_t i = 0U, e = runtime_.data_sources.size(); i != e; ++i)
    {
        auto task = [this, i] {
            exit_on_exception exit_on_exception{};
            loopback_task(i);
        };

        tasks_.push(std::async(std::launch::async, std::move(task)));
    }
}

template<typename Runtime>
void
Application<Runtime>::rx_task(std::size_t source_id)
{
    auto on_data = [this](std::uint32_t session_id,
                          std::uint8_t const * data,
                          std::uint32_t size) {
        auto & session = runtime_.sessions.at(session_id);
        if (! session.configuration)
        {
            std::ostringstream error;
            error << "Unexpected message on session " << session_id;
            throw std::runtime_error{error.str()};
        }

        if (! is_rx_data_valid(*session.configuration,
                               session.received_bytes,
                               data, size))
        {
            std::ostringstream error;
            error << "Corrupted data session " << session_id;
            throw std::runtime_error{error.str()};
        }

        session.received_bytes += size;

        if (session.received_bytes > session.configuration->size)
        {
            std::ostringstream error;
            error << "Overrun bytes on session " << session_id;
            throw std::runtime_error{error.str()};
        }

        if (session.received_bytes == session.configuration->size)
        {
            if (session.configuration->shutdown_policy
                    == SessionConfiguration::RECEIVE_COMPLETE)
                close_session(session_id, session.configuration->close_delay);
        }
    };

    while (! is_exit_requested())
        if (runtime_.poll_source(source_id, on_data) < 0)
            throw std::system_error{errno, std::generic_category()};
}

template<typename Runtime>
void
Application<Runtime>::create_rx_data_polling_tasks()
{
    for (std::size_t i = 0U, e = runtime_.data_sources.size(); i != e; ++i)
    {
        auto task = [this, i] {
            exit_on_exception exit_on_exception{};
            rx_task(i);
        };

        tasks_.push(std::async(std::launch::async, std::move(task)));
    }
}

template<typename Runtime>
void
Application<Runtime>::tx_task(std::size_t sink_id)
{
    auto const partition_size = runtime_.session_count / runtime_.data_sinks.size() +
            bool(runtime_.session_count % runtime_.data_sinks.size());
    auto const partition_offset = partition_size * sink_id;

    while (! is_exit_requested())
    {
        for (auto i = 0UL; i != partition_size; ++i)
        {
            auto const session_id = partition_offset + i;
            // As sinks count doesn't need to be a multiple of the
            // stack session count (e.g. 64 / 5), the last sink partition may
            // overrun the stack sessions total count
            if (session_id >= runtime_.session_count)
                break;

            auto & session = runtime_.sessions[session_id];

            // Reading the state using an acquire memory order
            // ensure session states (including bandwidth_throttle
            // modified by event polling thread) are up to date in
            // the current CPU cache line
            auto const state = session.get_tx_state();

            // Check if the connection has been established by the remote peer
            if (state != TxState::READY)
                continue;

            // Compute the amount of bytes that can be sent to honor
            // the bandwidth parameter
            auto const budget = session.bandwidth_throttle.theoretical_sent_bytes() - session.sent_bytes;
            if (! budget)
                continue;

            // Compute the size to send and the offset
            auto bytes_to_send = std::uint64_t(session.configuration->size) - session.sent_bytes;
            if (! bytes_to_send)
                continue;

            // Reduce according to the MSS or max datagram size
            bytes_to_send = std::min(session.get_payload_max_size(), bytes_to_send);

            // Reduce according to the budget
            bytes_to_send = std::min(budget, bytes_to_send);

            std::size_t const offset = std::uint8_t(session.sent_bytes);
            // Reduce according to the buffer space
            bytes_to_send = std::min(bytes_to_send, BUFFER_SIZE - offset);

            try_to_send_data(session_id,
                             &send_buffer_[offset], bytes_to_send);
        }
    }
}

template<typename Runtime>
void
Application<Runtime>::create_tx_data_sending_tasks()
{
    for (std::size_t i = 0U, e = runtime_.data_sinks.size(); i != e; ++i)
    {
        auto t = [this, i] {
            exit_on_exception exit_on_exception{};
            tx_task(i);
        };

        tasks_.push(std::async(std::launch::async, std::move(t)));
    }
}

template<typename Runtime>
void
Application<Runtime>::run()
{
    std::cout << "Starting.." << std::endl;

    if (configuration_.safe_mode)
        runtime_.stack.enable_safe_mode();
    else
        runtime_.stack.disable_safe_mode();

    switch (configuration_.mode)
    {
    case ApplicationConfiguration::LOOPBACK:
        run_loopback();
        break;
    case ApplicationConfiguration::RX:
        run_rx();
        break;
    case ApplicationConfiguration::TX:
        run_tx();
        break;
    case ApplicationConfiguration::BOTH:
        run_rxtx();
        break;
    };
}

template<typename Runtime>
std::error_code
Application<Runtime>::try_to_send_data(std::uint32_t session_id,
                                 std::uint8_t const * data,
                                 std::uint32_t size)
{
    auto & session = runtime_.sessions.at(session_id);

    if (runtime_.is_tx_buffer_full(session, size))
        return make_error_code(std::errc::resource_unavailable_try_again);

    auto const partition_size = runtime_.session_count / runtime_.data_sinks.size();
    auto & sink = runtime_.data_sinks[session_id / partition_size];

    int failure;
    bool sink_backpressure_detected = false;
    for (;;)
    {
        if (is_exit_requested())
            return make_error_code(std::errc::operation_canceled);

        failure = sink.send(session_id, data, size);
        if (! failure || errno != EAGAIN)
            break;

        sink_backpressure_detected = true;
    }
    // A back-pressure here implies:
    // - an hardware hang,
    // - a TCP credit mismatch
    // - or a PCIe buffer smaller than minimum requirement to
    //   sustain the TX bandwidth.
    session.sink_backpressure += sink_backpressure_detected;

    if (failure)
        throw std::system_error{errno, std::generic_category(),
                                "sink::send()"};

    session.sent_bytes += size;

    if (session.sent_bytes == session.configuration->size)
    {
        if (session.configuration->shutdown_policy
                == SessionConfiguration::SEND_COMPLETE)
            close_session(session_id, session.configuration->close_delay);
    }

    return std::error_code{};
}

template<typename Runtime>
void
Application<Runtime>::open_sessions(std::uint32_t partitions_count)
{
    auto const partition_size = runtime_.session_count / partitions_count +
            bool(runtime_.session_count % partitions_count);

    std::size_t session_id = 0;
    for (auto const& session_configuration : configuration_.session_configurations)
    {
        auto & session = runtime_.sessions.at(session_id);
        session.configuration = &session_configuration;
        runtime_.open_session(session_id, session);

        // Evenly distribute sessions across all partitions in order
        // to maximize CPU usage as each each partition is polled
        // by a dedicated thread.
        session_id += partition_size;
        if (session_id >= runtime_.session_count)
            session_id = (session_id % runtime_.session_count) + 1;
    }
}

template<typename Runtime>
void
Application<Runtime>::close_session(std::uint32_t session_id,
                                    std::chrono::milliseconds delay)
{
    if (delay == std::chrono::milliseconds::zero()) {
        runtime_.close_session(session_id);
        return;
    }

    auto close_task = [this, session_id, delay] {
        exit_on_exception exit_on_exception{};
        std::this_thread::sleep_for(delay);
        runtime_.close_session(session_id);
    };

    const std::lock_guard<std::mutex> lock{tasks_mutex_};
    tasks_.push(std::async(std::launch::async, std::move(close_task)));
}

template<typename Runtime>
void
Application<Runtime>::wait_for_tasks_completion()
{
    std::vector<Task> finished_tasks;

    // Wait for all threads to finish
    for (;;) {
        Task task;
        {
            const std::lock_guard<std::mutex> lock{tasks_mutex_};
            if (tasks_.empty())
                break;
            task = std::move(tasks_.front());
            tasks_.pop();
        }
        task.wait();
        finished_tasks.push_back(std::move(task));
    }

    // If a thread caught an exception, it will be re-thrown
    // at this point
    for (auto & task: finished_tasks)
        task.get();
}

template<typename Runtime>
void
Application<Runtime>::run_loopback()
{
    // Perform sanity checks
    if (runtime_.data_sources.size() != runtime_.data_sinks.size())
    {
        std::ostringstream error;
        error << "--mode=lookpack requires that a2c stream(s) count ("
              << runtime_.data_sources.size()
              << ") equals c2a stream(s) count ("
              << runtime_.data_sinks.size() << ")";

        // The rational behind this restriction is that having less sink(s)
        // than source(s) implies that sources (i.e. thread) would use a
        // single sink concurrently, and would require locking.

        throw std::runtime_error{error.str()};
    }

    if (runtime_.data_sinks.empty())
        throw std::runtime_error{"--mode=loopback requires at least"
                                 " one c2a stream"};

    // Configure sessions
    open_sessions(runtime_.data_sources.size());
    std::cout << "Started." << std::endl;

    // Create polling threads
    create_loopback_data_polling_tasks();
    runtime_.create_monitoring_tasks(tasks_);

    wait_for_tasks_completion();
}

template<typename Runtime>
void
Application<Runtime>::run_rx()
{
   if (runtime_.data_sources.empty())
        throw std::runtime_error{"--mode=rx requires at least"
                                 " one a2c stream"};

    // Configure sessions
    open_sessions(runtime_.data_sources.size());
    std::cout << "Started." << std::endl;

    // Create polling threads
    create_rx_data_polling_tasks();
    runtime_.create_monitoring_tasks(tasks_);

    wait_for_tasks_completion();
}

template<typename Runtime>
void
Application<Runtime>::run_tx()
{
    if (runtime_.data_sinks.empty())
        throw std::runtime_error{"--mode=tx requires at least"
                                 " one c2a stream"};

    // Configure sessions
    open_sessions(runtime_.data_sinks.size());
    std::cout << "Started." << std::endl;

    // Create polling threads
    create_tx_data_sending_tasks();
    runtime_.create_monitoring_tasks(tasks_);

    wait_for_tasks_completion();
}

template<typename Runtime>
void
Application<Runtime>::run_rxtx()
{
    if (runtime_.data_sources.empty())
        throw std::runtime_error{"--mode=both requires at least"
                                 " one a2c stream"};

    if (runtime_.data_sinks.empty())
        throw std::runtime_error{"--mode=both requires at least"
                                 " one c2a stream"};

    // Configure sessions
    open_sessions(runtime_.data_sources.size());
    std::cout << "Started." << std::endl;

    // Create polling threads
    create_tx_data_sending_tasks();
    create_rx_data_polling_tasks();
    runtime_.create_monitoring_tasks(tasks_);

    wait_for_tasks_completion();
}

ENYX_CORES_NAMESPACE_END
