// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

#ifndef RANSACPROBLEM_HPP_
#define RANSACPROBLEM_HPP_

#include <algorithm>
#include <iostream>
#include "lmfit/lmmin.hpp"
#include <random>

// #define RANSAC_VERBOSE

namespace VideoStitch {
namespace Util {

/**
 * A RANSAC layer over any Solver
 * NOT thread-safe.
 */
template <typename Solver_t>
class RansacSolver : public Solver<typename Solver_t::Problem_t> {
 public:
  /**
   * Creates a RANSAC model.
   * @param minSamplesForFit Minimum number of samples required to fit a model.
   */
  RansacSolver(const typename Solver_t::Problem_t& problem, int minSamplesForFit, int numIters, int minConsensusSamples,
               std::default_random_engine* gen = nullptr, bool debug = false, bool useFloatPrecision = false)
      : Solver<typename Solver_t::Problem_t>(problem),
        minSamplesForFit(minSamplesForFit),
        numIters(numIters),
        minConsensusSamples(minConsensusSamples),
        bitSet(problem.getNumInputSamples()),
        solver(problem, bitSet.data(), debug, useFloatPrecision),
        gen(gen) {}

  virtual ~RansacSolver() {}

  lm_control_struct& getControl() { return solver.getControl(); }

  bool run(std::vector<double>& params) {
    std::vector<char> inlierIndices;
    std::vector<double> outputResiduals;
    return run(params, inlierIndices, outputResiduals);
  }

  bool run(std::vector<double>& params, std::vector<char>& inlierIndices, std::vector<double>& outputResiduals) {
    if ((int)bitSet.size() < minSamplesForFit) {
      return false;
    }
    int bestNumConsensual = 0;
    std::vector<double> curModel(params.size());
    // Inliers and consensus sets. 0 means not selected.
    std::vector<double> residuals(solver.getProblem().getNumOutputValues());
    for (int iter = 0; iter < numIters; ++iter) {
#ifdef RANSAC_VERBOSE
      std::cout << "iter " << iter << ":" << std::endl;
#endif
      curModel = params;
      // Select random subset of size minSamplesForFit. bitSet = maybeInlinersSet.
      populateRandom(bitSet, minSamplesForFit);
      // Fit model on subset.
      if (!solver.run(curModel)) {
        continue;
      }
#ifdef RANSAC_VERBOSE
      for (size_t k = 0; k < params.size(); ++k) {
        std::cout << "  " << curModel[k] << std::endl;
      }
#endif
      // And get the residuals.
      bool requestBreakNotPossible = false;
      solver.getProblem().eval(curModel.data(), (int)residuals.size(), residuals.data(), NULL, 0,
                               &requestBreakNotPossible);
      // Get the residuals. bitSet = consensusSet.

      int numConsensual = 0;
      for (int i = 0; i < solver.getProblem().getNumInputSamples(); ++i) {
        if (isConsensualSample(residuals.data() + i * solver.getProblem().getNumValuesPerSample())) {
          ++numConsensual;
          bitSet[i] = 1;
        } else {
          bitSet[i] = 0;
        }
      }
#ifdef RANSAC_VERBOSE
      std::cout << "numConsensual: " << numConsensual << "/" << solver.getProblem().getNumInputSamples() << " "
                << minConsensusSamples << " " << bestNumConsensual << std::endl;
#endif

      // Check if found rotation matrix fits the presets bounds
      if (!validate(curModel.data())) {
#ifdef RANSAC_VERBOSE
        std::cout << "estimated rotation out of the presets" << std::endl;
#endif
        continue;
      }

      if (numConsensual > minConsensusSamples && numConsensual > bestNumConsensual) {
#ifdef RANSAC_VERBOSE
        std::cout << "new best : " << numConsensual << std::endl;
        std::cout << " model:" << std::endl;
        for (size_t k = 0; k < params.size(); ++k) {
          std::cout << "  " << curModel[k] << std::endl;
        }
#endif
        if (!solver.run(curModel)) {
          continue;
        }
#ifdef RANSAC_VERBOSE
        std::cout << " model2:" << std::endl;
        for (size_t k = 0; k < params.size(); ++k) {
          std::cout << "  " << curModel[k] << std::endl;
        }
#endif
        params = curModel;
        bestNumConsensual = numConsensual;
        inlierIndices = std::vector<char>(bitSet);
        outputResiduals = std::vector<double>(residuals);

        if (numConsensual ==
            solver.getProblem()
                .getNumInputSamples()) {  // if all samples are inliers then stop looking for a better consensus
          return true;
        }
      }
    }
    if (bestNumConsensual == 0) {
      return false;
    }
    return true;
  }

 private:
  virtual bool validate(double* /*values*/) const { return true; }

  /**
   * Implements the criterion for a consensual samples.
   * @param values The getValuesPerSample() values for this sample.
   */
  virtual bool isConsensualSample(double* values) const = 0;

  /**
   * Implements the random selection. The default is purely random selection, but some algorithms may have further
   * constraints.
   * @param vector to populate.
   * @param numBitsSets Minimum number of samples to select.
   */
  virtual void populateRandom(std::vector<char>& v, size_t numBitsSets) const {
    for (size_t i = 0; i < numBitsSets; ++i) {
      v[i] = 1;
    }
    for (size_t i = numBitsSets; i < v.size(); ++i) {
      v[i] = 0;
    }
    if (gen) {
      std::shuffle(v.begin(), v.end(), *gen);
    } else {
      std::random_shuffle(v.begin(), v.end());
    }
  }

  const int minSamplesForFit;
  const int numIters;
  const int minConsensusSamples;
  std::vector<char> bitSet;
  Solver_t solver;
  std::default_random_engine* gen;
};
}  // namespace Util
}  // namespace VideoStitch

#endif