// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

#include "backend/common/coredepth/sphereSweepParams.h"

// Kernel that computes depth for a reference input by sphere sweeping
// over the other available inputs

// TODODEPTH the input projection is currently hardcoded (change defines below)

// Bretagne:
// #define fromSphereToInput SphereToFisheye

// Orah:
//#define fromSphereToInput SphereToExternal
//#define fromInputToSphere ExternalToSphere
//#define isWithin isWithinCropCircle

// Rectilinear
//#define fromSphereToInput SphereToRect
//#define fromInputToSphere RectToSphere
//#define isWithin isWithinCropRect

// Calibrated - rectangular
#define fromSphereToInput SphereToExternal
#define fromInputToSphere ExternalToSphere
#define isWithin isWithinCropRect

#define distortionMetersTransform noopDistortionTransform
#define distortionPixelsTransform noopDistortionTransform

#define inverseDistortionMetersTransform noopDistortionTransform
#define inverseRadialPixelsTransform noopDistortionTransform

// Sphere scale is iterated in powers of two
// from 2^MIN_SPHERE_SCALE to 2^MAX_SPHERE_SCALE

// IMPORTANT set these according to the scene to get usable results
#define MIN_SPHERE_SCALE -3  // 0
#define MAX_SPHERE_SCALE 5

#define BIG_FLOAT 1e9

// if USE_ZMNCC is 1, using zero-mean normalized cross correlation
// otherwise, L2 norm
#define USE_ZMNCC 0

#define DEBUG_PRINT 0

const __device__ float scaleSmoothnessLambda = 10.0f;

inline __device__ __host__ float clampf_vs(float v, float minVal, float maxVal) {
  return v < minVal ? minVal : (v > maxVal ? maxVal : v);
}

// have: scaled rig spherical coordinates
// map scaled sphere to input
// return color
template <int PATCH_SIZE>
__device__ void mapInput(float4* values, bool* written, const read_only image2d_t texture,
                         struct InputParams inputParams, const float3* onRigSphere, const bool debug) {
  for (int k = 0; k < PATCH_SIZE; ++k) {
    *written = false;

    float2 inputCoords = FUNCTION_NAME_4(mapRigSphericalToInput, fromSphereToInput, distortionMetersTransform,
                                         distortionPixelsTransform)(
        onRigSphere[k], inputParams.transform, inputParams.scale, inputParams.distortion, inputParams.centerShift);

    inputCoords.x += inputParams.texWidth / 2.0f;
    inputCoords.y += inputParams.texHeight / 2.0f;

    if (debug && k == PATCH_SIZE / 2) {
      printf("mapped point %f, %f\n", inputCoords.x, inputCoords.y);
    }

    if (isWithin(inputCoords, (float)inputParams.texWidth, (float)inputParams.texHeight, (float)inputParams.cropLeft,
                 (float)inputParams.cropRight, (float)inputParams.cropTop, (float)inputParams.cropBottom)) {
      // const float4 color = read_texture_vs(texture, inputCoords) * 255.0f;
      const float4 color = read_texture_vs(texture, inputCoords);
      values[k] = color;
      *written = true;
    } else {
      values[k] = make_float4(0.f, 0.f, 0.f, 0.f);
    }
  }
}

