// Copyright (c) 2012-2017 VideoStitch SAS
// Copyright (c) 2018 stitchEm

__global__ void diskKernel(global_mem uint32_t* dst, unsigned width, unsigned height, float aX, float aY, float t,
                           uint32_t color) {
  unsigned x = get_global_id_x();
  unsigned y = get_global_id_y();
  if (x < width && y < height) {
    if ((aX - x) * (aX - x) + (aY - y) * (aY - y) < t) {
      dst[width * y + x] = color;
    }
  }
}
__global__ void diskSourceKernel(surface_t dst, unsigned width, unsigned height, float aX, float aY, float t,
                                 uint32_t color) {
  unsigned x = get_global_id_x();
  unsigned y = get_global_id_y();
  if (x < width && y < height) {
    if ((aX - x) * (aX - x) + (aY - y) * (aY - y) < t) {
      surface_write_i(color, dst, x, y);
    }
  }
}

__global__ void lineKernel(global_mem uint32_t* dst, unsigned width, unsigned height, float aX, float aY, float bX,
                           float bY, float t, uint32_t color) {
  unsigned x = get_global_id_x();
  unsigned y = get_global_id_y();
  if (x < width && y < height) {
    float sqrLen = (aX - bX) * (aX - bX) + (aY - bY) * (aY - bY);
    float p = (((float)x - aX) * (bX - aX) + ((float)y - aY) * (bY - aY)) / sqrLen;
    float sqrDst;
    if (p <= 0.0f) {
      sqrDst = (aX - (float)x) * (aX - (float)x) + (aY - (float)y) * (aY - (float)y);
    } else if (p >= 1.0f) {
      sqrDst = (bX - (float)x) * (bX - (float)x) + (bY - (float)y) * (bY - (float)y);
    } else {
      float projX = aX + p * (bX - aX);
      float projY = aY + p * (bY - aY);
      sqrDst = (projX - (float)x) * (projX - (float)x) + (projY - (float)y) * (projY - (float)y);
    }
    if (sqrDst * 4.0f < t) {
      dst[width * y + x] = color;
    }
  }
}
__global__ void lineSourceKernel(surface_t dst, unsigned width, unsigned height, float aX, float aY, float bX, float bY,
                                 float t, uint32_t color) {
  unsigned x = get_global_id_x();
  unsigned y = get_global_id_y();
  if (x < width && y < height) {
    float sqrLen = (aX - bX) * (aX - bX) + (aY - bY) * (aY - bY);
    float p = (((float)x - aX) * (bX - aX) + ((float)y - aY) * (bY - aY)) / sqrLen;
    float sqrDst;
    if (p <= 0.0f) {
      sqrDst = (aX - (float)x) * (aX - (float)x) + (aY - (float)y) * (aY - (float)y);
    } else if (p >= 1.0f) {
      sqrDst = (bX - (float)x) * (bX - (float)x) + (bY - (float)y) * (bY - (float)y);
    } else {
      float projX = aX + p * (bX - aX);
      float projY = aY + p * (bY - aY);
      sqrDst = (projX - (float)x) * (projX - (float)x) + (projY - (float)y) * (projY - (float)y);
    }
    if (sqrDst * 4.0f < t) {
      surface_write_i(color, dst, x, y);
    }
  }
}

#define CIRCLE_KERNEL_IMPL(name, moreTest, moreDesc)                                                            \
  __global__ void name(global_mem uint32_t* dst, unsigned width, unsigned height, float centerX, float centerY, \
                       float innerSqrRadius, float outerSqrRadius, uint32_t color) {                            \
    unsigned x = get_global_id_x();                                                                             \
    unsigned y = get_global_id_y();                                                                             \
    if (x < width && y < height) {                                                                              \
      float sqrDst = (centerX - (float)x) * (centerX - (float)x) + (centerY - (float)y) * (centerY - (float)y); \
      if (innerSqrRadius < sqrDst && sqrDst < outerSqrRadius && (moreTest)) {                                   \
        dst[width * y + x] = color;                                                                             \
      }                                                                                                         \
    }                                                                                                           \
  }                                                                                                             \
                                                                                                                \
  __global__ void name##Source(surface_t dst, unsigned width, unsigned height, float centerX, float centerY,    \
                               float innerSqrRadius, float outerSqrRadius, uint32_t color) {                    \
    unsigned x = get_global_id_x();                                                                             \
    unsigned y = get_global_id_y();                                                                             \
    if (x < width && y < height) {                                                                              \
      float sqrDst = (centerX - (float)x) * (centerX - (float)x) + (centerY - (float)y) * (centerY - (float)y); \
      if (innerSqrRadius < sqrDst && sqrDst < outerSqrRadius && (moreTest)) {                                   \
        surface_write_i(color, dst, x, y);                                                                      \
      }                                                                                                         \
    }                                                                                                           \
  }

CIRCLE_KERNEL_IMPL(circleKernel, true, .)
CIRCLE_KERNEL_IMPL(circleTRKernel, x + 1 > centerX && y < centerY + 1, (top right quarter only))
CIRCLE_KERNEL_IMPL(circleBRKernel, x + 1 > centerX && y > centerY - 1, (bottom right quarter only))
CIRCLE_KERNEL_IMPL(circleTKernel, y < centerY, (top only))
CIRCLE_KERNEL_IMPL(circleBKernel, y > centerY, (bottom only))