ransac.hpp 5.44 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

#ifndef RANSACPROBLEM_HPP_
#define RANSACPROBLEM_HPP_

#include <algorithm>
#include <iostream>
#include "lmfit/lmmin.hpp"
#include <random>

// #define RANSAC_VERBOSE

namespace VideoStitch {
namespace Util {

/**
 * A RANSAC layer over any Solver
 * NOT thread-safe.
 */
template <typename Solver_t>
class RansacSolver : public Solver<typename Solver_t::Problem_t> {
 public:
  /**
   * Creates a RANSAC model.
   * @param minSamplesForFit Minimum number of samples required to fit a model.
   */
  RansacSolver(const typename Solver_t::Problem_t& problem, int minSamplesForFit, int numIters, int minConsensusSamples,
               std::default_random_engine* gen = nullptr, bool debug = false, bool useFloatPrecision = false)
      : Solver<typename Solver_t::Problem_t>(problem),
        minSamplesForFit(minSamplesForFit),
        numIters(numIters),
        minConsensusSamples(minConsensusSamples),
        bitSet(problem.getNumInputSamples()),
        solver(problem, bitSet.data(), debug, useFloatPrecision),
        gen(gen) {}

  virtual ~RansacSolver() {}

  lm_control_struct& getControl() { return solver.getControl(); }

  bool run(std::vector<double>& params) {
    std::vector<char> inlierIndices;
    std::vector<double> outputResiduals;
    return run(params, inlierIndices, outputResiduals);
  }

  bool run(std::vector<double>& params, std::vector<char>& inlierIndices, std::vector<double>& outputResiduals) {
    if ((int)bitSet.size() < minSamplesForFit) {
      return false;
    }
    int bestNumConsensual = 0;
    std::vector<double> curModel(params.size());
    // Inliers and consensus sets. 0 means not selected.
    std::vector<double> residuals(solver.getProblem().getNumOutputValues());
    for (int iter = 0; iter < numIters; ++iter) {
#ifdef RANSAC_VERBOSE
      std::cout << "iter " << iter << ":" << std::endl;
#endif
      curModel = params;
      // Select random subset of size minSamplesForFit. bitSet = maybeInlinersSet.
      populateRandom(bitSet, minSamplesForFit);
      // Fit model on subset.
      if (!solver.run(curModel)) {
        continue;
      }
#ifdef RANSAC_VERBOSE
      for (size_t k = 0; k < params.size(); ++k) {
        std::cout << "  " << curModel[k] << std::endl;
      }
#endif
      // And get the residuals.
      bool requestBreakNotPossible = false;
      solver.getProblem().eval(curModel.data(), (int)residuals.size(), residuals.data(), NULL, 0,
                               &requestBreakNotPossible);
      // Get the residuals. bitSet = consensusSet.

      int numConsensual = 0;
      for (int i = 0; i < solver.getProblem().getNumInputSamples(); ++i) {
        if (isConsensualSample(residuals.data() + i * solver.getProblem().getNumValuesPerSample())) {
          ++numConsensual;
          bitSet[i] = 1;
        } else {
          bitSet[i] = 0;
        }
      }
#ifdef RANSAC_VERBOSE
      std::cout << "numConsensual: " << numConsensual << "/" << solver.getProblem().getNumInputSamples() << " "
                << minConsensusSamples << " " << bestNumConsensual << std::endl;
#endif

      // Check if found rotation matrix fits the presets bounds
      if (!validate(curModel.data())) {
#ifdef RANSAC_VERBOSE
        std::cout << "estimated rotation out of the presets" << std::endl;
#endif
        continue;
      }

      if (numConsensual > minConsensusSamples && numConsensual > bestNumConsensual) {
#ifdef RANSAC_VERBOSE
        std::cout << "new best : " << numConsensual << std::endl;
        std::cout << " model:" << std::endl;
        for (size_t k = 0; k < params.size(); ++k) {
          std::cout << "  " << curModel[k] << std::endl;
        }
#endif
        if (!solver.run(curModel)) {
          continue;
        }
#ifdef RANSAC_VERBOSE
        std::cout << " model2:" << std::endl;
        for (size_t k = 0; k < params.size(); ++k) {
          std::cout << "  " << curModel[k] << std::endl;
        }
#endif
        params = curModel;
        bestNumConsensual = numConsensual;
        inlierIndices = std::vector<char>(bitSet);
        outputResiduals = std::vector<double>(residuals);

        if (numConsensual ==
            solver.getProblem()
                .getNumInputSamples()) {  // if all samples are inliers then stop looking for a better consensus
          return true;
        }
      }
    }
    if (bestNumConsensual == 0) {
      return false;
    }
    return true;
  }

 private:
  virtual bool validate(double* /*values*/) const { return true; }

  /**
   * Implements the criterion for a consensual samples.
   * @param values The getValuesPerSample() values for this sample.
   */
  virtual bool isConsensualSample(double* values) const = 0;

  /**
   * Implements the random selection. The default is purely random selection, but some algorithms may have further
   * constraints.
   * @param vector to populate.
   * @param numBitsSets Minimum number of samples to select.
   */
  virtual void populateRandom(std::vector<char>& v, size_t numBitsSets) const {
    for (size_t i = 0; i < numBitsSets; ++i) {
      v[i] = 1;
    }
    for (size_t i = numBitsSets; i < v.size(); ++i) {
      v[i] = 0;
    }
    if (gen) {
      std::shuffle(v.begin(), v.end(), *gen);
    } else {
      std::random_shuffle(v.begin(), v.end());
    }
  }

  const int minSamplesForFit;
  const int numIters;
  const int minConsensusSamples;
  std::vector<char> bitSet;
  Solver_t solver;
  std::default_random_engine* gen;
};
}  // namespace Util
}  // namespace VideoStitch

#endif