template <int PATCH_SIZE>
__device__ float stddev(const float3* onRigSphere, const int referenceTextureId, read_only image2d_t texture0,
                        read_only image2d_t texture1, read_only image2d_t texture2, read_only image2d_t texture3,
                        read_only image2d_t texture4, read_only image2d_t texture5,
                        const struct InputParams6 inputParamsArray, int numInputs, float4& stddevReferenceTexture,
                        const bool debug) {
  // RGBA colors
  float4 values[NUM_INPUTS][PATCH_SIZE];

  // if the value is outside the input bounds, written will be false
  bool written[NUM_INPUTS];

  // fetch all input values at sphereScale
  if (numInputs > 0) {
    mapInput<PATCH_SIZE>(values[0], &written[0], texture0, inputParamsArray.params[0], onRigSphere, debug);
  }
  if (numInputs > 1) {
    mapInput<PATCH_SIZE>(values[1], &written[1], texture1, inputParamsArray.params[1], onRigSphere, debug);
  }
  if (numInputs > 2) {
    mapInput<PATCH_SIZE>(values[2], &written[2], texture2, inputParamsArray.params[2], onRigSphere, debug);
  }
  if (numInputs > 3) {
    mapInput<PATCH_SIZE>(values[3], &written[3], texture3, inputParamsArray.params[3], onRigSphere, debug);
  }
  if (numInputs > 4) {
    mapInput<PATCH_SIZE>(values[4], &written[4], texture4, inputParamsArray.params[4], onRigSphere, debug);
  }
  if (numInputs > 5) {
    mapInput<PATCH_SIZE>(values[5], &written[5], texture5, inputParamsArray.params[5], onRigSphere, debug);
  }

  int inputsWritten = 0;
  for (int id = 0; id < numInputs; id++) {
    if (written[id]) {
      inputsWritten++;
    }
  }

  // do not compute any score if onRigSphere is not visible by all inputs
  if (inputsWritten <= 1 || inputsWritten != numInputs) {
    return -1.f;
  }

  // var
  float4 aggregator = make_float4(0.f, 0.f, 0.f, 0.f);
#if USE_ZMNCC
  // use additive zero-mean normalized cross correlation
  // compute normalized cross correlation by pre-computing sums and sums of squared pixels
  float4 sum[NUM_INPUTS], sqsum[NUM_INPUTS], stddev[NUM_INPUTS];
  for (int id = 0; id < numInputs; id++) {
    if (written[id]) {
      sum[id] = make_float4(0.f, 0.f, 0.f, 0.f);
      sqsum[id] = make_float4(0.f, 0.f, 0.f, 0.f);
      for (int k = 0; k < PATCH_SIZE; ++k) {
        sum[id] += values[id][k];
        sqsum[id] += values[id][k] * values[id][k];
      }
      stddev[id] =
          sqrt4(fmaxf((sqsum[id] - sum[id] * sum[id] / PATCH_SIZE) / (PATCH_SIZE), make_float4(0.f, 0.f, 0.f, 0.f)));
      if (debug) {
        printf("id %d, sum %f %f %f, sqsum %f %f %f, stddev %f %f %f\n", id, sum[id].x, sum[id].y, sum[id].z,
               sqsum[id].x, sqsum[id].y, sqsum[id].z, stddev[id].x, stddev[id].y, stddev[id].z);
      }
    }
  }

  // keep track of reference texture stddev
  stddevReferenceTexture = stddev[referenceTextureId];

  for (int id = 0; id < numInputs; id++) {
    if (written[id] && id != referenceTextureId /* avoid multiplying by a zero diff */) {
      float4 product = make_float4(0.f, 0.f, 0.f, 0.f);
      for (int k = 0; k < PATCH_SIZE; ++k) {
        product += values[id][k] * values[referenceTextureId][k];
      }
      const float4 cross_correlation =
          product / PATCH_SIZE - (sum[id] * sum[referenceTextureId]) / (PATCH_SIZE * PATCH_SIZE);
      const float min_standard_dev = 1.e-6f;
      // slightly change the normalized cross correlation divisor, taking the max of the stddev squares, and not their
      // product if one patch has a much different stddev from the other one, this will bring the cross-correlation
      // towards 0
      float4 divisor = fmaxf(stddev[id] * stddev[id], stddev[referenceTextureId] * stddev[referenceTextureId]);
      divisor = fmaxf(divisor, make_float4(min_standard_dev, min_standard_dev, min_standard_dev, min_standard_dev));
      const float4 normalized_cross_correlation =
          make_float4(1.0f - cross_correlation.x / divisor.x, 1.0f - cross_correlation.y / divisor.y,
                      1.0f - cross_correlation.z / divisor.z, 0.f);
      if (debug) {
        printf(
            "i = %d, "
            "product %f %f %f, cross correlation %f %f %f "
            "dividsor %f %f %f "
            "stddev[id] %f %f %f, stddev[ref] %f %f %f, "
            "normalized_cross_correlation %f %f %f\n",
            id, product.x, product.y, product.z, cross_correlation.x, cross_correlation.y, cross_correlation.z,
            divisor.x, divisor.y, divisor.z, stddev[id].x, stddev[id].y, stddev[id].z, stddev[referenceTextureId].x,
            stddev[referenceTextureId].y, stddev[referenceTextureId].z, cross_correlation.x / divisor.x,
            cross_correlation.y / divisor.y, cross_correlation.z / divisor.z);
      }
      aggregator += normalized_cross_correlation;
    }
  }
  aggregator = aggregator * (20.f / (inputsWritten - 1));

  if (stddev[referenceTextureId].x + stddev[referenceTextureId].y + stddev[referenceTextureId].z < 0.0001f) {
    aggregator = make_float4(1000, 1000, 1000, 1000);
  }
  if (debug) {
    printf("Input stddev %f %f %f %f\n", stddev[referenceTextureId].x, stddev[referenceTextureId].y,
           stddev[referenceTextureId].z, stddev[referenceTextureId].w);
  }
#else
  // computing standard deviation
  // if a reference input is given, the score is the standard deviation of all inputs compared to the reference one
  // otherwise, it is just the standard deviation of all inputs
  // (note that the latter may favor disoccluded areas)
  // mean
  // TODO construct while writing...
  float4 mean[PATCH_SIZE];
  float4 sum[PATCH_SIZE];
  for (int k = 0; k < PATCH_SIZE; ++k) {
    if (referenceTextureId < 0) {
      sum[k] = make_float4(0.0f, 0.0f, 0.0f, 0.0f);
      for (int id = 0; id < numInputs; id++) {
        sum[k] += values[id][k];
      }
      mean[k] = sum[k] / inputsWritten;
    } else {
      mean[k] = values[referenceTextureId][k];
    }
  }
  // use multiplicative cost through log domain additions
  for (int id = 0; id < numInputs; id++) {
    if (written[id] && id != referenceTextureId) {
      float4 accumulator = make_float4(0.f, 0.f, 0.f, 0.f);
      for (int k = 0; k < PATCH_SIZE; ++k) {
        const float4 diff = values[id][k] - mean[k];
        accumulator += diff * diff;
      }
      // we want log4(sqrt4(accumulator / PATCH_SIZE))
      // let's use log4() / 2 to take out sqrt4()
      aggregator += log4(accumulator / PATCH_SIZE) * 0.5f;
    }
  }
  aggregator = exp4(aggregator / (inputsWritten - 1)) * 512.f;

  // compute stddev of reference patch
  if (referenceTextureId >= 0) {
    float4 sum = make_float4(0.f, 0.f, 0.f, 0.f);
    float4 sqsum = make_float4(0.f, 0.f, 0.f, 0.f);
    for (int k = 0; k < PATCH_SIZE; ++k) {
      sum += values[referenceTextureId][k];
      sqsum += values[referenceTextureId][k] * values[referenceTextureId][k];
    }
    stddevReferenceTexture =
        sqrt4(fmaxf((sqsum - sum * sum / PATCH_SIZE) / (PATCH_SIZE), make_float4(0.f, 0.f, 0.f, 0.f)));
  } else {
    stddevReferenceTexture = make_float4(0.f, 0.f, 0.f, 0.f);
  }
#endif

  // const float4 stddev = sqrt4(var);
  const float stddevsum = dot(aggregator, make_float4(1.f, 1.f, 1.f, 1.f));
  const float minVal = 0, maxVal = 1000;

  return clampf_vs(stddevsum, minVal, maxVal);
}

