module Data.SpirV.Reflect.Yaml.Parsers where import Prelude hiding (id) import Control.Applicative ((<|>)) import Data.Functor ((<&>)) import Data.Maybe (catMaybes) import Data.SpirV.Enum qualified as SpirV import Data.Text (Text) import Data.Vector (Vector) import Data.Vector qualified as Vector import Data.Vector.Storable qualified as Storable import Data.YAML ((.:)) import Data.YAML qualified as YAML import Data.SpirV.Reflect.BlockVariable (BlockVariable) import Data.SpirV.Reflect.BlockVariable qualified as BlockVariable import Data.SpirV.Reflect.DescriptorBinding (DescriptorBinding) import Data.SpirV.Reflect.DescriptorBinding qualified as DescriptorBinding import Data.SpirV.Reflect.DescriptorSet (DescriptorSet) import Data.SpirV.Reflect.DescriptorSet qualified as DescriptorSet import Data.SpirV.Reflect.Enums qualified as Reflect import Data.SpirV.Reflect.InterfaceVariable (InterfaceVariable) import Data.SpirV.Reflect.InterfaceVariable qualified as InterfaceVariable import Data.SpirV.Reflect.Module (Module) import Data.SpirV.Reflect.Module qualified as Module import Data.SpirV.Reflect.SpecializationConstant (SpecializationConstant) import Data.SpirV.Reflect.SpecializationConstant qualified as SpecializationConstant import Data.SpirV.Reflect.Traits qualified as Traits import Data.SpirV.Reflect.TypeDescription (TypeDescription) import Data.SpirV.Reflect.TypeDescription qualified as TypeDescription rootP :: NodeParser Module rootP = YAML.withMap "Root" \m -> -- XXX: Discarding all_XXX fields, assuming they got unrolled by YAML decoder. m .: "module" >>= moduleP moduleP :: NodeParser Module moduleP = YAML.withMap "Module" \m -> do generator <- m .: "generator" <&> Reflect.Generator entry_point_name <- m .: "entry_point_name" entry_point_id <- m .: "entry_point_id" source_language <- m .: "source_language" source_language_version <- m .: "source_language_version" spirv_execution_model <- m .: "spirv_execution_model" shader_stage <- m .: "shader_stage" descriptor_bindings <- m .? "descriptor_bindings" `seqOf` descriptorBindingP descriptor_sets <- m .? "descriptor_sets" `seqOf` descriptorSetP input_variables <- m .? "input_variables" `seqOf` interfaceVariableP output_variables <- m .? "output_variables" `seqOf` interfaceVariableP push_constants <- m .? "push_constants" `seqOf` blockVariableP spec_constants <- m .? "spec_constants" `seqOf` specializationConstantP pure Module.Module{..} descriptorBindingP :: NodeParser DescriptorBinding descriptorBindingP = YAML.withMap "DescriptorBinding" \m -> do spirv_id <- m .? "spirv_id" name <- m .| "name" binding <- m .: "binding" input_attachment_index <- m .: "input_attachment_index" set <- m .: "set" descriptor_type <- m .: "descriptor_type" <&> Reflect.DescriptorType resource_type <- m .: "resource_type" <&> Reflect.ResourceFlagBits image <- m .: "image" >>= traitsImageP block <- m .? "block" >>= traverse blockVariableP array <- m .: "array" >>= traitsArrayP count <- m .? "count" accessed <- m .: "accessed" uav_counter_id <- m .: "uav_counter_id" uav_counter_binding <- m .? "uav_counter_binding" >>= traverse descriptorBindingP byte_address_buffer_offsets <- m .? "byte_address_buffer_offsets" `seqOf` YAML.withInt "byte_address_buffer_offset" (pure . fromInteger) type_description <- m .? "type_description" >>= traverse typeDescriptionP word_offset <- m .: "word_offset" >>= descriptorBindingWordOffsetP decoration_flags <- m .? "decoration_flags" <&> maybe Reflect.DECORATION_NONE Reflect.DecorationFlagBits user_type <- m .? "user_type" >>= traverse userTypeP pure DescriptorBinding.DescriptorBinding{..} userTypeP :: NodeParser Reflect.UserType userTypeP = YAML.withStr "UserType" $ maybe (fail "Unknown UserType") pure . Reflect.userTypeId descriptorBindingWordOffsetP :: NodeParser DescriptorBinding.WordOffset descriptorBindingWordOffsetP = YAML.withMap "DescriptorBinding.WordOffset" \m -> do binding <- m .: "binding" set <- m .: "set" pure DescriptorBinding.WordOffset{..} typeDescriptionP :: NodeParser TypeDescription typeDescriptionP = YAML.withMap "TypeDescription" \m -> do id <- m .? "id" op <- m .? "op" <&> fmap SpirV.Op type_name <- m .? "type_name" struct_member_name <- m .? "struct_member_name" storage_class <- m .: "storage_class" <&> SpirV.StorageClass type_flags <- m .? "type_flags" <&> maybe Reflect.TYPE_FLAG_UNDEFINED Reflect.TypeFlagBits traits <- m .: "traits" >>= traverse typeDescriptionTraitsP <&> \ts -> if ts == Just TypeDescription.emptyTraits then Nothing else ts copied <- m .? "copied" -- BUG: YAML has no "copied" field members <- m .? "members" `seqOfMaybes` alt forwardPointerP (fmap Just . typeDescriptionP) -- BUG: YAML uses `[forward pointer]` marker for recursive structures. struct_type_description <- m .? "struct_type_description" >>= traverse typeDescriptionP pure TypeDescription.TypeDescription{..} typeDescriptionTraitsP :: NodeParser TypeDescription.Traits typeDescriptionTraitsP = YAML.withMap "TypeDescription.Traits" \m -> do numeric <- m .: "numeric" >>= traitsNumericP image <- m .: "image" >>= traitsImageP array <- m .: "array" >>= traitsArrayP pure TypeDescription.Traits{..} descriptorSetP :: NodeParser DescriptorSet descriptorSetP = YAML.withMap "DescriptorSet" \m -> do set <- m .: "set" bindings <- m .? "bindings" `seqOf` descriptorBindingP pure DescriptorSet.DescriptorSet{..} interfaceVariableP :: NodeParser InterfaceVariable interfaceVariableP = YAML.withMap "InterfaceVariable" \m -> do spirv_id <- m .? "spirv_id" name <- m .? "name" location <- m .: "location" component <- m .? "component" storage_class <- m .: "storage_class" <&> SpirV.StorageClass semantic <- m .: "semantic" decoration_flags <- m .: "decoration_flags" <&> Reflect.DecorationFlagBits numeric <- m .: "numeric" >>= traitsNumericP array <- m .: "array" >>= traitsArrayP members <- m .? "members" `seqOfMaybes` alt (fmap Just . interfaceVariableP) recursiveP built_in <- -- BUG: https://github.com/KhronosGroup/SPIRV-Reflect/issues/269 if null members then m .? "built_in" <&> fmap SpirV.BuiltIn else pure Nothing format <- m .: "format" <&> Reflect.Format type_description <- m .? "type_description" >>= traverse typeDescriptionP word_offset <- m .: "word_offset" >>= interfaceVariableWordOffsetP pure InterfaceVariable.InterfaceVariable{..} interfaceVariableWordOffsetP :: NodeParser InterfaceVariable.WordOffset interfaceVariableWordOffsetP = YAML.withMap "InterfaceVariable.WordOffset" \m -> do location <- m .: "location" pure InterfaceVariable.WordOffset{..} blockVariableP :: NodeParser BlockVariable blockVariableP = YAML.withMap "BlockVariable" \m -> do spirv_id <- m .? "spirv_id" name <- m .? "name" offset <- m .: "offset" absolute_offset <- m .: "absolute_offset" size <- m .: "size" padded_size <- m .: "padded_size" decorations <- m .: "decorations" <&> Reflect.DecorationFlagBits numeric <- m .: "numeric" >>= traitsNumericP array <- m .: "array" >>= traitsArrayP members <- m .? "members" `seqOfMaybes` alt recursiveP (fmap Just . blockVariableP) type_description <- m .? "type_description" >>= traverse typeDescriptionP word_offset <- m .? "word_offset" >>= traverse blockVariableWordOffsetP pure BlockVariable.BlockVariable{..} blockVariableWordOffsetP :: NodeParser BlockVariable.WordOffset blockVariableWordOffsetP = YAML.withMap "BlockVariable.WordOffset" \m -> do offset <- m .: "offset" pure BlockVariable.WordOffset{..} specializationConstantP :: NodeParser SpecializationConstant specializationConstantP = YAML.withMap "SpecializationConstant" \m -> do spirv_id <- m .? "spirv_id" constant_id <- m .: "constant_id" name <- m .? "name" pure SpecializationConstant.SpecializationConstant{..} -------------- traitsNumericP :: NodeParser Traits.Numeric traitsNumericP = YAML.withMap "Numeric" \m -> do scalar <- m .: "scalar" >>= traitsScalarP vector <- m .: "vector" >>= traitsVectorP matrix <- m .: "matrix" >>= traitsMatrixP pure Traits.Numeric{..} traitsScalarP :: NodeParser Traits.Scalar traitsScalarP = YAML.withMap "Scalar" \m -> do width <- m .: "width" signed <- m .: "signedness" <&> (/= (0 :: Word)) pure Traits.Scalar{..} traitsVectorP :: NodeParser Traits.Vector traitsVectorP = YAML.withMap "Vector" \m -> do component_count <- m .: "component_count" pure Traits.Vector{..} traitsMatrixP :: NodeParser Traits.Matrix traitsMatrixP = YAML.withMap "Matrix" \m -> do column_count <- m .: "column_count" row_count <- m .: "row_count" stride <- m .: "stride" pure Traits.Matrix{..} traitsArrayP :: NodeParser Traits.Array traitsArrayP = YAML.withMap "Array" \m -> do dims_count <- m .: "dims_count" dims <- m .: "dims" <&> Storable.fromList stride <- m .? "stride" <&> \stride -> if stride == Just 0 then Nothing else stride pure Traits.Array{..} traitsImageP :: NodeParser Traits.Image traitsImageP = YAML.withMap "Image" \m -> do dim <- m .: "dim" <&> SpirV.Dim depth <- m .: "depth" arrayed <- m .: "arrayed" ms <- m .: "ms" sampled <- m .: "sampled" image_format <- m .: "image_format" <&> SpirV.ImageFormat pure Traits.Image{..} recursiveP :: NodeParser (Maybe a) recursiveP = \case YAML.Sequence _pos _tag [YAML.Scalar _pos2 (YAML.SStr "recursive")] -> pure Nothing _ -> fail "Unexpected token" forwardPointerP :: NodeParser (Maybe a) forwardPointerP = \case YAML.Sequence _pos _tag [YAML.Scalar _pos2 (YAML.SStr "forward pointer")] -> pure Nothing _ -> fail "Unexpected token" type NodeParser a = YAML.Node YAML.Pos -> YAML.Parser a alt :: NodeParser a -> NodeParser a -> NodeParser a alt a b n = a n <|> b n (.?) :: YAML.FromYAML a => YAML.Mapping YAML.Pos -> Text -> YAML.Parser (Maybe a) o .? v = o YAML..:? v (.|) :: (YAML.FromYAML a, Monoid a) => YAML.Mapping YAML.Pos -> Text -> YAML.Parser a o .| v = o YAML..:? v YAML..!= mempty seqOf :: YAML.Parser (Maybe (YAML.Node YAML.Pos)) -> NodeParser a -> YAML.Parser (Vector a) seqOf mappingP nodeP = mappingP >>= \case Nothing -> pure mempty Just items -> YAML.withSeq "seqOf" (fmap Vector.fromList . traverse nodeP) items seqOfMaybes :: YAML.Parser (Maybe (YAML.Node YAML.Pos)) -> NodeParser (Maybe a) -> YAML.Parser (Vector a) seqOfMaybes mappingP nodeP = mappingP >>= \case Nothing -> pure mempty Just items -> YAML.withSeq "seqOf" (fmap (Vector.fromList . catMaybes) . traverse nodeP) items