{-# OPTIONS_GHC -fplugin Foreign.Storable.Generic.Plugin #-}

{-# LANGUAGE OverloadedLists #-}

module Render.DescSets.Sun
  ( Sun(..)
  , createSet0Ds
  , set0

  , pattern MAX_VIEWS

  , Buffer

  , SunInput(..)
  , initialSunInput

  , Process
  , spawn1
  , mkSun

  , Observer
  , newObserver1
  , observe1
  ) where

import RIO

import Control.Monad.Trans.Resource (MonadResource, ResourceT)
import Data.Tagged (Tagged(..))
import Data.Vector qualified as Vector
import Data.Vector.Storable qualified as VectorS
import Foreign.Storable.Generic (GStorable)
import Geomancy (Vec3, Vec4, vec3, vec4)
import Geomancy.Transform (Transform)
import Geomancy.Transform qualified as Transform
import Geomancy.Quaternion qualified as Quaternion
import Geomancy.Vec4 qualified as Vec4
import Vulkan.Core10 qualified as Vk
import Vulkan.CStruct.Extends (SomeStruct(..))
import Vulkan.NamedType ((:::))
import Vulkan.Zero (Zero(..))

import Engine.Camera qualified as Camera
import Engine.Types (StageRIO)
import Engine.Vulkan.DescSets ()
import Engine.Vulkan.Types (DsLayoutBindings, HasVulkan(..))
import Engine.Worker qualified as Worker
import Resource.Buffer qualified as Buffer
import Resource.Region qualified as Region
import Resource.Vulkan.DescriptorPool qualified as DescriptorPool

-- * Set0 data for light projection

-- | Maximum "guaranteed" amount for multiview passes
pattern MAX_VIEWS :: Int
pattern MAX_VIEWS = 6

data Sun = Sun
  { sunViewProjection :: Transform
  , sunShadow         :: Vec4 -- offsetx, offsety, index, size -- XXX: only index is used
  , sunPosition       :: Vec4 -- XXX: alpha available for stuff
  , sunDirection      :: Vec4 -- XXX: alpha available for stuff
  , sunColor          :: Vec4 -- XXX: RGB premultiplied, alpha is available for stuff
  }
  deriving (Show, Generic)

instance GStorable Sun

instance Zero Sun where
  zero = Sun
    { sunViewProjection = mempty
    , sunShadow         = 0
    , sunPosition       = 0
    , sunDirection      = vec4 0 1 0 0
    , sunColor          = 0
    }

-- * Shadow casting descriptor set

set0
  :: Tagged Sun DsLayoutBindings
set0 = Tagged
  [ (set0bind0, zero)
  ]

set0bind0 :: Vk.DescriptorSetLayoutBinding
set0bind0 = Vk.DescriptorSetLayoutBinding
  { binding           = 0
  , descriptorType    = Vk.DESCRIPTOR_TYPE_UNIFORM_BUFFER
  , descriptorCount   = 1
  , stageFlags        = Vk.SHADER_STAGE_VERTEX_BIT
  , immutableSamplers = mempty
  }

-- * Setup

type Buffer = Buffer.Allocated 'Buffer.Coherent Sun

createSet0Ds
  :: Tagged '[Sun] Vk.DescriptorSetLayout
  -> ResourceT (StageRIO st)
      ( Tagged '[Sun] (Vector Vk.DescriptorSet)
      , Buffer
      )
createSet0Ds (Tagged set0layout) = do
  descPool <- Region.local $
    DescriptorPool.allocate (Just "Basic.Sun") 1
      [ ( Vk.DESCRIPTOR_TYPE_UNIFORM_BUFFER
        , 1 + 1
        )
      ]

  descSets <- DescriptorPool.allocateSetsFrom descPool (Just "Basic.Sun") [set0layout]

  sunData <- Region.local $
    Buffer.allocateCoherent
      (Just "Basic.Sun.Data")
      Vk.BUFFER_USAGE_UNIFORM_BUFFER_BIT MAX_VIEWS
      (VectorS.replicate MAX_VIEWS zero)

  updateSet0Ds (Tagged descSets) sunData

  pure (Tagged descSets, sunData)

updateSet0Ds
  :: Tagged '[Sun] (Vector Vk.DescriptorSet)
  -> Buffer.Allocated 'Buffer.Coherent Sun
  -> ResourceT (StageRIO st) ()
updateSet0Ds (Tagged ds) Buffer.Allocated{aBuffer} = do
  context <- asks id
  Vk.updateDescriptorSets (getDevice context) [writeSet0b0] mempty

  where
    destSet0 = case Vector.headM ds of
      Nothing ->
        error "assert: descriptor sets promised to contain [Sun]"
      Just one ->
        one

    writeSet0b0 = SomeStruct zero
      { Vk.dstSet          = destSet0
      , Vk.dstBinding      = 0
      , Vk.dstArrayElement = 0
      , Vk.descriptorCount = 1
      , Vk.descriptorType  = Vk.DESCRIPTOR_TYPE_UNIFORM_BUFFER
      , Vk.bufferInfo      = [set0bind0I]
      }
      where
        set0bind0I = Vk.DescriptorBufferInfo
          { Vk.buffer = aBuffer
          , Vk.offset = 0
          , Vk.range  = Vk.WHOLE_SIZE
          }

data SunInput = SunInput
  { siColor :: Vec4

  , siInclination :: Float
  , siAzimuth     :: Float
  , siRadius      :: Float
  , siTarget      :: Vec3

  , siDepthRange :: Float
  , siSize       :: Float
  , siShadowIx   :: Float
  }

initialSunInput :: SunInput
initialSunInput = SunInput
  { siColor = vec4 1 1 1 1

  , siInclination = τ/8
  , siAzimuth     = -τ/8
  , siRadius      = Camera.PROJECTION_FAR / 2
  , siTarget      = 0

  , siDepthRange = Camera.PROJECTION_FAR -- TODO: use explicit near/far
  , siSize       = 512
  , siShadowIx   = -1
  }

type Process = Worker.Cell SunInput ("bounding box" ::: Transform, Sun)

spawn1
  :: ( MonadResource m
     , MonadUnliftIO m
     )
  => SunInput
  -> m Process
spawn1 = Worker.spawnCell mkSun

mkSun :: SunInput -> ("bounding box" ::: Transform, Sun)
mkSun SunInput{..} =
  ( bbTransform
  , Sun
      { sunViewProjection = pv
      , sunShadow         = vec4 0 0 siShadowIx siSize
      , sunPosition       = Vec4.fromVec3 position 0
      , sunDirection      = Vec4.fromVec3 direction 0
      , sunColor          = siColor
      }
  )
  where
    pv = p <> v
    p = -- just a scaled ortho box
      Transform.scale3
        (1 / siSize)
        (1 / siSize)
        (1 / siDepthRange)
    v =
      Transform.dirPos
        (Quaternion.extrinsic (-siAzimuth) (-siInclination) 0)
        (vec3 0 0 siRadius)

    bbTransform = mconcat
      [ Transform.inverse pv
      , Transform.translate 0 0 0.5 -- shift origin to the near face
      , Transform.rotateX (τ/4) -- put the green side towards the light
      ]

    position = rotation `Quaternion.rotate` vec3 0 0 siRadius
    direction = rotation `Quaternion.rotate` vec3 0 0 (-1)
    rotation = Quaternion.intrinsic 0 (siInclination) (siAzimuth)

type Observer = Worker.ObserverIO (VectorS.Vector ("bounding box" ::: Transform))

newObserver1 :: MonadIO m => m Observer
newObserver1 = Worker.newObserverIO mempty

observe1 :: MonadUnliftIO m => Process -> Observer -> Buffer -> m ()
observe1 sunP sunOut sunData =
  Worker.observeIO_ sunP sunOut \_oldBB (bb, sun) -> do
    -- XXX: must stay the same or descsets must be updated with a new buffer
    _same <- Buffer.updateCoherent (VectorS.singleton sun) sunData
    pure $ VectorS.singleton bb

τ :: Float
τ = 2 * pi