template <size_t numSteps, int PATCH_WIDTH, int PATCH_HEIGHT, int PATCH_SIZE>
__device__ void sweep(const float2 inputCoords, const float2 inputCoordsCentered, const bool sweepInInputSpace,
                      const int referenceInputId, read_only image2d_t texture0, read_only image2d_t texture1,
                      read_only image2d_t texture2, read_only image2d_t texture3, read_only image2d_t texture4,
                      read_only image2d_t texture5, const float minSphereScale, const float maxSphereScale,
                      const struct InputParams6 inputParamsArray, const int numInputs,
                      __globalmem__ unsigned short* scoreDevicePtr, const float scaleSmoothness,
                      const float scaleReference, uint32_t* retVal, int* retMinIdx) {
#if USE_ZMNCC
  static_assert(PATCH_SIZE > 1, "ZMNCC works on patches");
#endif

  static_assert(PATCH_WIDTH % 2 == 1, "PATCH_WIDTH must be odd");
  static_assert(PATCH_HEIGHT % 2 == 1, "PATCH_HEIGHT must be odd");

  const struct InputParams referenceInput = inputParamsArray.params[referenceInputId];
  const vsfloat3x4 poseInverseReferenceInput = referenceInput.inverseTransform;

  float scores[numSteps];

  float3 patchInputSphereCoords[PATCH_SIZE];

  for (int j = 0; j < PATCH_HEIGHT; ++j) {
    for (int i = 0; i < PATCH_WIDTH; ++i) {
      const float2 xy = {inputCoordsCentered.x + (i - PATCH_WIDTH / 2), inputCoordsCentered.y + (j - PATCH_HEIGHT / 2)};
      float3 pt;
      pt = FUNCTION_NAME_4(mapInputToCameraSphere, fromInputToSphere, inverseDistortionMetersTransform,
                           inverseRadialPixelsTransform)(xy, referenceInput.scale, referenceInput.distortion,
                                                         referenceInput.centerShift);
      pt = transformSphere(pt, poseInverseReferenceInput);

      patchInputSphereCoords[i + j * PATCH_WIDTH] = pt;
    }
  }

  const float3 depthCenter = transformSphere(make_float3(0.f, 0.f, 0.f), poseInverseReferenceInput);

#if DEBUG_PRINT
  const bool debug = ((int)inputCoords.x == 2068 && (int)inputCoords.y == 799);
#else
  const bool debug = false;
#endif

  float minVal = BIG_FLOAT;
  int minIdx = -1;

  // TODO do not repeatedly project the reference patch when sweeping in input space
  // since it does not vary in function of the sphere scale
  const float sphereScaleStep = (maxSphereScale - minSphereScale) / (float)(numSteps - 1);
  for (int sphereScaleIdx = 0; sphereScaleIdx < numSteps; ++sphereScaleIdx) {
    const float sphereScaleLvl = minSphereScale + sphereScaleStep * sphereScaleIdx;
    const float sphereScale = exp2((float)sphereScaleLvl);
    float4 stddevReferenceTexture;

    float3 onRigSphere[PATCH_SIZE];

    for (int k = 0; k < PATCH_SIZE; ++k) {
      const float3 pt = patchInputSphereCoords[k];

      onRigSphere[k] = (sweepInInputSpace)
                           ?
                           // if sweeping depth in input space, scale the 3D point from the input 3D center
                           sphereScale * (pt - depthCenter) + depthCenter
                           :
                           // if not sweeping depth in input space, scale the 3D point so that it is at the right depth
                           // from the rig center
                           FUNCTION_NAME_4(tracePointToRigSphere, fromInputToSphere, inverseDistortionMetersTransform,
                                           inverseRadialPixelsTransform)(pt, poseInverseReferenceInput, sphereScale);
    }

    scores[sphereScaleIdx] =
        stddev<PATCH_SIZE>(onRigSphere, referenceInputId, texture0, texture1, texture2, texture3, texture4, texture5,
                           inputParamsArray, numInputs, stddevReferenceTexture, debug);

    if (debug) {
      printf("Sphere scale level idx %d, scale level %f, depth %f, score %f, stddevReferenceTexture %f %f %f\n",
             sphereScaleIdx, sphereScaleLvl, sphereScale, scores[sphereScaleIdx], stddevReferenceTexture.x,
             stddevReferenceTexture.y, stddevReferenceTexture.z);
    }

    if (scoreDevicePtr) {
      const float sgmScore = (scores[sphereScaleIdx] >= 0.0f) ? min(scores[sphereScaleIdx], 4096.0f) : 4096.f;
      scoreDevicePtr[sphereScaleIdx] = (unsigned short)sgmScore;
    }

    // add scale smoothness term
    if (scores[sphereScaleIdx] >= 0.f) {
      const float stddevSum = dot(stddevReferenceTexture, make_float4(1.0f, 1.0f, 1.0f, 1.0f)) +
                              1.e-6f /* + epsilon to avoid dividing by zero */;
      scores[sphereScaleIdx] += scaleSmoothness * abs(sphereScaleLvl - scaleReference) / stddevSum;
    }

    if (scores[sphereScaleIdx] >= 0.f && scores[sphereScaleIdx] <= minVal) {
      minVal = scores[sphereScaleIdx];
      minIdx = sphereScaleIdx;
    }
  }

  if (debug) {
    printf("minIdx %d\n", minIdx);
  }

  uint32_t val = 0;

  if (minIdx == -1) {
    val = Image_RGBA_pack(255, 0, 0, 255);
  } else {
    // find depth with min score
    // return
    const int depth = 255 - minIdx * (256 / numSteps);
    val = Image_RGBA_pack(depth, depth, depth, 255);
  }

  *retVal = val;
  *retMinIdx = minIdx;
}

