module Render.Code.Lit
  ( raySphereIntersection
  , hgPhase
  , structLight
  , structMaterial
  , shadowFuns
  , litMain
  , brdfSpecular
  ) where

import Render.Code (Code(..), trimming)

raySphereIntersection :: Code
raySphereIntersection :: Code
raySphereIntersection = Text -> Code
Code
  [trimming|
    vec2 raySphereIntersection(vec3 rayOrigin, vec3 rayDir, vec3 sphereCenter, float sphereRadius) {
      vec3 tmp = rayOrigin - sphereCenter;

      float b = dot(rayDir, tmp);
      float c = dot(tmp, tmp) - sphereRadius * sphereRadius;

      float disc = b * b - c;

      if(disc < 0.0) return vec2(-M_MAX, -M_MAX);

      float disc_sqrt = sqrt(disc);

      float t0 = -b - disc_sqrt;
      float t1 = -b + disc_sqrt;

      return vec2(t0, t1);
    }
  |]

-- | The Henyey-Greenstein Phase Function
hgPhase :: Code
hgPhase :: Code
hgPhase = Text -> Code
Code
  [trimming|
    float hgPhase(float nu, float g) {
      float g2 = g * g;
      return
        (
          3.0 *
          (1.0 - g2) *
          (1.0 + nu * nu)
        ) /
        (
          2.0 *
          (2.0 + g2) *
          pow(
            1.0 + g2 - 2.0 * g * nu,
            1.5
          )
        );
    }
  |]

structLight :: Code
structLight :: Code
structLight = Text -> Code
Code
  [trimming|
    struct Light {
      mat4 viewProjection; // bring model positions into light-space
      vec4 shadow;         // offset-x, offset-y, shadowmap index, size
      vec4 position;       // alpha: unused
      vec4 direction;      // alpha: unused
      vec4 color;          // alpha: energy
      // vec2 cutoff;      // inner / outer
    };
  |]

structMaterial :: Code
structMaterial :: Code
structMaterial = Text -> Code
Code
  [trimming|
    struct Material {
      vec4 baseColor;
      vec2 metallicRoughness;
      vec4 emissive;
      float normalScale;
      float alphaCutoff;

      int baseColorTex;
      int metallicRoughnessTex;
      int emissiveTex;
      int normalTex;
      int ambientOcclusionTex;
    };
  |]

shadowFuns :: Code
shadowFuns :: Code
shadowFuns = Text -> Code
Code
  [trimming|
    float shadow_factor(vec3 shadowCoord, float mapIx, vec2 offset) {
      if (abs(shadowCoord.x) > 1.0 ||
          abs(shadowCoord.y) > 1.0 ||
          abs(shadowCoord.z) > 1.0)
            return 0.0; // XXX: 1.0 would be better for directional

      vec4 uvwi = vec4(shadowCoord.xy * 0.5 + 0.5 + offset, mapIx, shadowCoord.z);
      return texture(shadowmaps, uvwi);
    }

    float filterPCF(vec3 shadowCoord, float mapIx) {
      float shadowFactor = 0.0;
      int count = 0;
      int range = 1;

      for (int x = -range; x <= range; x++) {
        for (int y = -range; y <= range; y++) {
          shadowFactor += shadow_factor(
            shadowCoord,
            mapIx,
            vec2(x, y) * PCF_STEP
          );
          count++;
        }
      }
      return shadowFactor / count;
    }
  |]

litMain :: Code
litMain :: Code
litMain = Text -> Code
Code
  [trimming|
    vec3 albedo =
      // XXX: not needed, we're in linear already
      // pow(baseColor.rgb, vec3(2.2));
      baseColor.rgb;
    vec3 F0 = mix(vec3(0.04), albedo, metallic);

    vec3 ray = scene.viewPosition.xyz - fPosition.xyz;
    vec3 rayDir = normalize(ray); // V
    float quadrance = dot(ray, ray);
    float distance = sqrt(quadrance);

    // XXX: provided by caller
    // vec3 normal = normalize(fNormal); // N

    vec3 Lo = vec3(0.0);
    for (int l = 0; l < scene.numLights; l++) {
      // XXX: directional lights' hit angle doesn't depend on fragment position
      vec3 lightDir = normalize(lights[l].direction.xyz); // L

      float shade = 1.0; // XXX: 0 - occluded, 1 - lit
      if (lights[l].shadow.w > 0) {
        vec4 light_space_pos = lights[l].viewProjection * fPosition;
        vec4 shadowCoord = light_space_pos /= light_space_pos.w;

        // TODO: pick on specialization constant
        shade = filterPCF(shadowCoord.xyz, lights[l].shadow.z);
        // shade = shadow_factor(shadowCoord.xyz, lights[l].shadow.z, vec2(0));
      }

      Lo += brdfSpecular(
        lightDir,
        ray,
        normal,
        F0,
        metallic,
        roughness,
        albedo,
        lights[l].color.rgb * lights[l].color.a
      ) * shade;
    }

    vec3 reflection = prefilteredReflection(reflect(rayDir, normal), roughness).rgb;

    // IBL
    vec3 irradiance = vec3(0);
    if (scene.envCubeId > -1) {
      irradiance = textureLod(
        samplerCube(
          cubes[nonuniformEXT(scene.envCubeId)],
          samplers[0]
        ),
        -normal,
        IRRADIANCE_LOD
      ).rgb;
    }

    // Specular reflectance
    vec2 ibl = texture(
      sampler2D(
        textures[BRDF_LUT],
        samplers[BRDF_LUT_SAMPLER]
      ),
      vec2(roughness, max(dot(normal, rayDir), 0.0))
    ).rg;
    vec3 F = F_SchlickR(max(dot(normal, rayDir), 0.0), F0, roughness);
    vec3 specular = nonOcclusion * reflection * (F * ibl.x + ibl.y);

    vec3 kD = (1.0 - F) * (1.0 - metallic);

    // Diffuse based on irradiance
    vec3 diffuseI = nonOcclusion * irradiance * albedo;
    vec3 ambient = kD * diffuseI + specular;

    // Combine with ambient
    vec3 color = Lo + ambient;

    // Tone mapping
    color =
      Uncharted2Tonemap(color * 4.5) /
      Uncharted2Tonemap(vec3(11.2)); // White point

    // Gamma correction
    // XXX: not needed, we're in linear already
    // color = pow(color, vec3(1.0/2.2));

    // Happily ever after
    oColor = vec4(color, baseColor.a);
  |]

