module Render.Lit.Material.Code
  ( vert
  , frag
  ) where

import RIO

import Render.Code (Code, glsl)
import Render.Code.Lit (litMain, shadowFuns, structLight, structMaterial, brdfSpecular)
import Render.DescSets.Set0.Code (set0binding0, set0binding1, set0binding2, set0binding3, set0binding4, set0binding5, set0binding6)

vert :: Code
vert :: Code
vert = forall a. IsString a => String -> a
fromString
  [glsl|
    #version 450

    invariant gl_Position;

    ${set0binding0}

    // vertexPos
    layout(location = 0) in vec3 vPosition;
    // vertexAttrs
    layout(location = 1) in vec2 vTexCoord0;
    layout(location = 2) in vec2 vTexCoord1;
    layout(location = 3) in vec3 vNormal;
    layout(location = 4) in vec3 vTangent;
    layout(location = 5) in uint vMaterial;

    // transformMat
    layout(location = 6) in mat4 iModel;

    layout(location = 0)      out  vec4 fPosition;
    layout(location = 1)      out  vec2 fTexCoord0;
    layout(location = 2)      out  vec2 fTexCoord1;
    layout(location = 3) flat out  uint fMaterial;
    layout(location = 4)      out  mat3 fTBN;

    void main() {
      fPosition = iModel * vec4(vPosition, 1.0);

      gl_Position
        = scene.projection
        * scene.view
        * fPosition;

      fTexCoord0 = vTexCoord0;
      fTexCoord1 = vTexCoord1;

      vec3 t = normalize(vec3(iModel * vec4(vTangent, 0.0)));
      vec3 n = normalize(vec3(iModel * vec4(vNormal, 0.0)));
      vec3 to = normalize(t - dot(t, n) * n); // re-orthogonalize T with respect to N
      fTBN = mat3(to, cross(n, to), n);

      fMaterial = vMaterial;
    }
  |]

frag :: Code
frag :: Code
frag = forall a. IsString a => String -> a
fromString
  [glsl|
    #version 450
    #extension GL_EXT_nonuniform_qualifier : enable

    layout(early_fragment_tests) in;

    // XXX: copypasta from Lit.Colored
    // TODO: move to spec constant
    const uint MAX_LIGHTS = 255;
    const float PCF_STEP = 1.5 / 4096;

    const uint MAX_MATERIALS = 2048;

    ${structLight}
    ${structMaterial}

    ${set0binding0}
    ${set0binding1}
    ${set0binding2}
    ${set0binding3}
    ${set0binding4} // lights
    ${set0binding5} // shadowmap
    ${set0binding6} // materials

    layout(location = 0)      in vec4 fPosition;
    layout(location = 1)      in vec2 fTexCoord0;
    layout(location = 2)      in vec2 fTexCoord1;
    layout(location = 3) flat in uint fMaterial;
    layout(location = 4)      in mat3 fTBN;

    layout(location = 0) out vec4 oColor;

    ${shadowFuns}
    ${brdfSpecular}

    void main() {
      Material material = materials[fMaterial];
      vec4 baseColor = material.baseColor;
      float metallic = material.metallicRoughness[0];
      float roughness = material.metallicRoughness[1];
      float nonOcclusion = 1.0;
      vec4 emissive = material.emissive;

      if (material.baseColorTex > -1) {
        baseColor *= texture(
          sampler2D(
            textures[nonuniformEXT(material.baseColorTex)],
            samplers[0]
          ),
          fTexCoord0
        );
      }

      if (baseColor.a < material.alphaCutoff) {
        discard;
      }

      baseColor.rgb *= baseColor.a;

      if (material.metallicRoughnessTex > -1) {
        vec3 packed = texture(
          sampler2D(
            textures[nonuniformEXT(material.metallicRoughnessTex)],
            samplers[0]
          ),
          fTexCoord0
        ).rgb;
        // XXX: assuming sRGB textures, even for AMR
        packed = pow(packed, vec3(1.0/2.2));
        nonOcclusion -= packed.r;
        metallic *= packed.b;
        roughness *= packed.g;
      }

      // // TODO: combine with MR as channel R.
      // float occlusion = texture(
      //   sampler2D(
      //     textures[nonuniformEXT(max(0, material.ambientOcclusionTex))],
      //     samplers[0]
      //   ),
      //   fTexCoord0
      // ).r;
      // nonOcclusion -= pow(occlusion, 1.0/2.2);

      if (material.emissiveTex > -1) {
        emissive *= texture(
          sampler2D(
            textures[nonuniformEXT(material.emissiveTex)],
            samplers[0]
          ),
          fTexCoord0
        );
      }

      vec3 normal = fTBN[2];
      if (material.normalTex > -1) {
        vec3 normalsColor = texture(
          sampler2D(
            textures[nonuniformEXT(material.normalTex)],
            samplers[0]
          ),
          fTexCoord0
        ).rgb;

        // XXX: convert normal non-colors to linear values from sRGB texture colorspace
        vec3 normals = pow(normalsColor, vec3(1.0/2.2)) * 2.0 - 1.0;

        normal = normalize(fTBN * normals);
      } else {
        normal = normalize(normal);
      }

      ${litMain}

      oColor.rgb += pow(emissive.rgb, vec3(2.2));
    }
  |]