__global__ void splatInputWithDepthIntoPano(
    // pano
    global_mem uint32_t* g_odata, int panoWidth, int panoHeight, const float2 panoScale, read_only image2d_t texture,
    read_only image2d_t depth, const struct InputParams inputParams, int numInputs, float offset) {
  const int x = get_global_id_x();
  const int y = get_global_id_y();

  const int inputWidth = inputParams.texWidth;
  const int inputHeight = inputParams.texHeight;

  uint32_t val = 0;

  if (x < inputWidth && y < inputHeight) {
    const float2 inputCoords = make_float2((float)x, (float)y);
    const float2 inputCoordsCentered = {inputCoords.x - inputWidth / 2.0f, inputCoords.y - inputHeight / 2.0f};

    const float bestSphereScale = read_depth_vs(depth, inputCoords);

    // get the coordinates of input in pano at best sphere scale
    float3 rigCoords = FUNCTION_NAME_4(mapInputToRigSpherical, fromInputToSphere, inverseDistortionMetersTransform,
                                       inverseRadialPixelsTransform)(inputCoordsCentered, inputParams.inverseTransform,
                                                                     bestSphereScale, inputParams.scale,
                                                                     inputParams.distortion, inputParams.centerShift);

    rigCoords.x += offset;

    /** From spherical to output space */
    float2 panoCoords = SphereToErect(rigCoords);

    /** From output space to normalized space */
    panoCoords *= panoScale;

    const float4 color = read_texture_vs(texture, inputCoords) * 255.0f;
    val = Image_RGBA_pack(Image_clamp8((int32_t)color.x), Image_clamp8((int32_t)color.y),
                          Image_clamp8((int32_t)color.z), (int32_t)color.w);

    // IN PANO COORDS
    const float2 panoCoordsCenter = {panoCoords.x + panoWidth / 2.0f, panoCoords.y + panoHeight / 2.0f};
    const int2 panoCoordsInt = convert_int2(panoCoordsCenter);

    float upperBound = BIG_FLOAT;
    float lowerBound = BIG_FLOAT;

    const float sphereScaleStep = (MAX_SPHERE_SCALE - MIN_SPHERE_SCALE) / (float)(numSphereScales - 1);
    // iterate through sphere scales, start splatting at far end
    for (int sphereScaleIdx = numSphereScales - 1; sphereScaleIdx >= 0; sphereScaleIdx--) {
      const float idxSphereScaleLvl = MIN_SPHERE_SCALE + sphereScaleStep * sphereScaleIdx;

      const float idxSphereScale = exp2(idxSphereScaleLvl);
      lowerBound = idxSphereScale;

      if (bestSphereScale >= lowerBound && bestSphereScale < upperBound) {
        if (panoCoordsInt.x >= 0 && panoCoordsInt.x < panoWidth && panoCoordsInt.y >= 0 &&
            panoCoordsInt.y < panoHeight) {
          global_mem uint32_t* panoPtr = g_odata + (panoCoordsInt.y * panoWidth + panoCoordsInt.x);
          *panoPtr = val;
        }

        upperBound = lowerBound;
      }
      sync_threads(CLK_GLOBAL_MEM_FENCE);
    }
  }
}

