module Render.Unlit.TileMap.Code
  ( vert
  , frag
  ) where

import RIO

import Render.Code (Code, glsl)
import Render.Samplers qualified as Samplers
import Render.DescSets.Set0.Code (set0binding0, set0binding1, set0binding2)

vert :: Code
vert :: Code
vert = forall a. IsString a => String -> a
fromString
  [glsl|
    #version 450

    ${set0binding0}

    // vertexPos
    layout(location = 0) in vec3 vPosition;
    // vertexAttrs
    layout(location = 1) in vec2 vTexCoord;
    // tilemapParams
    layout(location = 2) in ivec4 iTextureIds; // combined: tileset, tileset sampler, map, repeat
    layout(location = 3) in vec2  iViewOffset;
    layout(location = 4) in vec2  iViewportSize;
    layout(location = 5) in vec2  iMapTextureSize;
    layout(location = 6) in vec2  iTilesetTextureSize;
    layout(location = 7) in vec2  iTileSize;
    layout(location = 8) in vec2  iTilesetOffset;
    layout(location = 9) in vec2  iTilesetBorder;

    // transformMat
    layout(location = 10) in mat4 iModel;

    layout(location = 0)      out  vec2 fTexCoord;
    layout(location = 1)      out  vec2 fPixCoord;
    layout(location = 2) flat out ivec4 fTextureIds;
    layout(location = 3) flat out  vec2 fTilesetTextureSizeInv;
    layout(location = 4) flat out  vec2 fTileSize;
    layout(location = 5) flat out  vec2 fTilesetOffset;
    layout(location = 6) flat out  vec2 fTilesetBorder;

    void main() {
      vec4 fPosition = iModel * vec4(vPosition, 1.0);

      gl_Position
        = scene.projection
        * scene.view
        * fPosition;

      fPixCoord = (vTexCoord * iViewportSize) + iViewOffset;
      fTexCoord = fPixCoord / iMapTextureSize / iTileSize;

      fTextureIds = iTextureIds;
      fTilesetTextureSizeInv = 1.0 / iTilesetTextureSize;
      fTileSize = iTileSize;
      fTilesetOffset = iTilesetOffset * fTilesetTextureSizeInv;
      fTilesetBorder = iTilesetBorder * fTilesetTextureSizeInv;
    }
  |]

frag :: Code
frag :: Code
frag = forall a. IsString a => String -> a
fromString
  [glsl|
    #version 450
    #extension GL_EXT_nonuniform_qualifier : enable

    ${set0binding1}
    ${set0binding2}

    layout(location = 0) in vec2 fTexCoord;
    layout(location = 1) in vec2 fPixCoord;

    // combined: tileset, tileset sampler, map, repeat
    layout(location = 2) flat in ivec4 fTextureIds;
    layout(location = 3) flat in vec2 fTilesetTextureSizeInv;
    layout(location = 4) flat in vec2 fTileSize;
    layout(location = 5) flat in vec2 fTilesetOffset;
    layout(location = 6) flat in vec2 fTilesetBorder;

    layout(location = 0) out vec4 oColor;

    int tilesetTextureIx = fTextureIds[0];
    int tilesetSamplerIx = fTextureIds[1]; // XXX: can have repeat, but not mips
    int mapTextureIx     = fTextureIds[2];
    int repeatTiles      = fTextureIds[3];

    // TODO
    // const vec4 fTextureGamma = vec4(1.0);

    void main() {
      if (repeatTiles == 0 && (fTexCoord.x < 0.0 || fTexCoord.y < 0.0 || fTexCoord.x > 1.0 || fTexCoord.y > 1.0)) {
        discard;
      }

      vec4 map = textureLod(
        sampler2D(
          textures[nonuniformEXT(mapTextureIx)],
          samplers[$samplerId]
        ),
        fTexCoord,
        0
      );

      vec2 tilePos = floor(map.xy * 256.0);
      vec2 spriteOffset = tilePos * fTileSize;

      vec2 spriteCoord = mod(fPixCoord, fTileSize);
      uint flags = uint(map.w * 256.0);

      // XXX: Anti-diagonal flip first, to allow rotation with H/V flips
      if ((flags & 0x20) == 0x20)
        spriteCoord = spriteCoord.yx;
      if ((flags & 0x40) == 0x40)
        spriteCoord.x = fTileSize.x - spriteCoord.x;
      if ((flags & 0x80) == 0x80)
        spriteCoord.y = fTileSize.y - spriteCoord.y;

      vec2 tilesetPadding = fTilesetOffset + tilePos * fTilesetBorder;

      vec2 spriteUV =
        spriteOffset * fTilesetTextureSizeInv +
        spriteCoord * fTilesetTextureSizeInv +
        tilesetPadding;

      oColor = textureLod(
        sampler2D(
          textures[nonuniformEXT(tilesetTextureIx)],
          samplers[nonuniformEXT(tilesetSamplerIx)]
        ),
        spriteUV,
        0
      );
    }
  |]
  where
    samplerId :: Int32
samplerId = forall a. Collection a -> a
Samplers.nearest Collection Int32
Samplers.indices