/***************************************************************************
 *
 * Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
 *
 **************************************************************************/

/**
 * @author LiuJie
 * @brief demo_multi_thread
 *
 **/

#include <assert.h>
#include <dirent.h>
#include <future>
#include <iostream>
#include <sstream>
#include <sys/stat.h>
#include <thread>

#include "easyedge/easyedge.h"
#include "easyedge/easyedge_config.h"
#include "easyedge/easyedge_gpu_turbo_config.h"

using namespace easyedge;

bool is_file(const std::string &path) {
    struct stat st;
    return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode);
}


void list_file(std::string path, std::vector<std::string>& files) {
    DIR *p_dir;
    struct dirent *entry;
    if (p_dir = opendir(path.c_str())) {
        while (entry = readdir(p_dir)) {
            if (strcmp(entry->d_name, ".") != 0 && strcmp(entry->d_name, "..") != 0) {
                files.push_back(path + "/" + entry->d_name);
            }
        }
        closedir(p_dir);
    }
}


void split_image(const std::vector<std::string>& origin, std::vector<std::vector<std::string>>& result, int split_num) {
    assert(origin.size() >= split_num);
    result.resize(split_num);
    int avg = origin.size() / split_num;
    int left = origin.size() % split_num;
    int start = 0;
    int end = 0;
    int pushed = 0;
    for (int i = 0; i < split_num; ++i) {
        start = end;
        end = i < left ? pushed + avg + 1 : pushed + avg;
        for (int j = start; j < end; ++j) {
            result[i].push_back(origin[j]);
        }
        pushed += (end - start);
    }
}


void print_result(std::unique_ptr<EdgePredictor>& predictor, std::vector<std::vector<EdgeResultData>>& result, std::vector<std::string>& img_files) {
    for (int i = 0; i < result.size(); ++i) {
        std::cout << "Results of image " << img_files[i] << ": " << std::endl;
        if (result[i].empty()) {
            std::cerr << "empty result" << std::endl;
            continue;
        }
        for (auto &v : result[i]) {
            std::cout << v.index << ", " << v.label << ", p:" << v.prob;
            if (predictor->model_kind() == EdgeModelKind::kObjectDetection) {
                std::cout << " loc: "
                        << v.x1 << ", " << v.y1 << ", " << v.x2 << ", " << v.y2;
            }
            std::cout << std::endl;
        }
    }
}


void infer_task(std::unique_ptr<EdgePredictor>& predictor,
        std::vector<std::string>& img_files,
        std::map<int /*predictor index*/, std::vector<std::vector<EdgeResultData>>>& results,
        int index) {
    std::cout << "Index: " << index << " | Thread id: " << std::this_thread::get_id() << " | Inference start" << std::endl;

    std::vector<cv::Mat> imgs;
    for (auto& img_file : img_files) {
        imgs.push_back(cv::imread(img_file));
    }

    int iterations = 1;
    auto t_start = std::chrono::high_resolution_clock::now();
    for (int i = 0; i < iterations; ++i) {
        results[index].clear();
        predictor->infer(imgs, results[index]);
    }
    auto t_end = std::chrono::high_resolution_clock::now();
    float time_cost = std::chrono::duration<float, std::milli>(t_end - t_start).count();
    std::cout << "Index: " << index << " | Thread id: " << std::this_thread::get_id() << " | Average time costs: " << time_cost / iterations << "ms" << std::endl;
}


void run_multi_infer(const std::string& model_dir, const std::vector<std::string>& img_files, const std::string& serial_num) {
    EdgePredictorConfig config;
    config.model_dir = model_dir;
    config.set_config(params::PREDICTOR_KEY_SERIAL_NUM, serial_num);
    config.set_config(params::PREDICTOR_KEY_GTURBO_MAX_BATCH_SIZE, 1);      // 优化的模型可以支持的最大batch_size，实际单次推理的图片数不能大于此值
    config.set_config(params::PREDICTOR_KEY_GTURBO_MAX_CONCURRENCY, 2);     // 设置device对应的卡可以使用的最大并发数
    config.set_config(params::PREDICTOR_KEY_GTURBO_FP16, false);            // 置true开启fp16模式推理会更快，精度会略微降低，但取决于硬件是否支持fp16，不是所有模型都支持fp16，参阅文档
    config.set_config(params::PREDICTOR_KEY_GTURBO_COMPILE_LEVEL, 1);       // 编译模型的策略，如果当前设置的max_batch_size与历史编译存储的不同，则重新编译模型

    auto predictor = global_controller()->CreateEdgePredictor(config);
    if (predictor->init() != EDGE_OK) {
        exit(-1);
    }

    // 按线程数切分图片
    std::vector<std::vector<std::string>> split_img_files(config.get_config<int>(params::PREDICTOR_KEY_GTURBO_MAX_CONCURRENCY));
    split_image(img_files, split_img_files, config.get_config<int>(params::PREDICTOR_KEY_GTURBO_MAX_CONCURRENCY));

    std::vector<std::thread> threads;
    std::map<int /*concurrency index*/, std::vector<std::vector<EdgeResultData>>> results;
    for (int i = 0; i < config.get_config<int>(params::PREDICTOR_KEY_GTURBO_MAX_CONCURRENCY); ++i) {
        assert(split_img_files[i].size() <= config.get_config<int>(params::PREDICTOR_KEY_GTURBO_MAX_BATCH_SIZE));
        threads.emplace_back(std::thread(infer_task, std::ref(predictor), std::ref(split_img_files[i]), std::ref(results), i));
    }

    for (auto& t : threads) {
        if (t.joinable()) {
            t.join();
        }
    }

    // 打印结果
    for (auto& result : results) {
        int index = result.first;
        std::cout << "[ Index: " << index << " ]" << std::endl;
        print_result(predictor, results[index], split_img_files[index]);
    }
}


int main(int argc, char *argv[]) {

    if (argc != 3) {
        std::cerr << "Usage: demo {model_dir} {image_directory} {serial_num}" << std::endl;
        exit(-1);
    }
    std::string model_dir = argv[1];
    std::string img_dir = argv[2];
    std::string serial_num = "";

    if (argc >= 4) {
        serial_num = argv[3];
    }

    EdgeLogConfig log_config;
    log_config.enable_debug = false;
    log_config.to_file = false;
    global_controller()->set_log_config(log_config);

    // 获取图片路径
    std::vector<std::string> img_files;
    if (is_file(img_dir)) {
        img_files.push_back(img_dir);
    } else {
        list_file(img_dir, img_files);
    }

    // 多线程infer，线程数增加会降低单次infer的速度，建议优先考虑使用batch infer
    run_multi_infer(model_dir, img_files, serial_num);

    std::cout << "All Done" << std::endl;

    return 0;
}