__global__ void sphereSweepInputKernel(surface_t dst, int outWidth, int outHeight,
                                       __globalmem__ unsigned short* scoreVolumeDevicePtr, read_only image2d_t texture0,
                                       read_only image2d_t texture1, read_only image2d_t texture2,
                                       read_only image2d_t texture3, read_only image2d_t texture4,
                                       read_only image2d_t texture5, const struct InputParams6 inputParamsArray,
                                       const int referenceInputId, int numInputs, int xblock, int yblock,
                                       int blockSizeX, int blockSizeY) {
  const int x = get_global_id_x() + xblock * blockSizeX;
  const int y = get_global_id_y() + yblock * blockSizeY;

  const int inputWidth = inputParamsArray.params[referenceInputId].texWidth;
  const int inputHeight = inputParamsArray.params[referenceInputId].texHeight;

  uint32_t val = 0;

  if (x < inputWidth && y < inputHeight && (int)get_global_id_x() < blockSizeX && (int)get_global_id_y() < blockSizeY) {
    const float2 inputCoords = make_float2((float)x, (float)y);
    const float2 inputCoordsCentered = {inputCoords.x - inputWidth / 2.0f, inputCoords.y - inputHeight / 2.0f};

    int minIdx;

    sweep<numSphereScales, patchWidth, patchHeight, patchSize>(
        inputCoords, inputCoordsCentered, true, referenceInputId, texture0, texture1, texture2, texture3, texture4,
        texture5, MIN_SPHERE_SCALE, MAX_SPHERE_SCALE, inputParamsArray, numInputs,
        (scoreVolumeDevicePtr) ? scoreVolumeDevicePtr + (x + y * inputWidth) * numSphereScales : nullptr,
        0.0f /* no smoothness on scale field */, 0.0f /* no reference scale */, &val, &minIdx);

    if (x < outWidth && y < outHeight) {
      float bestSphereScale = 0.0f;
      if (minIdx >= 0) {
        const float bestSphereScaleLvl =
            MIN_SPHERE_SCALE +
            ((float)(MAX_SPHERE_SCALE - MIN_SPHERE_SCALE - 0.01f) / (float)(numSphereScales - 1)) * minIdx;
        bestSphereScale = exp2((float)bestSphereScaleLvl);
      }
      surface_write_depth(bestSphereScale, dst, x, y);
    }
  }
}