brdfSpecular :: Code
brdfSpecular :: Code
brdfSpecular = Text -> Code
Code
  [trimming|
    // TODO: unhardcode
    const int BRDF_LUT = 2;
    const int BRDF_LUT_SAMPLER = 3; // linear/mip0/no-repeat
    const float MAX_REFLECTION_LOD = 9.0; // todo: param/const
    const float IRRADIANCE_LOD = 10.0; // todo: param/const

    // Normal Distribution function --------------------------------------
    float D_GGX(float dotNH, float roughness) {
      float alpha = roughness * roughness;
      float alpha2 = alpha * alpha;
      float denom = dotNH * dotNH * (alpha2 - 1.0) + 1.0;
      return (alpha2) / (3.14159265359 * denom*denom);
    }

    // Geometric Shadowing function --------------------------------------
    float G_SchlicksmithGGX(float dotNL, float dotNV, float roughness) {
      float r = (roughness + 1.0);
      float k = (r*r) / 8.0;
      float GL = dotNL / (dotNL * (1.0 - k) + k);
      float GV = dotNV / (dotNV * (1.0 - k) + k);
      return GL * GV;
    }

    // Fresnel function ----------------------------------------------------
    vec3 F_Schlick(float cosTheta, vec3 F0) {
      return F0 + (1.0 - F0) * pow(1.0 - cosTheta, 5.0);
    }

    vec3 F_SchlickR(float cosTheta, vec3 F0, float roughness) {
      return F0 + (max(vec3(1.0 - roughness), F0) - F0) * pow(1.0 - cosTheta, 5.0);
    }

    vec3 prefilteredReflection(vec3 R, float roughness) {
      vec3 color = vec3(0);

      if (scene.envCubeId > -1) {
        float lod = roughness * MAX_REFLECTION_LOD;
        float lodf = floor(lod);
        float lodc = ceil(lod);

        vec3 a = textureLod(
          samplerCube(
            cubes[nonuniformEXT(scene.envCubeId)],
            samplers[0]
          ),
          R,
          lodf
        ).rgb;

        vec3 b = textureLod(
          samplerCube(
            cubes[nonuniformEXT(scene.envCubeId)],
            samplers[0]
          ),
          R,
          lodc
        ).rgb;

        return mix(a, b, lod - lodf);


//        color = texture(
//          samplerCube(
//            cubes[nonuniformEXT(scene.envCubeId)],
//            samplers[2] // XXX: linear/mip0/repeat
//          ),
//          fragUVW,
//          10
//        );
      }
      return color;
    }

    vec3 brdfSpecular(vec3 L, vec3 V, vec3 N, vec3 F0, float metallic, float roughness, vec3 ALBEDO, vec3 lightColor) {
      // Precalculate vectors and dot products
      vec3 H = normalize (V + L);
      float dotNH = clamp(dot(N, H), 0.0, 1.0);
      float dotNV = clamp(dot(N, V), 0.0, 1.0);
      float dotNL = clamp(dot(N, L), 0.0, 1.0);

      vec3 color = vec3(0.0);

      if (dotNL > 0.0) {
        // D = Normal distribution (Distribution of the microfacets)
        float D = D_GGX(dotNH, roughness);
        // G = Geometric shadowing term (Microfacets shadowing)
        float rroughness = max(0.05, roughness);
        float G = G_SchlicksmithGGX(dotNL, dotNV, rroughness);
        // F = Fresnel factor (Reflectance depending on angle of incidence)
        vec3 F = F_Schlick(dotNV, F0);
        vec3 spec = D * F * G / (4.0 * dotNL * dotNV + 0.001);
        vec3 kD = (vec3(1.0) - F) * (1.0 - metallic);
        color += (kD * ALBEDO / 3.1415926535897932384626433832795 + spec) * dotNL * lightColor;
      }

      return color;
    }

    vec3 Uncharted2Tonemap(vec3 color) {
      float A = 0.15;
      float B = 0.50;
      float C = 0.10;
      float D = 0.20;
      float E = 0.02;
      float F = 0.30;
      return ((color * (A * color + C * B) + D * E) / (color * (A * color + B) + D * F)) - E / F;
    }
  |]