{-# LANGUAGE OverloadedLists #-}

-- XXX: TypeError in Compatible generates unused constraint argument
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Engine.Vulkan.Pipeline.Compute
  ( Config(..)
  , Configure

  , Pipeline(..)
  , allocate
  , create
  , destroy

  , bind
  , Compute
  ) where

import RIO

import Data.Kind (Type)
import Data.List qualified as List
import Data.Tagged (Tagged(..))
import Data.Vector qualified as Vector
import GHC.Stack (callStack, getCallStack, srcLocModule, withFrozenCallStack)
import UnliftIO.Resource (MonadResource, ReleaseKey)
import UnliftIO.Resource qualified as Resource
import Vulkan.Core10 qualified as Vk
import Vulkan.Core12.Promoted_From_VK_EXT_descriptor_indexing qualified as Vk12
import Vulkan.CStruct.Extends (SomeStruct(..), pattern (:&), pattern (::&))
import Vulkan.Utils.Debug qualified as Debug
import Vulkan.Zero (Zero(..))

import Engine.Vulkan.DescSets (Bound(..), Compatible)
import Engine.Vulkan.Types (HasVulkan(..), MonadVulkan, DsBindings, getPipelineCache)

import Engine.Vulkan.Pipeline (Pipeline(..), destroy)
import Engine.Vulkan.Shader qualified as Shader

data Config (dsl :: [Type]) spec = Config
  { forall (dsl :: [*]) spec. Config dsl spec -> ByteString
cComputeCode        :: ByteString
  , forall (dsl :: [*]) spec.
Config dsl spec -> Tagged dsl [DsBindings]
cDescLayouts        :: Tagged dsl [DsBindings]
  , forall (dsl :: [*]) spec.
Config dsl spec -> Vector PushConstantRange
cPushConstantRanges :: Vector Vk.PushConstantRange
  , forall (dsl :: [*]) spec. Config dsl spec -> spec
cSpecialization     :: spec
  }

data Compute

type family Configure pipeline spec where
  Configure (Pipeline dsl Compute Compute) spec = Config dsl spec

allocate
  :: ( MonadVulkan env m
     , MonadResource m
     , HasCallStack
     , Shader.Specialization spec
     )
  => Config dsl spec
  -> m (ReleaseKey, Pipeline dsl Compute Compute)
allocate :: forall env (m :: * -> *) spec (dsl :: [*]).
(MonadVulkan env m, MonadResource m, HasCallStack,
 Specialization spec) =>
Config dsl spec -> m (ReleaseKey, Pipeline dsl Compute Compute)
allocate Config dsl spec
config = (HasCallStack => m (ReleaseKey, Pipeline dsl Compute Compute))
-> m (ReleaseKey, Pipeline dsl Compute Compute)
forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack do
  env
ctx <- m env
forall r (m :: * -> *). MonadReader r m => m r
ask
  IO (Pipeline dsl Compute Compute)
-> (Pipeline dsl Compute Compute -> IO ())
-> m (ReleaseKey, Pipeline dsl Compute Compute)
forall (m :: * -> *) a.
MonadResource m =>
IO a -> (a -> IO ()) -> m (ReleaseKey, a)
Resource.allocate
    (env -> Config dsl spec -> IO (Pipeline dsl Compute Compute)
forall ctx (m :: * -> *) spec (dsl :: [*]).
(HasVulkan ctx, MonadUnliftIO m, Specialization spec) =>
ctx -> Config dsl spec -> m (Pipeline dsl Compute Compute)
create env
ctx Config dsl spec
config)
    (env -> Pipeline dsl Compute Compute -> IO ()
forall (io :: * -> *) ctx (dsl :: [*]) vertices instances.
(MonadIO io, HasVulkan ctx) =>
ctx -> Pipeline dsl vertices instances -> io ()
destroy env
ctx)

create
  :: ( HasVulkan ctx
     , MonadUnliftIO m
     , Shader.Specialization spec
     )
  => ctx
  -> Config dsl spec
  -> m (Pipeline dsl Compute Compute)
create :: forall ctx (m :: * -> *) spec (dsl :: [*]).
(HasVulkan ctx, MonadUnliftIO m, Specialization spec) =>
ctx -> Config dsl spec -> m (Pipeline dsl Compute Compute)
create ctx
context Config{spec
ByteString
Vector PushConstantRange
Tagged dsl [DsBindings]
cSpecialization :: spec
cPushConstantRanges :: Vector PushConstantRange
cDescLayouts :: Tagged dsl [DsBindings]
cComputeCode :: ByteString
$sel:cSpecialization:Config :: forall (dsl :: [*]) spec. Config dsl spec -> spec
$sel:cPushConstantRanges:Config :: forall (dsl :: [*]) spec.
Config dsl spec -> Vector PushConstantRange
$sel:cDescLayouts:Config :: forall (dsl :: [*]) spec.
Config dsl spec -> Tagged dsl [DsBindings]
$sel:cComputeCode:Config :: forall (dsl :: [*]) spec. Config dsl spec -> ByteString
..} = do
  -- XXX: copypasta from Pipeline.create
  let
    originModule :: ByteString
originModule =
      String -> ByteString
forall a. IsString a => String -> a
fromString (String -> ByteString)
-> ([String] -> String) -> [String] -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
List.intercalate String
"|" ([String] -> ByteString) -> [String] -> ByteString
forall a b. (a -> b) -> a -> b
$
        ((String, SrcLoc) -> String) -> [(String, SrcLoc)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (SrcLoc -> String
srcLocModule (SrcLoc -> String)
-> ((String, SrcLoc) -> SrcLoc) -> (String, SrcLoc) -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String, SrcLoc) -> SrcLoc
forall a b. (a, b) -> b
snd) (CallStack -> [(String, SrcLoc)]
getCallStack CallStack
HasCallStack => CallStack
callStack)

  Vector DescriptorSetLayout
dsLayouts <- Vector DsBindings
-> (DsBindings -> m DescriptorSetLayout)
-> m (Vector DescriptorSetLayout)
forall (m :: * -> *) a b.
Monad m =>
Vector a -> (a -> m b) -> m (Vector b)
Vector.forM ([DsBindings] -> Vector DsBindings
forall a. [a] -> Vector a
Vector.fromList ([DsBindings] -> Vector DsBindings)
-> [DsBindings] -> Vector DsBindings
forall a b. (a -> b) -> a -> b
$ Tagged dsl [DsBindings] -> [DsBindings]
forall {k} (s :: k) b. Tagged s b -> b
unTagged Tagged dsl [DsBindings]
cDescLayouts) \DsBindings
bindsFlags -> do
    let
      ([DescriptorSetLayoutBinding]
binds, [DescriptorBindingFlags]
flags) = DsBindings
-> ([DescriptorSetLayoutBinding], [DescriptorBindingFlags])
forall a b. [(a, b)] -> ([a], [b])
List.unzip DsBindings
bindsFlags

      setCI :: DescriptorSetLayoutCreateInfo
  '[DescriptorSetLayoutBindingFlagsCreateInfo]
setCI =
        DescriptorSetLayoutCreateInfo '[]
forall a. Zero a => a
zero
          { $sel:bindings:DescriptorSetLayoutCreateInfo :: Vector DescriptorSetLayoutBinding
Vk.bindings = [DescriptorSetLayoutBinding] -> Vector DescriptorSetLayoutBinding
forall a. [a] -> Vector a
Vector.fromList [DescriptorSetLayoutBinding]
binds
          }
        DescriptorSetLayoutCreateInfo '[]
-> Chain '[DescriptorSetLayoutBindingFlagsCreateInfo]
-> DescriptorSetLayoutCreateInfo
     '[DescriptorSetLayoutBindingFlagsCreateInfo]
forall (a :: [*] -> *) (es :: [*]) (es' :: [*]).
Extensible a =>
a es' -> Chain es -> a es
::& DescriptorSetLayoutBindingFlagsCreateInfo
forall a. Zero a => a
zero
          { $sel:bindingFlags:DescriptorSetLayoutBindingFlagsCreateInfo :: Vector DescriptorBindingFlags
Vk12.bindingFlags = [DescriptorBindingFlags] -> Vector DescriptorBindingFlags
forall a. [a] -> Vector a
Vector.fromList [DescriptorBindingFlags]
flags
          }
        DescriptorSetLayoutBindingFlagsCreateInfo
-> Chain '[] -> Chain '[DescriptorSetLayoutBindingFlagsCreateInfo]
forall e (es :: [*]). e -> Chain es -> Chain (e : es)
:& ()

    Device
-> DescriptorSetLayoutCreateInfo
     '[DescriptorSetLayoutBindingFlagsCreateInfo]
-> ("allocator" ::: Maybe AllocationCallbacks)
-> m DescriptorSetLayout
forall (a :: [*]) (io :: * -> *).
(Extendss DescriptorSetLayoutCreateInfo a, PokeChain a,
 MonadIO io) =>
Device
-> DescriptorSetLayoutCreateInfo a
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io DescriptorSetLayout
Vk.createDescriptorSetLayout Device
device DescriptorSetLayoutCreateInfo
  '[DescriptorSetLayoutBindingFlagsCreateInfo]
setCI "allocator" ::: Maybe AllocationCallbacks
forall a. Maybe a
Nothing

  -- TODO: get from outside
  PipelineLayout
layout <- Device
-> PipelineLayoutCreateInfo
-> ("allocator" ::: Maybe AllocationCallbacks)
-> m PipelineLayout
forall (io :: * -> *).
MonadIO io =>
Device
-> PipelineLayoutCreateInfo
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io PipelineLayout
Vk.createPipelineLayout Device
device (Vector DescriptorSetLayout -> PipelineLayoutCreateInfo
layoutCI Vector DescriptorSetLayout
dsLayouts) "allocator" ::: Maybe AllocationCallbacks
forall a. Maybe a
Nothing
  Device -> PipelineLayout -> ByteString -> m ()
forall a (m :: * -> *).
(HasObjectType a, MonadIO m) =>
Device -> a -> ByteString -> m ()
Debug.nameObject Device
device PipelineLayout
layout ByteString
originModule

  -- Compute stuff begins...

  Shader
shader <- spec -> (Maybe SpecializationInfo -> m Shader) -> m Shader
forall spec (m :: * -> *) a.
(Specialization spec, MonadUnliftIO m) =>
spec -> (Maybe SpecializationInfo -> m a) -> m a
Shader.withSpecialization spec
cSpecialization ((Maybe SpecializationInfo -> m Shader) -> m Shader)
-> (Maybe SpecializationInfo -> m Shader) -> m Shader
forall a b. (a -> b) -> a -> b
$
    ctx
-> Vector (ShaderStageFlagBits, ByteString)
-> Maybe SpecializationInfo
-> m Shader
forall (io :: * -> *) ctx.
(MonadIO io, HasVulkan ctx) =>
ctx
-> Vector (ShaderStageFlagBits, ByteString)
-> Maybe SpecializationInfo
-> io Shader
Shader.create
      ctx
context
      [(ShaderStageFlagBits
Vk.SHADER_STAGE_COMPUTE_BIT, ByteString
cComputeCode)]

  let
    cis :: Vector (SomeStruct ComputePipelineCreateInfo)
cis = SomeStruct ComputePipelineCreateInfo
-> Vector (SomeStruct ComputePipelineCreateInfo)
forall a. a -> Vector a
Vector.singleton (SomeStruct ComputePipelineCreateInfo
 -> Vector (SomeStruct ComputePipelineCreateInfo))
-> (ComputePipelineCreateInfo '[]
    -> SomeStruct ComputePipelineCreateInfo)
-> ComputePipelineCreateInfo '[]
-> Vector (SomeStruct ComputePipelineCreateInfo)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ComputePipelineCreateInfo '[]
-> SomeStruct ComputePipelineCreateInfo
forall (a :: [*] -> *) (es :: [*]).
(Extendss a es, PokeChain es, Show (Chain es)) =>
a es -> SomeStruct a
SomeStruct (ComputePipelineCreateInfo '[]
 -> Vector (SomeStruct ComputePipelineCreateInfo))
-> ComputePipelineCreateInfo '[]
-> Vector (SomeStruct ComputePipelineCreateInfo)
forall a b. (a -> b) -> a -> b
$
      Vector (SomeStruct PipelineShaderStageCreateInfo)
-> PipelineLayout -> ComputePipelineCreateInfo '[]
forall {l}.
(IsList l, Item l ~ SomeStruct PipelineShaderStageCreateInfo) =>
l -> PipelineLayout -> ComputePipelineCreateInfo '[]
pipelineCI (Shader -> Vector (SomeStruct PipelineShaderStageCreateInfo)
Shader.sPipelineStages Shader
shader) PipelineLayout
layout

  Device
-> PipelineCache
-> Vector (SomeStruct ComputePipelineCreateInfo)
-> ("allocator" ::: Maybe AllocationCallbacks)
-> m (Result, "pipelines" ::: Vector Pipeline)
forall (io :: * -> *).
MonadIO io =>
Device
-> PipelineCache
-> Vector (SomeStruct ComputePipelineCreateInfo)
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io (Result, "pipelines" ::: Vector Pipeline)
Vk.createComputePipelines Device
device PipelineCache
cache Vector (SomeStruct ComputePipelineCreateInfo)
cis "allocator" ::: Maybe AllocationCallbacks
forall a. Maybe a
Nothing m (Result, "pipelines" ::: Vector Pipeline)
-> ((Result, "pipelines" ::: Vector Pipeline)
    -> m (Pipeline dsl Compute Compute))
-> m (Pipeline dsl Compute Compute)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    (Result
Vk.SUCCESS, "pipelines" ::: Vector Pipeline
pipelines) ->
      case "pipelines" ::: Vector Pipeline
pipelines of
        [Item ("pipelines" ::: Vector Pipeline)
one] -> do
          ctx -> Shader -> m ()
forall (io :: * -> *) ctx.
(MonadIO io, HasVulkan ctx) =>
ctx -> Shader -> io ()
Shader.destroy ctx
context Shader
shader
          Device -> Pipeline -> ByteString -> m ()
forall a (m :: * -> *).
(HasObjectType a, MonadIO m) =>
Device -> a -> ByteString -> m ()
Debug.nameObject Device
device Item ("pipelines" ::: Vector Pipeline)
Pipeline
one ByteString
originModule
          pure Pipeline :: forall (dsl :: [*]) vertices instances.
Pipeline
-> Tagged dsl PipelineLayout
-> Tagged dsl (Vector DescriptorSetLayout)
-> Pipeline dsl vertices instances
Pipeline
            { $sel:pipeline:Pipeline :: Pipeline
pipeline     = Item ("pipelines" ::: Vector Pipeline)
Pipeline
one
            , $sel:pLayout:Pipeline :: Tagged dsl PipelineLayout
pLayout      = PipelineLayout -> Tagged dsl PipelineLayout
forall {k} (s :: k) b. b -> Tagged s b
Tagged PipelineLayout
layout
            , $sel:pDescLayouts:Pipeline :: Tagged dsl (Vector DescriptorSetLayout)
pDescLayouts = Vector DescriptorSetLayout
-> Tagged dsl (Vector DescriptorSetLayout)
forall {k} (s :: k) b. b -> Tagged s b
Tagged Vector DescriptorSetLayout
dsLayouts
            }
        "pipelines" ::: Vector Pipeline
_ ->
          String -> m (Pipeline dsl Compute Compute)
forall a. HasCallStack => String -> a
error String
"assert: exactly one pipeline requested"
    (Result
err, "pipelines" ::: Vector Pipeline
_) ->
      String -> m (Pipeline dsl Compute Compute)
forall (m :: * -> *) a. (MonadIO m, HasCallStack) => String -> m a
throwString (String -> m (Pipeline dsl Compute Compute))
-> String -> m (Pipeline dsl Compute Compute)
forall a b. (a -> b) -> a -> b
$ String
"createComputePipelines: " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Result -> String
forall a. Show a => a -> String
show Result
err

  where
    device :: Device
device = ctx -> Device
forall a. HasVulkan a => a -> Device
getDevice ctx
context
    cache :: PipelineCache
cache = ctx -> PipelineCache
forall ctx. ctx -> PipelineCache
getPipelineCache ctx
context

    layoutCI :: Vector DescriptorSetLayout -> PipelineLayoutCreateInfo
layoutCI Vector DescriptorSetLayout
dsLayouts = PipelineLayoutCreateInfo :: PipelineLayoutCreateFlags
-> Vector DescriptorSetLayout
-> Vector PushConstantRange
-> PipelineLayoutCreateInfo
Vk.PipelineLayoutCreateInfo
      { $sel:flags:PipelineLayoutCreateInfo :: PipelineLayoutCreateFlags
flags              = PipelineLayoutCreateFlags
forall a. Zero a => a
zero
      , $sel:setLayouts:PipelineLayoutCreateInfo :: Vector DescriptorSetLayout
setLayouts         = Vector DescriptorSetLayout
dsLayouts
      , $sel:pushConstantRanges:PipelineLayoutCreateInfo :: Vector PushConstantRange
pushConstantRanges = Vector PushConstantRange
cPushConstantRanges
      }

    pipelineCI :: l -> PipelineLayout -> ComputePipelineCreateInfo '[]
pipelineCI l
stages PipelineLayout
layout = ComputePipelineCreateInfo '[]
forall a. Zero a => a
zero
      { $sel:layout:ComputePipelineCreateInfo :: PipelineLayout
Vk.layout             = PipelineLayout
layout
      , $sel:stage:ComputePipelineCreateInfo :: SomeStruct PipelineShaderStageCreateInfo
Vk.stage              = Item l
SomeStruct PipelineShaderStageCreateInfo
stage
      , $sel:basePipelineHandle:ComputePipelineCreateInfo :: Pipeline
Vk.basePipelineHandle = Pipeline
forall a. Zero a => a
zero
      }
      where
        stage :: Item l
stage = case l
stages of
          [Item l
one]   -> Item l
one
          l
_assert -> String -> SomeStruct PipelineShaderStageCreateInfo
forall a. HasCallStack => String -> a
error String
"compute code has one stage"

bind
  :: ( Compatible pipeLayout boundLayout
     , MonadIO m
     )
  => Vk.CommandBuffer
  -> Pipeline pipeLayout vertices instances
  -> Bound boundLayout vertices instances m ()
  -> Bound boundLayout oldVertices oldInstances m ()
bind :: forall (pipeLayout :: [*]) (boundLayout :: [*]) (m :: * -> *)
       vertices instances oldVertices oldInstances.
(Compatible pipeLayout boundLayout, MonadIO m) =>
CommandBuffer
-> Pipeline pipeLayout vertices instances
-> Bound boundLayout vertices instances m ()
-> Bound boundLayout oldVertices oldInstances m ()
bind CommandBuffer
cb Pipeline{Pipeline
pipeline :: Pipeline
$sel:pipeline:Pipeline :: forall (dsl :: [*]) vertices instances.
Pipeline dsl vertices instances -> Pipeline
pipeline} (Bound m ()
attrAction) = do
  m () -> Bound boundLayout oldVertices oldInstances m ()
forall (dsl :: [*]) vertices instances (m :: * -> *) a.
m a -> Bound dsl vertices instances m a
Bound (m () -> Bound boundLayout oldVertices oldInstances m ())
-> m () -> Bound boundLayout oldVertices oldInstances m ()
forall a b. (a -> b) -> a -> b
$ CommandBuffer -> PipelineBindPoint -> Pipeline -> m ()
forall (io :: * -> *).
MonadIO io =>
CommandBuffer -> PipelineBindPoint -> Pipeline -> io ()
Vk.cmdBindPipeline CommandBuffer
cb PipelineBindPoint
Vk.PIPELINE_BIND_POINT_COMPUTE Pipeline
pipeline
  m () -> Bound boundLayout oldVertices oldInstances m ()
forall (dsl :: [*]) vertices instances (m :: * -> *) a.
m a -> Bound dsl vertices instances m a
Bound m ()
attrAction