// translates disparity levels into depth values
__global__ void sphereSweepInputDisparityToDepthKernel(surface_t dst, int outWidth, int outHeight,
                                                       __globalmem__ short* disparityDevicePtr,
                                                       read_only image2d_t texture, int xblock, int yblock,
                                                       int blockSizeX, int blockSizeY) {
  const int x = get_global_id_x() + xblock * blockSizeX;
  const int y = get_global_id_y() + yblock * blockSizeY;

  if (x < outWidth && y < outHeight && (int)get_global_id_x() < blockSizeX && (int)get_global_id_y() < blockSizeY) {
    const float4 pixel = read_texture_vs(texture, make_float2(x, y));
    float disparity;
    if (disparityDevicePtr) {
      disparity = disparityDevicePtr[x + y * outWidth] / 16.0f; /* SGM resolution scaling factor */
    } else {
      surface_read_depth(&disparity, dst, x, y);
    }

    const float sphereScaleStep = (MAX_SPHERE_SCALE - MIN_SPHERE_SCALE - 0.01f) / (float)(numSphereScales - 1);
    const float bestSphereScaleLvl = MIN_SPHERE_SCALE + sphereScaleStep * disparity;
    float bestSphereScale = exp2((float)bestSphereScaleLvl);

    if (bestSphereScaleLvl < MIN_SPHERE_SCALE || pixel.w < 1.f /* consider undistortion masks */) {
      bestSphereScale = 0;
    }
    surface_write_depth(bestSphereScale, dst, x, y);
  }
}

