import {
  gridShader, basicShader, vertexShader, checkerShader, clippingShader, maskShader, basicPremultiplyShader,
  unpremultiplyShader, antsShader, circleShader, ellipseShader, circleOutlineShader, ellipseOutlineShader,
  softBrush1Shader, softBrush2Shader, vertexColorShader, solidColorShader, drawingShader, testShader,
  rectOutlineShader, hslShader, maskPremultiplyShader, motionBlurShader, lineShader, toolShader, imageBrushShader, imageBrushCursorShader, fastNormalShader,
  maskUnpremultiplyShader, fastBrushShader, fastEraserShader, basicColorShader, fastMoveShader, meshShader,
  vertexMeshShader, spriteShader, hueSaturationLightnessShader, brightnessContrastShader,
  antAiShader, curvesShader, aiInpaintMaskShader, aiOutpaintMaskShader, perspectiveGridLineShader
} from '../generated/shaders';

export { vertexShader, vertexMeshShader };

function baseFragmentShader(code: string, decl: string) {
  return toolShader.replace('DECL_HERE', decl).replace('CODE_HERE', code);
}

function basicFragmentShader(code: string) {
  return baseFragmentShader(`
    vec4 blended = blend(src, dst);
  `, `
    vec3 blendColor(vec3 src, vec3 dst) {
      ${code}
    }
    vec4 blend(vec4 src, vec4 dst) {
      src *= opacity;
      vec3 rgb = blendColor(getRGB(src), getRGB(dst));
      return alphaBlendOperator(src, dst, rgb);
    }
  `);
}

function hslFragmentShaderNew(code: string) {
  return baseFragmentShader(`
    vec4 blended = blend(src, dst);
  `, `
    ${hslShader}
    vec3 blendColor(vec3 src, vec3 dst, float sa, float da) {
      ${code}
    }
    vec4 blend(vec4 src, vec4 dst) {
      src *= opacity;
      vec3 rgb = blendColor(getRGB(src), getRGB(dst), 1.0, 1.0);
      return alphaBlendOperator(src, dst, rgb);
    }
  `);
}