__global__ void sphereSweepInputKernelStep(surface_t dst, int outWidth, int outHeight, surface_t srcNextLevel,
                                           int nextLevelWidth, int nextLevelHeight, read_only image2d_t texture0,
                                           read_only image2d_t texture1, read_only image2d_t texture2,
                                           read_only image2d_t texture3, read_only image2d_t texture4,
                                           read_only image2d_t texture5, const struct InputParams6 inputParamsArray,
                                           const int referenceInputId, int numInputs, int xblock, int yblock,
                                           int blockSizeX, int blockSizeY, float searchSpan) {
  // numSphereScalesRefine must be an odd number,
  // to ensure that the depth value from the coarser level is considered at the finer level
  static_assert(numSphereScalesRefine % 2 == 1, "numSphereScalesRefine must be an odd number");

  const int x = get_global_id_x() + xblock * blockSizeX;
  const int y = get_global_id_y() + yblock * blockSizeY;

  const int inputWidth = inputParamsArray.params[referenceInputId].texWidth;
  const int inputHeight = inputParamsArray.params[referenceInputId].texHeight;

  uint32_t val = 0;

  if (x < inputWidth && y < inputHeight && (int)get_global_id_x() < blockSizeX && (int)get_global_id_y() < blockSizeY) {
    const float2 inputCoords = make_float2((float)x, (float)y);
    const float2 inputCoordsCentered = {inputCoords.x - inputWidth / 2.0f, inputCoords.y - inputHeight / 2.0f};

    // blocky
    float depthLowerLevel;
    surface_read_depth(&depthLowerLevel, srcNextLevel, x / 2, y / 2);

    // with interpolation
    // NOTE: to enable, also change .surface() in sphereSweep.cu to .texture()
    // float depthLowerLevel = tex2D<float>(srcNextLevel, inputCoords.x / 2.f, inputCoords.y / 2.f);

    float bestSphereScale = 0.0f;

    if (depthLowerLevel > 0.0f) {
      // adapt the searchSpan to the neighboring depths from the coarser levels
      // TODO PERF neighboring threads are loading the same values, ptoentially use shared memory
      const float depthScale = log2(depthLowerLevel);
      float minSphereScale = depthLowerLevel;
      float maxSphereScale = depthLowerLevel;

      auto checkMinMaxNeighboringDepthAtCoarserLevel = [&](const int x, const int y) -> void {
        float neighboringDepth;
        surface_read_depth(&neighboringDepth, srcNextLevel, x, y);
        if (minSphereScale > neighboringDepth) {
          minSphereScale = neighboringDepth;
        }
        if (maxSphereScale < neighboringDepth) {
          maxSphereScale = neighboringDepth;
        }
      };
      checkMinMaxNeighboringDepthAtCoarserLevel(x / 2 - 2, y / 2 - 2);
      checkMinMaxNeighboringDepthAtCoarserLevel(x / 2 + 2, y / 2 - 2);
      checkMinMaxNeighboringDepthAtCoarserLevel(x / 2 + 2, y / 2 + 2);
      checkMinMaxNeighboringDepthAtCoarserLevel(x / 2 - 2, y / 2 + 2);

      // TODO check if the manually set level-based searchSpan can be disregarded and trust only the neighborhood
      // defined one pros: more depth homogeneity cons: will not really refine homogeneous areas and recover finer
      // details
      searchSpan = max(searchSpan, max(depthScale - log2(minSphereScale), log2(maxSphereScale) - depthScale));

      minSphereScale = clampf_vs(depthScale - searchSpan, (float)MIN_SPHERE_SCALE, (float)MAX_SPHERE_SCALE);
      maxSphereScale = clampf_vs(depthScale + searchSpan, (float)MIN_SPHERE_SCALE, (float)MAX_SPHERE_SCALE);

      // if previous value was invalid, search the whole space again
      if (minSphereScale == maxSphereScale) {
        minSphereScale = MIN_SPHERE_SCALE;
        maxSphereScale = MAX_SPHERE_SCALE;
      }

      int minIdx;
      sweep<numSphereScalesRefine, patchWidthStep, patchHeightStep, patchSizeStep>(
          inputCoords, inputCoordsCentered, true, referenceInputId, texture0, texture1, texture2, texture3, texture4,
          texture5, minSphereScale, maxSphereScale, inputParamsArray, numInputs, nullptr, scaleSmoothnessLambda,
          depthScale, &val, &minIdx);

      // handle invalid value (not enough contributing inputs)
      if (minIdx != -1) {
        const float sphereScaleStep =
            (float)(maxSphereScale - minSphereScale - 0.01f) / (float)(numSphereScalesRefine - 1);
        float bestSphereScaleLvl = minSphereScale + sphereScaleStep * minIdx;

        bestSphereScale = exp2((float)bestSphereScaleLvl);

#if DEBUG_PRINT
        if (x == 100 && y == 100) {
          printf(
              "depthLowerLevel: %f, depthScale: %f, minSphereScale: %f, maxSphereScale: %f, bestSphereScaleLvl: %f, "
              "bestSphereScale: %f\n",
              depthLowerLevel, depthScale, minSphereScale, maxSphereScale, bestSphereScaleLvl, bestSphereScale);
        }
#endif
      }
    }

    if (x < outWidth && y < outHeight) {
      surface_write_depth(bestSphereScale, dst, x, y);

      // debug: copy lower level value without sweeping
      // surface_write_depth(depthLowerLevel, dst, x, y);
    }
  }
}