export const shaders: { [key: string]: string; } = {
  drawing: drawingShader,
  drawingGrayscale: `#define GRAYSCALE\n${drawingShader}`,
  basic: basicShader,
  basicColor: basicColorShader,
  basicPremultiply: basicPremultiplyShader,
  sprite: spriteShader,
  unpremultiply: unpremultiplyShader,
  test: testShader,
  solidColor: solidColorShader,
  ants: antsShader,
  antsAi: antAiShader,
  line: lineShader,
  vertexColor: vertexColorShader,
  circle: circleShader,
  circleOpacity: `#define OPACITY\n${circleShader}`,
  circleOutline: circleOutlineShader,
  imageBrush: imageBrushShader,
  imageBrushCursor: imageBrushCursorShader,
  imageBrushOpacity: `#define OPACITY\n${imageBrushShader}`,
  ellipse: ellipseShader,
  ellipseOutline: ellipseOutlineShader,
  rectOutline: rectOutlineShader,
  softBrush1: softBrush1Shader,
  softBrush2: softBrush2Shader,
  softBrush2Opacity: `#define OPACITY\n${softBrush2Shader}`,
  fastNormal: fastNormalShader,
  fastNormalWithMask: `#define MASK\n${fastNormalShader}`,
  fastMove: fastMoveShader,
  fastMoveWithMask: `#define MASK\n${fastMoveShader}`,
  fastBrush: fastBrushShader,
  fastBrushWithMask: `#define MASK\n${fastBrushShader}`,
  fastBrushOpacityLocked: `#define OPACITY_LOCKED\n${fastBrushShader}`,
  fastBrushOpacityLockedWithMask: `#define OPACITY_LOCKED\n#define MASK\n${fastBrushShader}`,
  fastEraser: fastEraserShader,
  fastEraserWithMask: `#define MASK\n${fastEraserShader}`,
  normal: basicFragmentShader(`return src;`),
  darken: basicFragmentShader(`return min(src, dst);`),
  multiply: basicFragmentShader(`return src * dst;`), // TODO: use optimized version ?
  'color burn': basicFragmentShader(`
    vec3 res;
    res.r = (dst.r == 1.0) ? 1.0 : ((src.r != 0.0) ? 1.0 - min(1.0, (1.0 - dst.r) / src.r) : 0.0);
    res.g = (dst.g == 1.0) ? 1.0 : ((src.g != 0.0) ? 1.0 - min(1.0, (1.0 - dst.g) / src.g) : 0.0);
    res.b = (dst.b == 1.0) ? 1.0 : ((src.b != 0.0) ? 1.0 - min(1.0, (1.0 - dst.b) / src.b) : 0.0);
    return res;
  `),
  aiMask: baseFragmentShader(`vec4 blended = blend(src, dst); `,
    ` vec4 blend(vec4 src, vec4 dst) {
      return vec4(1.0 - src[3], 1.0 - src[3], 1.0 - src[3], 1);
    }`),
  aiMaskInpaint: aiInpaintMaskShader,
  aiMaskOutpaint: aiOutpaintMaskShader,
  // color burn
  // vec3 result;
  // result.r = (src.r == 0.0) ? src.r : max((1.0 - ((1.0 - dst.r) / src.r)), 0.0);
  // result.g = (src.g == 0.0) ? src.g : max((1.0 - ((1.0 - dst.g) / src.g)), 0.0);
  // result.b = (src.b == 0.0) ? src.b : max((1.0 - ((1.0 - dst.b) / src.b)), 0.0);

  lighten: basicFragmentShader(`return max(src, dst);`),
  screen: basicFragmentShader(`return ONE - (ONE - dst) * (ONE - src);`),
  // is forZero needed here ?
  'color dodge': basicFragmentShader(`
    vec3 res;
    res.r = (dst.r == 0.0) ? 0.0 : ((src.r == 1.0) ? 1.0 : min(dst.r / (1.0 - src.r), 1.0));
    res.g = (dst.g == 0.0) ? 0.0 : ((src.g == 1.0) ? 1.0 : min(dst.g / (1.0 - src.g), 1.0));
    res.b = (dst.b == 0.0) ? 0.0 : ((src.b == 1.0) ? 1.0 : min(dst.b / (1.0 - src.b), 1.0));
    return res;
  `),
  // screen
  // vec3 result;
  // result.r = (src.r==1.0)?src.r:min(dst.r/(1.0-src.r),1.0);
  // result.g = (src.g==1.0)?src.g:min(dst.g/(1.0-src.g),1.0);
  // result.b = (src.b==1.0)?src.b:min(dst.b/(1.0-src.b),1.0);
  // return result;

  // screen
  // vec3 stepValues = step(ONE, src);
  // vec3 regular = min(dst / (ONE - src), ONE);
  // vec3 forZero = src;
  // return mix(regular, forZero, stepValues);

  // TODO: fix steps for <=
  //  (Target > 0.5) * (1 - (1-2*(Target-0.5)) * (1-Blend)) +
  // (Target <= 0.5) * ((2*Target) * Blend)
  overlay: basicFragmentShader(`
    vec3 steps = step(HALF, dst);
    vec3 above05 = (ONE - (ONE - 2.0 * (dst - HALF)) * (ONE - src));
    vec3 below05 = 2.0 * dst * src;
    return mix(below05, above05, steps);
  `),
  //  (Blend > 0.5) * (1 - (1-Target) * (1-(Blend-0.5))) +
  // (Blend <= 0.5) * (Target * (Blend+0.5))
  'soft light': basicFragmentShader(`
    vec3 below05 = 2.0 * dst * src + dst * dst * (ONE - 2.0 * src);
    vec3 above05 = sqrt(dst) * (2.0 * src - ONE) + 2.0 * dst * (ONE - src);
    return mix(below05, above05, step(HALF, src));
  `),
  // soft light
  // vec3 steps = step(HALF, src);
  // vec3 above05 = ONE - (ONE - dst) * (ONE - (src - HALF));
  // vec3 below05 = dst * (src + HALF);
  // return mix(below05, above05, steps);

  //  (Blend > 0.5) * (1 - (1-Target) * (1-2*(Blend-0.5))) +
  // (Blend <= 0.5) * (Target * (2*Blend))
  'hard light': basicFragmentShader(`
    vec3 steps = step(HALF, src);
    vec3 above05 = ONE - (ONE - dst) * (ONE - 2.0 * (src - HALF));
    vec3 below05 = dst * (2.0 * src);
    return mix(below05, above05, steps);
  `),
  difference: basicFragmentShader(`return abs(dst - src);`),
  exclusion: basicFragmentShader(`return HALF - 2.0 * (dst - HALF) * (src - HALF);`),
  hue: hslFragmentShaderNew(`
    vec3 result = src * da;
    result = setSat(result, getSat(dst) * sa);
    return setLum(result, sa * da, getLum(dst) * sa);
  `),
  saturation: hslFragmentShaderNew(`
    vec3 result = dst * sa;
    result = setSat(result, getSat(src) * da);
    return setLum(result, sa * da, getLum(dst) * sa);
  `),
  color: hslFragmentShaderNew(`
    return setLum(src * sa, sa * da, getLum(dst) * sa);
  `),
  luminosity: hslFragmentShaderNew(`
    return setLum(dst * sa, sa * da, getLum(src) * da);
  `),
  clipping: clippingShader,
  mask: maskShader,
  maskPremultiply: maskPremultiplyShader,
  maskUnpremultiply: maskUnpremultiplyShader,
  checker: checkerShader,
  grid: gridShader,
  mesh: meshShader,
  hueSaturationLightnessShader: hueSaturationLightnessShader,
  brightnessContrastShader: brightnessContrastShader,
  curvesShader: curvesShader,
  motionBlurShader: motionBlurShader,
  perspectiveGridLine: perspectiveGridLineShader,
};

for (const key of Object.keys(shaders)) {
  let shader = shaders[key];

  if (shader.indexOf('#ifdef TOOL') !== -1) {
    // tool shaders: 16, total: 48
    shaders[`${key}WithDraw`] = `#define TOOL\n${shader}`;
    shaders[`${key}WithDrawAndMask`] = `#define TOOL\n#define MASKS\n${shader}`;
    // clip: +32
    // shaders[`${key}WithDrawClip`] = `#define CLIP\n#define TOOL\n${shader}`;
    // shaders[`${key}WithDrawAndMaskClip`] = `#define CLIP\n#define TOOL\n#define MASKS\n${shader}`;
  }
}
