{-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecursiveDo #-} {-# LANGUAGE TupleSections #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} -- TODO: remove module LambdaCube.Compiler.CoreToIR ( compilePipeline ) where import Data.Char import Data.Monoid import Data.Map (Map) import Data.Maybe import Data.Function import Data.List import qualified Data.Map as Map import qualified Data.Set as Set import qualified Data.Vector as Vector import Control.Arrow hiding ((<+>)) import Control.Monad.Writer import Control.Monad.State import LambdaCube.IR(Backend(..)) import qualified LambdaCube.IR as IR import qualified LambdaCube.Linear as IR import LambdaCube.Compiler.Pretty import LambdaCube.Compiler.DeBruijn as I import LambdaCube.Compiler.DesugaredSource hiding (getTTuple) import LambdaCube.Compiler.Core (Subst(..), down, nType) import qualified LambdaCube.Compiler.Core as I import LambdaCube.Compiler.Infer (neutType', makeCaseFunPars') import Data.Version import Paths_lambdacube_compiler (version) -------------------------------------------------------------------------- compilePipeline :: IR.Backend -> I.ExpType -> IR.Pipeline compilePipeline backend exp = IR.Pipeline { IR.info = "generated by lambdacube-compiler " ++ showVersion version , IR.backend = backend , IR.samplers = mempty , IR.programs = Vector.fromList . map fst . sortBy (compare `on` snd) . Map.toList $ programs , IR.slots = Vector.fromList . map snd . sortBy (compare `on` fst) . Map.elems $ slots , IR.targets = Vector.fromList . reverse . snd $ targets , IR.streams = Vector.fromList . reverse . snd $ streams , IR.textures = Vector.fromList . reverse . snd $ textures , IR.commands = Vector.fromList $ subCmds <> cmds } where ((subCmds,cmds), (streams, programs, targets, slots, textures)) = flip runState ((0, mempty), mempty, (0, mempty), mempty, (0, mempty)) $ case toExp exp of A1 "ScreenOut" a -> addTarget backend a [IR.TargetItem s $ Just $ IR.Framebuffer s | s <- getSemantics a] x -> error $ "ScreenOut expected inststead of " ++ ppShow x type CG = State (List IR.StreamData, Map IR.Program Int, List IR.RenderTarget, Map String (Int, IR.Slot), List IR.TextureDescriptor) type List a = (Int, [a]) streamLens f (a,b,c,d,e) = f (,b,c,d,e) a programLens f (a,b,c,d,e) = f (a,,c,d,e) b targetLens f (a,b,c,d,e) = f (a,b,,d,e) c slotLens f (a,b,c,d,e) = f (a,b,c,,e) d textureLens f (a,b,c,d,e) = f (a,b,c,d,) e modL gs f = state $ gs $ \fx -> second fx . f addL' l p f x = modL l $ \sv -> maybe (length sv, Map.insert p (length sv, x) sv) (\(i, x') -> (i, Map.insert p (i, f x x') sv)) $ Map.lookup p sv addL l x = modL l $ \(i, sv) -> (i, (i+1, x: sv)) addLEq l x = modL l $ \sv -> maybe (let i = length sv in i `seq` (i, Map.insert x i sv)) (\i -> (i, sv)) $ Map.lookup x sv --------------------------------------------------------- addTarget backend a tl = do rt <- addL targetLens $ IR.RenderTarget $ Vector.fromList tl second (IR.SetRenderTarget rt:) <$> getCommands backend a getCommands :: Backend -> ExpTV{-FrameBuffer-} -> CG ([IR.Command],[IR.Command]) getCommands backend e = case e of A1 "FrameBuffer" (ETuple a) -> return ([], [IR.ClearRenderTarget $ Vector.fromList $ map compFrameBuffer a]) A3 "Accumulate" actx (getFragmentShader -> (frag, getFragFilter -> (ffilter, x1))) fbuf -> case x1 of A3 "foldr" (A0 "++") (A0 "Nil") (A2 "map" (EtaPrim3 "rasterizePrimitive" ints rctx) (getVertexShader -> (vert, input_))) -> mdo let (vertexInput, pUniforms, vertSrc, fragSrc) = case backend of -- disabled DX11 codegen, due to it's incomplete --IR.DirectX11 -> genHLSLs backend (compRC' rctx) ints vert frag ffilter _ -> genGLSLs backend (compRC' rctx) ints vert frag ffilter pUniforms' = snd <$> Map.filter ((\case UTexture2D{} -> False; _ -> True) . fst) pUniforms prg = IR.Program { IR.programUniforms = pUniforms' , IR.programStreams = Map.fromList $ zip vertexInput $ map (uncurry IR.Parameter) input , IR.programInTextures = snd <$> Map.filter ((\case UUniform{} -> False; _ -> True) . fst) pUniforms , IR.programOutput = pure $ IR.Parameter "f0" IR.V4F -- TODO , IR.vertexShader = show vertSrc , IR.geometryShader = mempty -- TODO , IR.fragmentShader = show fragSrc } textureUniforms = [IR.SetSamplerUniform n textureUnit | ((n,IR.FTexture2D),textureUnit) <- zip (Map.toList pUniforms') [0..]] cmds = [ IR.SetProgram prog ] <> textureUniforms <> concat -- TODO: generate IR.SetSamplerUniform commands for texture slots [ [ IR.SetTexture textureUnit texture , IR.SetSamplerUniform name textureUnit ] | (textureUnit,(name,IR.TextureImage texture _ _)) <- zip [length textureUniforms..] smpBindings ] <> [ IR.SetRasterContext (compRC rctx) , IR.SetAccumulationContext (compAC actx) , renderCommand ] (smpBindings, txtCmds) <- mconcat <$> traverse (uncurry getRenderTextureCommands) (Map.toList $ fst <$> pUniforms) (renderCommand,input) <- case input_ of A2 "fetch" (EString slotName) attrs -> do i <- IR.RenderSlot <$> addL' slotLens slotName (flip mergeSlot) IR.Slot { IR.slotName = slotName , IR.slotUniforms = IR.programUniforms prg , IR.slotPrograms = pure prog , IR.slotStreams = Map.fromList input , IR.slotPrimitive = compFetchPrimitive $ getPrim $ tyOf input_ } return (i, input) where input = compInputType'' attrs mergeSlot a b = a { IR.slotUniforms = IR.slotUniforms a <> IR.slotUniforms b , IR.slotStreams = IR.slotStreams a <> IR.slotStreams b , IR.slotPrograms = IR.slotPrograms a <> IR.slotPrograms b } A1 "fetchArrays" (unzip . compAttributeValue -> (tys, values)) -> do i <- IR.RenderStream <$> addL streamLens IR.StreamData { IR.streamData = Map.fromList $ zip names values , IR.streamType = Map.fromList input , IR.streamPrimitive = compFetchPrimitive $ getPrim $ tyOf input_ , IR.streamPrograms = pure prog } return (i, input) where names = ["attribute_" ++ show i | i <- [0..]] input = zip names tys e -> error $ "getSlot: " ++ ppShow e prog <- addLEq programLens prg (<> (txtCmds, cmds)) <$> getCommands backend fbuf x -> error $ "getCommands': " ++ ppShow x x -> error $ "getCommands: " ++ ppShow x where getRenderTextureCommands :: String -> Uniform -> CG ([SamplerBinding],[IR.Command]) getRenderTextureCommands n = \case UTexture2D (fromIntegral -> width) (fromIntegral -> height) img -> do let (a, tf) = case img of A1 "PrjImageColor" a -> (,) a $ \[_, x] -> x A1 "PrjImage" a -> (,) a $ \[x] -> x tl <- forM (getSemantics a) $ \semantic -> do texture <- addL textureLens IR.TextureDescriptor { IR.textureType = IR.Texture2D (if semantic == IR.Color then IR.FloatT IR.RGBA else IR.FloatT IR.Red) 1 , IR.textureSize = IR.VV2U $ IR.V2 (fromIntegral width) (fromIntegral height) , IR.textureSemantic = semantic , IR.textureSampler = IR.SamplerDescriptor { IR.samplerWrapS = IR.Repeat , IR.samplerWrapT = Nothing , IR.samplerWrapR = Nothing , IR.samplerMinFilter = IR.Linear , IR.samplerMagFilter = IR.Linear , IR.samplerBorderColor = IR.VV4F (IR.V4 0 0 0 1) , IR.samplerMinLod = Nothing , IR.samplerMaxLod = Nothing , IR.samplerLodBias = 0 , IR.samplerCompareFunc = Nothing } , IR.textureBaseLevel = 0 , IR.textureMaxLevel = 0 } return $ IR.TargetItem semantic $ Just $ IR.TextureImage texture 0 Nothing (subCmds, cmds) <- addTarget backend a tl let (IR.TargetItem IR.Color (Just tx)) = tf tl return ([(n, tx)], subCmds ++ cmds) _ -> return mempty type SamplerBinding = (IR.UniformName,IR.ImageRef) ---------------------------------------------------------------- frameBufferType (A2 "FrameBuffer" _ ty) = ty frameBufferType x = error $ "illegal target type: " ++ ppShow x getSemantics = compSemantics . frameBufferType . tyOf getFragFilter (A2 "map" (EtaPrim2 "filterFragment" p) x) = (Just p, x) getFragFilter x = (Nothing, x) getVertexShader (A2 "map" (EtaPrim2 "mapPrimitive" f@(etaReds -> Just (_, o))) x) = ((Just f, tyOf o), x) --getVertexShader (A2 "map" (EtaPrim2 "mapPrimitive" f) x) = error $ "gff: " ++ show (case f of ExpTV x _ _ -> x) --ppShow (mapVal unFunc' f) --getVertexShader x = error $ "gf: " ++ ppShow x getVertexShader x = ((Nothing, getPrim' $ tyOf x), x) getFragmentShader (A2 "map" (EtaPrim2 "mapFragment" f@(etaReds -> Just (_, o))) x) = ((Just f, tyOf o), x) --getFragmentShader (A2 "map" (EtaPrim2 "mapFragment" f) x) = error $ "gff: " ++ ppShow f --getFragmentShader x = error $ "gf: " ++ ppShow x getFragmentShader x = ((Nothing, getPrim'' $ tyOf x), x) getPrim (A1 "List" (A2 "Primitive" _ p)) = p getPrim' (A1 "List" (A2 "Primitive" a _)) = a getPrim'' (A1 "List" (A2 "Vector" _ (A1 "Maybe" (A1 "SimpleFragment" (TTuple [a]))))) = a getPrim'' x = error $ "getPrim'':" ++ ppShow x compFrameBuffer = \case A1 "DepthImage" a -> IR.ClearImage IR.Depth $ compValue a A1 "ColorImage" a -> IR.ClearImage IR.Color $ compValue a x -> error $ "compFrameBuffer " ++ ppShow x compSemantics = map compSemantic . compList compList (A2 ":" a x) = a : compList x compList (A0 "Nil") = [] compList x = error $ "compList: " ++ ppShow x compSemantic = \case A0 "Depth" -> IR.Depth A0 "Stencil" -> IR.Stencil A1 "Color" _ -> IR.Color x -> error $ "compSemantic: " ++ ppShow x compAC (ETuple x) = IR.AccumulationContext Nothing $ map compFrag x compBlending x = case x of A0 "NoBlending" -> IR.NoBlending A1 "BlendLogicOp" a -> IR.BlendLogicOp (compLO a) A3 "Blend" (ETuple [a,b]) (ETuple [ETuple [c,d],ETuple [e,f]]) (compValue -> IR.VV4F g) -> IR.Blend (compBE a) (compBE b) (compBF c) (compBF d) (compBF e) (compBF f) g x -> error $ "compBlending " ++ ppShow x compBF x = case x of A0 "ZeroBF" -> IR.Zero A0 "OneBF" -> IR.One A0 "SrcColor" -> IR.SrcColor A0 "OneMinusSrcColor" -> IR.OneMinusSrcColor A0 "DstColor" -> IR.DstColor A0 "OneMinusDstColor" -> IR.OneMinusDstColor A0 "SrcAlpha" -> IR.SrcAlpha A0 "OneMinusSrcAlpha" -> IR.OneMinusSrcAlpha A0 "DstAlpha" -> IR.DstAlpha A0 "OneMinusDstAlpha" -> IR.OneMinusDstAlpha A0 "ConstantColor" -> IR.ConstantColor A0 "OneMinusConstantColor" -> IR.OneMinusConstantColor A0 "ConstantAlpha" -> IR.ConstantAlpha A0 "OneMinusConstantAlpha" -> IR.OneMinusConstantAlpha A0 "SrcAlphaSaturate" -> IR.SrcAlphaSaturate x -> error $ "compBF " ++ ppShow x compBE x = case x of A0 "FuncAdd" -> IR.FuncAdd A0 "FuncSubtract" -> IR.FuncSubtract A0 "FuncReverseSubtract" -> IR.FuncReverseSubtract A0 "Min" -> IR.Min A0 "Max" -> IR.Max x -> error $ "compBE " ++ ppShow x compLO x = case x of A0 "Clear" -> IR.Clear A0 "And" -> IR.And A0 "AndReverse" -> IR.AndReverse A0 "Copy" -> IR.Copy A0 "AndInverted" -> IR.AndInverted A0 "Noop" -> IR.Noop A0 "Xor" -> IR.Xor A0 "Or" -> IR.Or A0 "Nor" -> IR.Nor A0 "Equiv" -> IR.Equiv A0 "Invert" -> IR.Invert A0 "OrReverse" -> IR.OrReverse A0 "CopyInverted" -> IR.CopyInverted A0 "OrInverted" -> IR.OrInverted A0 "Nand" -> IR.Nand A0 "Set" -> IR.Set x -> error $ "compLO " ++ ppShow x compComparisonFunction x = case x of A0 "Never" -> IR.Never A0 "Less" -> IR.Less A0 "Equal" -> IR.Equal A0 "Lequal" -> IR.Lequal A0 "Greater" -> IR.Greater A0 "Notequal" -> IR.Notequal A0 "Gequal" -> IR.Gequal A0 "Always" -> IR.Always x -> error $ "compComparisonFunction " ++ ppShow x pattern EBool a <- (compBool -> Just a) compBool x = case x of A0 "True" -> Just True A0 "False" -> Just False x -> Nothing compFrag x = case x of A2 "DepthOp" (compComparisonFunction -> a) (EBool b) -> IR.DepthOp a b A2 "ColorOp" (compBlending -> b) (compValue -> v) -> IR.ColorOp b v x -> error $ "compFrag " ++ ppShow x toGLSLType msg x = showGLSLType msg $ compInputType msg x -- move to lambdacube-ir? showGLSLType msg = \case IR.Bool -> "bool" IR.Word -> "uint" IR.Int -> "int" IR.Float -> "float" IR.V2F -> "vec2" IR.V3F -> "vec3" IR.V4F -> "vec4" IR.V2B -> "bvec2" IR.V3B -> "bvec3" IR.V4B -> "bvec4" IR.V2U -> "uvec2" IR.V3U -> "uvec3" IR.V4U -> "uvec4" IR.V2I -> "ivec2" IR.V3I -> "ivec3" IR.V4I -> "ivec4" IR.M22F -> "mat2" IR.M33F -> "mat3" IR.M44F -> "mat4" IR.M23F -> "mat2x3" IR.M24F -> "mat2x4" IR.M32F -> "mat3x2" IR.M34F -> "mat3x4" IR.M42F -> "mat4x2" IR.M43F -> "mat4x3" IR.FTexture2D -> "sampler2D" t -> error $ "toGLSLType: " ++ msg ++ " " ++ show t supType = isJust . compInputType_ compInputType_ x = case x of TFloat -> Just IR.Float TVec 2 TFloat -> Just IR.V2F TVec 3 TFloat -> Just IR.V3F TVec 4 TFloat -> Just IR.V4F TBool -> Just IR.Bool TVec 2 TBool -> Just IR.V2B TVec 3 TBool -> Just IR.V3B TVec 4 TBool -> Just IR.V4B TInt -> Just IR.Int TVec 2 TInt -> Just IR.V2I TVec 3 TInt -> Just IR.V3I TVec 4 TInt -> Just IR.V4I TWord -> Just IR.Word TVec 2 TWord -> Just IR.V2U TVec 3 TWord -> Just IR.V3U TVec 4 TWord -> Just IR.V4U TMat 2 2 TFloat -> Just IR.M22F TMat 2 3 TFloat -> Just IR.M23F TMat 2 4 TFloat -> Just IR.M24F TMat 3 2 TFloat -> Just IR.M32F TMat 3 3 TFloat -> Just IR.M33F TMat 3 4 TFloat -> Just IR.M34F TMat 4 2 TFloat -> Just IR.M42F TMat 4 3 TFloat -> Just IR.M43F TMat 4 4 TFloat -> Just IR.M44F _ -> Nothing compInputType msg x = fromMaybe (error $ "compInputType " ++ msg ++ " " ++ ppShow x) $ compInputType_ x is234 = (`elem` [2,3,4]) compInputType'' (ETuple attrs) = map compAttribute attrs compAttribute = \case x@(A1 "Attribute" (EString s)) -> (s, compInputType "compAttr" $ tyOf x) x -> error $ "compAttribute " ++ ppShow x compAttributeValue :: ExpTV -> [(IR.InputType,IR.ArrayValue)] compAttributeValue (ETuple x) = checkLength $ map go x where emptyArray t | t `elem` [IR.Float,IR.V2F,IR.V3F,IR.V4F,IR.M22F,IR.M23F,IR.M24F,IR.M32F,IR.M33F,IR.M34F,IR.M42F,IR.M43F,IR.M44F] = IR.VFloatArray mempty emptyArray t | t `elem` [IR.Int,IR.V2I,IR.V3I,IR.V4I] = IR.VIntArray mempty emptyArray t | t `elem` [IR.Word,IR.V2U,IR.V3U,IR.V4U] = IR.VWordArray mempty emptyArray t | t `elem` [IR.Bool,IR.V2B,IR.V3B,IR.V4B] = IR.VBoolArray mempty emptyArray _ = error "compAttributeValue - emptyArray" flatten IR.Float (IR.VFloat x) (IR.VFloatArray l) = IR.VFloatArray $ pure x <> l flatten IR.V2F (IR.VV2F (IR.V2 x y)) (IR.VFloatArray l) = IR.VFloatArray $ pure x <> pure y <> l flatten IR.V3F (IR.VV3F (IR.V3 x y z)) (IR.VFloatArray l) = IR.VFloatArray $ pure x <> pure y <> pure z <> l flatten IR.V4F (IR.VV4F (IR.V4 x y z w)) (IR.VFloatArray l) = IR.VFloatArray $ pure x <> pure y <> pure z <> pure w <> l flatten _ _ _ = error "compAttributeValue" checkLength l@((a,_):_) = case all (\(i,_) -> i == a) l of True -> snd $ unzip l False -> error "FetchArrays array length mismatch!" go a = (length values,(t,foldr (flatten t) (emptyArray t) values)) where (A1 "List" (compInputType "compAV" -> t)) = tyOf a values = map compValue $ compList a compFetchPrimitive x = case x of A0 "Point" -> IR.Points A0 "Line" -> IR.Lines A0 "Triangle" -> IR.Triangles A0 "LineAdjacency" -> IR.LinesAdjacency A0 "TriangleAdjacency" -> IR.TrianglesAdjacency x -> error $ "compFetchPrimitive " ++ ppShow x compValue x = case x of EFloat a -> IR.VFloat $ realToFrac a EInt a -> IR.VInt $ fromIntegral a A2 "V2" (EFloat a) (EFloat b) -> IR.VV2F $ IR.V2 (realToFrac a) (realToFrac b) A3 "V3" (EFloat a) (EFloat b) (EFloat c) -> IR.VV3F $ IR.V3 (realToFrac a) (realToFrac b) (realToFrac c) A4 "V4" (EFloat a) (EFloat b) (EFloat c) (EFloat d) -> IR.VV4F $ IR.V4 (realToFrac a) (realToFrac b) (realToFrac c) (realToFrac d) A2 "V2" (EBool a) (EBool b) -> IR.VV2B $ IR.V2 a b A3 "V3" (EBool a) (EBool b) (EBool c) -> IR.VV3B $ IR.V3 a b c A4 "V4" (EBool a) (EBool b) (EBool c) (EBool d) -> IR.VV4B $ IR.V4 a b c d x -> error $ "compValue " ++ ppShow x compRC x = case x of A3 "PointCtx" a (EFloat b) c -> IR.PointCtx (compPS a) (realToFrac b) (compPSCO c) A2 "LineCtx" (EFloat a) b -> IR.LineCtx (realToFrac a) (compPV b) A4 "TriangleCtx" a b c d -> IR.TriangleCtx (compCM a) (compPM b) (compPO c) (compPV d) x -> error $ "compRC " ++ ppShow x compRC' x = case x of A3 "PointCtx" a _ _ -> compPS' a A4 "TriangleCtx" _ b _ _ -> compPM' b x -> Nothing compPSCO x = case x of A0 "LowerLeft" -> IR.LowerLeft A0 "UpperLeft" -> IR.UpperLeft x -> error $ "compPSCO " ++ ppShow x compCM x = case x of A0 "CullNone" -> IR.CullNone A0 "CullFront" -> IR.CullFront IR.CCW A0 "CullBack" -> IR.CullBack IR.CCW x -> error $ "compCM " ++ ppShow x compPM x = case x of A0 "PolygonFill" -> IR.PolygonFill A1 "PolygonLine" (EFloat a) -> IR.PolygonLine $ realToFrac a A1 "PolygonPoint" a -> IR.PolygonPoint $ compPS a x -> error $ "compPM " ++ ppShow x compPM' x = case x of A1 "PolygonPoint" a -> compPS' a x -> Nothing compPS x = case x of A1 "PointSize" (EFloat a) -> IR.PointSize $ realToFrac a A1 "ProgramPointSize" _ -> IR.ProgramPointSize x -> error $ "compPS " ++ ppShow x compPS' x = case x of A1 "ProgramPointSize" x -> Just x x -> Nothing compPO x = case x of A2 "Offset" (EFloat a) (EFloat b) -> IR.Offset (realToFrac a) (realToFrac b) A0 "NoOffset" -> IR.NoOffset x -> error $ "compPO " ++ ppShow x compPV x = case x of A0 "FirstVertex" -> IR.FirstVertex A0 "LastVertex" -> IR.LastVertex x -> error $ "compPV " ++ ppShow x --------------------------------------------------------------- GLSL generation genGLSLs backend rp -- program point size (ETuple ints) -- interpolations (vert, tvert) -- vertex shader (frag, tfrag) -- fragment shader ffilter -- fragment filter = ( -- vertex input vertInNames , -- uniforms vertUniforms <> fragUniforms , -- vertex shader code shader $ uniformDecls vertUniforms <> [shaderDecl (caseWO "attribute" "in") (text t) (text n) | (n, t) <- zip vertInNames vertIns] <> vertOutDecls "out" <> vertFuncs <> [mainFunc $ vertVals <> [shaderLet (text n) x | (n, x) <- zip vertOutNamesWithPosition vertGLSL] <> [shaderLet "gl_PointSize" x | Just x <- [ptGLSL]] ] , -- fragment shader code shader $ uniformDecls fragUniforms <> vertOutDecls "in" <> [shaderDecl "out" (text t) (text n) | (n, t) <- zip fragOutNames fragOuts, backend == OpenGL33] <> fragFuncs <> [mainFunc $ fragVals <> [shaderStmt $ "if" <+> parens ("!" <> parens filt) <+> "discard" | Just filt <- [filtGLSL]] <> [shaderLet (text n) x | (n, x) <- zip fragOutNames fragGLSL ] ] ) where uniformDecls us = [shaderDecl "uniform" (text $ showGLSLType "2" t) (text n) | (n, (_, t)) <- Map.toList us] vertOutDecls io = [shaderDecl (caseWO "varying" $ text i <+> io) (text t) (text n) | (n, (i, t)) <- zip vertOutNames vertOuts] fragOutNames = case length frags of 0 -> [] 1 -> [caseWO "gl_FragColor" "f0"] (vertIns, verts) = case vert of Just (etaReds -> Just (xs, ETuple ys)) -> (toGLSLType "3" <$> xs, ys) Nothing -> ([toGLSLType "4" tvert], [mkTVar 0 tvert]) (fragOuts, frags) = case frag of Just (etaReds -> Just (xs, ETuple ys)) -> (toGLSLType "31" . tyOf <$> ys, ys) Nothing -> ([toGLSLType "41" tfrag], [mkTVar 0 tfrag]) (((vertGLSL, ptGLSL), (vertUniforms, (vertFuncs, vertVals))), ((filtGLSL, fragGLSL), (fragUniforms, (fragFuncs, fragVals)))) = flip evalState shaderNames $ do ((g1, (us1, verts)), (g2, (us2, frags))) <- (,) <$> runWriterT ((,) <$> traverse (genGLSL' "1" vertInNames . (,) vertIns) verts <*> traverse (genGLSL' "2" vertOutNamesWithPosition . reds) rp) <*> runWriterT ((,) <$> traverse (genGLSL' "3" vertOutNames . red) ffilter <*> traverse (genGLSL' "4" vertOutNames . (,) (snd <$> vertOuts)) frags) (,) <$> ((,) g1 <$> fixFuncs us1 mempty mempty verts) <*> ((,) g2 <$> fixFuncs us2 mempty mempty frags) fixFuncs :: Uniforms -> Set.Set SName -> ([Doc], [Doc]) -> Map.Map SName (ExpTV, ExpTV, [ExpTV]) -> State [SName] (Uniforms, ([Doc], [Doc])) fixFuncs us ns fsb (Map.toList -> fsa) | null fsa = return (us, fsb) | otherwise = do (unzip -> (defs, unzip -> (us', fs'))) <- forM fsa $ \(fn, (def, ty, tys)) -> runWriterT $ genGLSL (reverse $ take (length tys) funArgs) $ removeLams (length tys) def let fsb' = mconcat (zipWith combine fsa defs) <> fsb ns' = ns <> Set.fromList (map fst fsa) fixFuncs (us <> mconcat us') ns' fsb' (mconcat fs' `Map.difference` Map.fromSet (const undefined) ns') where combine (fn, (_, ty, tys)) def = case tys of [] -> ( [shaderDecl' ot n], [shaderLet n def] ) _ -> ( [shaderFunc ot n (zipWith (<+>) (map (toGLSLType "45") tys) (map text funArgs)) [shaderReturn def]] , [] ) where ot = toGLSLType "44" ty n = text fn funArgs = map (("z" ++) . show) [0..] shaderNames = map (("s" ++) . show) [0..] vertInNames = map (("vi" ++) . show) [1..length vertIns] vertOutNames = map (("vo" ++) . show) [1..length vertOuts] vertOutNamesWithPosition = "gl_Position": vertOutNames red (etaReds -> Just (ps, o)) = (ps, o) red x = error $ "red: " ++ ppShow x reds (etaReds -> Just (ps, o)) = (ps, o) reds x = error $ "red: " ++ ppShow x genGLSL' err vertOuts (ps, o) | length ps == length vertOuts = genGLSL (reverse vertOuts) o | otherwise = error $ "makeSubst illegal input " ++ err ++ " " ++ ppShow ps ++ "\n" ++ ppShow vertOuts noUnit TTuple0 = False noUnit _ = True vertOuts = zipWith go ints $ tail verts where go (A0 n) e = (interpName n, toGLSLType "3" $ tyOf e) interpName "Smooth" = "smooth" interpName "Flat" = "flat" interpName "NoPerspective" = "noperspective" shader xs = vcat $ ["#version" <+> caseWO "100" "330 core"] <> ["precision highp float;" | backend == WebGL1] <> ["precision highp int;" | backend == WebGL1] <> [shaderFunc "vec4" "texture2D" ["sampler2D s", "vec2 uv"] [shaderReturn "texture(s,uv)"] | backend == OpenGL33] <> [shaderFunc "mat4" "transpose" ["mat4 m"] -- todo: not just for 4 dimension [ shaderLet "vec4 i0" "m[0]" , shaderLet "vec4 i1" "m[1]" , shaderLet "vec4 i2" "m[2]" , shaderLet "vec4 i3" "m[3]" , shaderReturn "mat4(\ \vec4(i0.x, i1.x, i2.x, i3.x),\ \vec4(i0.y, i1.y, i2.y, i3.y),\ \vec4(i0.z, i1.z, i2.z, i3.z),\ \vec4(i0.w, i1.w, i2.w, i3.w)\ \)" ] | backend == WebGL1 ] <> xs shaderFunc outtype name pars body = nest 4 (outtype <+> name <> tupled pars <+> "{" <$$> vcat body) <$$> "}" mainFunc xs = shaderFunc "void" "main" [] xs shaderStmt xs = nest 4 $ xs <> ";" shaderReturn xs = shaderStmt $ "return" <+> xs shaderLet a b = shaderStmt $ a <+> "=" b shaderDecl a b c = shaderDecl' (a <+> b) c shaderDecl' b c = shaderStmt $ b <+> c caseWO w o = case backend of WebGL1 -> w; OpenGL33 -> o data Uniform = UUniform | UTexture2DSlot | UTexture2D Integer Integer ExpTV type Uniforms = Map String (Uniform, IR.InputType) tellUniform x = tell (x, mempty) simpleExpr = \case Con cn xs -> case cn of "Uniform" -> True _ -> False _ -> False genGLSL :: [SName] -> ExpTV -> WriterT (Uniforms, Map.Map SName (ExpTV, ExpTV, [ExpTV])) (State [String]) Doc genGLSL dns e = case e of ELit a -> pure $ pShow a Var i _ -> pure $ text $ dns !! i Func fn def ty xs | not (simpleExpr def) -> tell (mempty, Map.singleton fn (def, ty, map tyOf xs)) >> call fn xs Con cn xs -> case cn of "primIfThenElse" -> case xs of [a, b, c] -> hsep <$> sequence [gen a, pure "?", gen b, pure ":", gen c] "swizzscalar" -> case xs of [e, getSwizzChar -> Just s] -> showSwizzProj [s] <$> gen e "swizzvector" -> case xs of [e, Con ((`elem` ["V2","V3","V4"]) -> True) (traverse getSwizzChar -> Just s)] -> showSwizzProj s <$> gen e "Uniform" -> case xs of [EString s] -> do tellUniform $ Map.singleton s $ (,) UUniform $ compInputType "unif" $ tyOf e pure $ text s "Sampler" -> case xs of [_, _, A1 "Texture2DSlot" (EString s)] -> do tellUniform $ Map.singleton s $ (,) UTexture2DSlot IR.FTexture2D{-compInputType $ tyOf e -- TODO-} pure $ text s [_, _, A2 "Texture2D" (A2 "V2" (EInt w) (EInt h)) b] -> do s <- newName tellUniform $ Map.singleton s $ (,) (UTexture2D w h b) IR.FTexture2D pure $ text s 'P':'r':'i':'m':n | n'@(_:_) <- trName (dropS n) -> call n' xs where ifType p a b = if all (p . tyOf) xs then a else b dropS n | last n == 'S' && init n `elem` ["Add", "Sub", "Div", "Mod", "BAnd", "BOr", "BXor", "BShiftL", "BShiftR", "Min", "Max", "Clamp", "Mix", "Step", "SmoothStep"] = init n | otherwise = n trName = \case -- Arithmetic Functions "Add" -> "+" "Sub" -> "-" "Neg" -> "-_" "Mul" -> ifType isMatrix "matrixCompMult" "*" "MulS" -> "*" "Div" -> "/" "Mod" -> ifType isIntegral "%" "mod" -- Bit-wise Functions "BAnd" -> "&" "BOr" -> "|" "BXor" -> "^" "BNot" -> "~_" "BShiftL" -> "<<" "BShiftR" -> ">>" -- Logic Functions "And" -> "&&" "Or" -> "||" "Xor" -> "^" "Not" -> ifType isScalar "!_" "not" -- Integer/Float Conversion Functions "FloatBitsToInt" -> "floatBitsToInt" "FloatBitsToUInt" -> "floatBitsToUint" "IntBitsToFloat" -> "intBitsToFloat" "UIntBitsToFloat" -> "uintBitsToFloat" -- Matrix Functions "OuterProduct" -> "outerProduct" "MulMatVec" -> "*" "MulVecMat" -> "*" "MulMatMat" -> "*" -- Fragment Processing Functions "DFdx" -> "dFdx" "DFdy" -> "dFdy" -- Vector and Scalar Relational Functions "LessThan" -> ifType isScalarNum "<" "lessThan" "LessThanEqual" -> ifType isScalarNum "<=" "lessThanEqual" "GreaterThan" -> ifType isScalarNum ">" "greaterThan" "GreaterThanEqual" -> ifType isScalarNum ">=" "greaterThanEqual" "Equal" -> "==" "EqualV" -> ifType isScalar "==" "equal" "NotEqual" -> "!=" "NotEqualV" -> ifType isScalar "!=" "notEqual" -- Angle and Trigonometry Functions "ATan2" -> "atan" -- Exponential Functions "InvSqrt" -> "inversesqrt" -- Common Functions "RoundEven" -> "roundEven" "ModF" -> error "PrimModF is not implemented yet!" -- TODO "MixB" -> "mix" n | n `elem` -- Logic Functions [ "Any", "All" -- Angle and Trigonometry Functions , "ACos", "ACosH", "ASin", "ASinH", "ATan", "ATanH", "Cos", "CosH", "Degrees", "Radians", "Sin", "SinH", "Tan", "TanH" -- Exponential Functions , "Pow", "Exp", "Exp2", "Log2", "Sqrt" -- Common Functions , "IsNan", "IsInf", "Abs", "Sign", "Floor", "Trunc", "Round", "Ceil", "Fract", "Min", "Max", "Mix", "Clamp", "Step", "SmoothStep" -- Geometric Functions , "Length", "Distance", "Dot", "Cross", "Normalize", "FaceForward", "Reflect", "Refract" -- Matrix Functions , "Transpose", "Determinant", "Inverse" -- Fragment Processing Functions , "FWidth" -- Noise Functions , "Noise1", "Noise2", "Noise3", "Noise4" ] -> map toLower n _ -> "" n | n@(_:_) <- trName n -> call n xs where trName n = case n of "texture2D" -> "texture2D" "True" -> "true" "False" -> "false" "M22F" -> "mat2" "M33F" -> "mat3" "M44F" -> "mat4" "==" -> "==" n | n `elem` ["primNegateWord", "primNegateInt", "primNegateFloat"] -> "-_" n | n `elem` ["V2", "V3", "V4"] -> toGLSLType (n ++ " " ++ show (length xs)) $ tyOf e _ -> "" -- not supported n | n `elem` ["primIntToWord", "primIntToFloat", "primCompareInt", "primCompareWord", "primCompareFloat"] -> error $ "WebGL 1 does not support: " ++ ppShow e n | n `elem` ["M23F", "M24F", "M32F", "M34F", "M42F", "M43F"] -> error "WebGL 1 does not support matrices with this dimension" x -> error $ "GLSL codegen - unsupported function: " ++ ppShow x x -> error $ "GLSL codegen - unsupported expression: " ++ ppShow x where newName = gets head <* modify tail call f xs = case f of (c:_) | isAlpha c -> case xs of [] -> return $ text f xs -> (text f ) . tupled <$> mapM gen xs [op, '_'] -> case xs of [a] -> (text [op] <+>) . parens <$> gen a o -> case xs of [a, b] -> hsep <$> sequence [parens <$> gen a, pure $ text o, parens <$> gen b] gen = genGLSL dns isMatrix :: Ty -> Bool isMatrix TMat{} = True isMatrix _ = False isIntegral :: Ty -> Bool isIntegral TWord = True isIntegral TInt = True isIntegral (TVec _ TWord) = True isIntegral (TVec _ TInt) = True isIntegral _ = False isScalarNum :: Ty -> Bool isScalarNum = \case TInt -> True TWord -> True TFloat -> True _ -> False isScalar :: Ty -> Bool isScalar TBool = True isScalar x = isScalarNum x getSwizzChar = \case A0 "Sx" -> Just 'x' A0 "Sy" -> Just 'y' A0 "Sz" -> Just 'z' A0 "Sw" -> Just 'w' _ -> Nothing showSwizzProj x a = parens a <> "." <> text x -------------------------------------------------------------------------------- -- expression + type + type of local variables data ExpTV = ExpTV_ I.Exp I.Exp [I.Exp] deriving (Eq) pattern ExpTV a b c <- ExpTV_ a b c where ExpTV a b c = ExpTV_ (a) (unLab' b) c type Ty = ExpTV tyOf :: ExpTV -> Ty tyOf (ExpTV _ t vs) = t .@ vs expOf (ExpTV x _ _) = x mapVal f (ExpTV a b c) = ExpTV (f a) b c toExp :: I.ExpType -> ExpTV toExp (I.ET x xt) = ExpTV x xt [] pattern Pi h a b <- (mkPi . mapVal unLab' -> Just (h, a, b)) pattern Lam h a b <- (mkLam . mapVal unFunc' -> Just (h, a, b)) pattern Con h b <- (mkCon . mapVal unLab' -> Just (h, b)) pattern App a b <- (mkApp . mapVal unLab' -> Just (a, b)) pattern Var a b <- (mkVar . mapVal unLab' -> Just (a, b)) pattern ELit l <- ExpTV (unLab' -> I.ELit l) _ _ pattern TType <- ExpTV (unLab' -> I.TType) _ _ pattern Func fn def ty xs <- (mkFunc -> Just (fn, def, ty, xs)) pattern EString s <- ELit (LString s) pattern EFloat s <- ELit (LFloat s) pattern EInt s <- ELit (LInt s) t .@ vs = ExpTV t I.TType vs infix 1 .@ mkVar (ExpTV (I.Var i) t vs) = Just (i, t .@ vs) mkVar _ = Nothing mkPi (ExpTV (I.Pi b x y) _ vs) = Just (b, x .@ vs, y .@ addToEnv x vs) mkPi _ = Nothing mkLam (ExpTV (I.Lam y) (I.Pi b x yt) vs) = Just (b, x .@ vs, ExpTV y yt $ addToEnv x vs) mkLam _ = Nothing mkCon (ExpTV (I.Con s n (reverse -> xs)) et vs) = Just (untick $ show s, chain vs (I.conType et s) $ I.mkConPars n et ++ xs) mkCon (ExpTV (I.TyCon s (reverse -> xs)) et vs) = Just (untick $ show s, chain vs (nType s) xs) mkCon (ExpTV (I.Neut (I.Fun s@(I.FunName _ loc _{-I.DeltaDef{}-} _) (reverse -> xs) def)) et vs) = Just (untick $ show s, drop loc $ chain vs (nType s) xs) mkCon (ExpTV (I.CaseFun s xs n) et vs) = Just (untick $ show s, chain vs (nType s) $ makeCaseFunPars' (mkEnv vs) n ++ xs ++ [I.Neut n]) mkCon (ExpTV (I.TyCaseFun s [m, t, f] n) et vs) = Just (untick $ show s, chain vs (nType s) [m, t, I.Neut n, f]) mkCon _ = Nothing mkApp (ExpTV (I.Neut (I.App_ a b)) et vs) = Just (ExpTV (I.Neut a) t vs, head $ chain vs t [b]) where t = neutType' (mkEnv vs) a mkApp _ = Nothing removeRHS 0 (I.RHS x) = Just x removeRHS n (I.Lam x) | n > 0 = I.Lam <$> removeRHS (n-1) x removeRHS _ _ = Nothing mkFunc r@(ExpTV (I.Neut (I.Fun (I.FunName (show -> n) loc (I.ExpDef def_) nt) xs I.RHS{})) ty vs) | Just def <- removeRHS (length xs) def_ , all (supType . tyOf) (r: xs') && n `notElem` ["typeAnn"] && all validChar n = Just (untick n +++ intercalate "_" (filter (/="TT") $ map (filter isAlphaNum . plainShow . shortForm . pShow) hs), toExp $ I.ET (foldl I.app_ def hs) (foldl I.appTy nt hs), tyOf r, xs') where a +++ [] = a a +++ b = a ++ "_" ++ b (map (expOf . snd) -> hs, map snd -> xs') = splitAt loc $ chain' vs nt $ reverse xs validChar = isAlphaNum {- mkFunc r@(ExpTV (I.Neut (I.Fun (I.FunName (show -> n) loc (I.ExpDef def_) nt) xs I.RHS{})) ty vs) | Just def <- removeRHS (length xs) def_ , all validChar n = tracePShow (text n, reverse xs, supType . tyOf <$> (r: xs')) Nothing where a +++ [] = a a +++ b = a ++ "_" ++ b (map (expOf . snd) -> hs, map snd -> xs') = splitAt loc $ chain' vs nt $ reverse xs validChar = isAlphaNum mkFunc r@(ExpTV (I.Neut (I.Fun (I.FunName (show -> n) loc (I.ExpDef def_) nt) xs I.RHS{})) ty vs) = tracePShow (text n, take loc $ reverse xs) Nothing -} mkFunc _ = Nothing chain vs t@(I.Pi Hidden at y) (a: as) = chain vs (I.appTy t a) as chain vs t xs = map snd $ chain' vs t xs chain' vs t [] = [] chain' vs t@(I.Pi b at y) (a: as) = (b, ExpTV a at vs): chain' vs (I.appTy t a) as chain' vs t _ = error $ "chain: " ++ ppShow t mkTVar i (ExpTV t _ vs) = ExpTV (I.Var i) t vs unLab' (I.Reduced x) = unLab' x unLab' (I.RHS x) = unLab' x -- TODO: remove unLab' x = x unFunc' (I.Reduced x) = unFunc' x -- todo: remove? unFunc' (I.Neut (I.Fun (I.FunName _ _ I.ExpDef{} _) _ y)) = unFunc' y unFunc' (I.RHS x) = unFunc' x -- TODO: remove unFunc' x = x instance Subst I.Exp ExpTV where subst_ i0 dx x (ExpTV a at vs) = ExpTV (subst_ i0 dx x a) (subst_ i0 dx x at) (zipWith (\i -> subst_ (i0+i) (I.shiftFreeVars i dx) $ up i x{-todo: review-}) [1..] vs) addToEnv x xs = x: xs mkEnv xs = {-trace_ ("mk " ++ show (length xs)) $ -} zipWith up [1..] xs instance HasFreeVars ExpTV where getFreeVars (ExpTV x xt vs) = getFreeVars x <> getFreeVars xt instance PShow ExpTV where pShow (ExpTV x t _) = pShow (x, t) isSampler (I.TyCon n _) = show n == "'Sampler" isSampler _ = False -------------------------------------------------------------------------------- ExpTV conversion -- TODO: remove removeLams 0 x = x removeLams i (ELam _ x) = removeLams (i-1) x removeLams i (Lam Hidden _ x) = removeLams i x etaReds (ELam _ (App (down 0 -> Just f) (EVar 0))) = etaReds f etaReds (ELam _ (hlistLam -> x@Just{})) = x etaReds (ELam p i) = Just ([p], i) etaReds x = Nothing hlistLam :: ExpTV -> Maybe ([ExpTV], ExpTV) hlistLam (A3 "hlistNilCase" _ (down 0 -> Just x) (EVar 0)) = Just ([], x) hlistLam (A3 "hlistConsCase" _ (down 0 -> Just (getPats 2 -> Just ([p, px], x))) (EVar 0)) = first (p:) <$> hlistLam x hlistLam _ = Nothing getPats 0 e = Just ([], e) getPats i (ELam p e) = first (p:) <$> getPats (i-1) e getPats i (Lam Hidden p (down 0 -> Just e)) = getPats i e getPats i x = error $ "getPats: " ++ show i ++ " " ++ ppShow x pattern EtaPrim1 s <- (getEtaPrim -> Just (s, [])) pattern EtaPrim2 s x <- (getEtaPrim -> Just (s, [x])) pattern EtaPrim3 s x1 x2 <- (getEtaPrim -> Just (s, [x1, x2])) pattern EtaPrim4 s x1 x2 x3 <- (getEtaPrim -> Just (s, [x1, x2, x3])) pattern EtaPrim5 s x1 x2 x3 x4 <- (getEtaPrim -> Just (s, [x1, x2, x3, x4])) pattern EtaPrim2_2 s <- (getEtaPrim2 -> Just (s, [])) getEtaPrim (ELam _ (Con s (initLast -> Just (traverse (down 0) -> Just xs, EVar 0)))) = Just (s, xs) getEtaPrim _ = Nothing getEtaPrim2 (ELam _ (ELam _ (Con s (initLast -> Just (initLast -> Just (traverse (down 0) -> Just (traverse (down 0) -> Just xs), EVar 0), EVar 0))))) = Just (s, xs) getEtaPrim2 _ = Nothing initLast [] = Nothing initLast xs = Just (init xs, last xs) ------------- pattern EVar n <- Var n _ pattern ELam t b <- Lam Visible t b pattern A0 n <- Con n [] pattern A1 n a <- Con n [a] pattern A2 n a b <- Con n [a, b] pattern A3 n a b c <- Con n [a, b, c] pattern A4 n a b c d <- Con n [a, b, c, d] pattern A5 n a b c d e <- Con n [a, b, c, d, e] pattern TTuple0 <- A1 "HList" (A0 "Nil") pattern TBool <- A0 "Bool" pattern TWord <- A0 "Word" pattern TInt <- A0 "Int" pattern TNat <- A0 "Nat" pattern TFloat <- A0 "Float" pattern TString <- A0 "String" pattern TVec n a <- A2 "VecS" a (Nat n) pattern TMat i j a <- A3 "Mat" (Nat i) (Nat j) a pattern Nat n <- (fromNat -> Just n) fromNat :: ExpTV -> Maybe Int fromNat (A0 "Zero") = Just 0 fromNat (A1 "Succ" n) = (1 +) <$> fromNat n fromNat _ = Nothing pattern TTuple xs <- (getTTuple -> Just xs) pattern ETuple xs <- (getTuple -> Just xs) getTTuple (A1 "HList" l) = Just $ compList l getTTuple _ = Nothing getTuple (A0 "HNil") = Just [] getTuple (A2 "HCons" x (getTuple -> Just xs)) = Just (x: xs) getTuple _ = Nothing ------------ HLSL DX11 genHLSL :: [SName] -> ExpTV -> WriterT (Uniforms, Map.Map SName (ExpTV, ExpTV, [ExpTV])) (State [String]) Doc genHLSL dns e = case e of ELit a -> pure $ pShow a Var i _ -> pure $ text $ dns !! i Func fn def ty xs | not (simpleExpr def) -> tell (mempty, Map.singleton fn (def, ty, map tyOf xs)) >> call fn xs Con cn xs -> case cn of "primIfThenElse" -> case xs of [a, b, c] -> hsep <$> sequence [gen a, pure "?", gen b, pure ":", gen c] "swizzscalar" -> case xs of [e, getSwizzChar -> Just s] -> showSwizzProj [s] <$> gen e "swizzvector" -> case xs of [e, Con ((`elem` ["V2","V3","V4"]) -> True) (traverse getSwizzChar -> Just s)] -> showSwizzProj s <$> gen e "Uniform" -> case xs of [EString s] -> do tellUniform $ Map.singleton s $ (,) UUniform $ compInputType "unif" $ tyOf e pure $ text s "Sampler" -> case xs of [_, _, A1 "Texture2DSlot" (EString s)] -> do tellUniform $ Map.singleton s $ (,) UTexture2DSlot IR.FTexture2D{-compInputType $ tyOf e -- TODO-} pure $ text s [_, _, A2 "Texture2D" (A2 "V2" (EInt w) (EInt h)) b] -> do s <- newName tellUniform $ Map.singleton s $ (,) (UTexture2D w h b) IR.FTexture2D pure $ text s 'P':'r':'i':'m':n | n'@(_:_) <- trName (dropS n) -> call n' xs where ifType p a b = if all (p . tyOf) xs then a else b dropS n | last n == 'S' && init n `elem` ["Add", "Sub", "Div", "Mod", "BAnd", "BOr", "BXor", "BShiftL", "BShiftR", "Min", "Max", "Clamp", "Mix", "Step", "SmoothStep"] = init n | otherwise = n trName = \case -- Arithmetic Functions "Add" -> "+" "Sub" -> "-" "Neg" -> "-_" "Mul" -> "*" "MulS" -> "*" "Div" -> "/" "Mod" -> ifType isIntegral "%" "mod" -- Bit-wise Functions "BAnd" -> "&" "BOr" -> "|" "BXor" -> "^" "BNot" -> "~_" "BShiftL" -> "<<" "BShiftR" -> ">>" -- Logic Functions "And" -> "&&" "Or" -> "||" "Xor" -> "^" "Not" -> ifType isScalar "!_" "not" -- Integer/Float Conversion Functions "FloatBitsToInt" -> "floatBitsToInt" "FloatBitsToUInt" -> "floatBitsToUint" "IntBitsToFloat" -> "intBitsToFloat" "UIntBitsToFloat" -> "uintBitsToFloat" -- Matrix Functions "OuterProduct" -> "outerProduct" "MulMatVec" -> "mul" "MulVecMat" -> "mul" "MulMatMat" -> "mul" -- Fragment Processing Functions "DFdx" -> "dFdx" "DFdy" -> "dFdy" -- Vector and Scalar Relational Functions "LessThan" -> ifType isScalarNum "<" "lessThan" "LessThanEqual" -> ifType isScalarNum "<=" "lessThanEqual" "GreaterThan" -> ifType isScalarNum ">" "greaterThan" "GreaterThanEqual" -> ifType isScalarNum ">=" "greaterThanEqual" "Equal" -> "==" "EqualV" -> ifType isScalar "==" "equal" "NotEqual" -> "!=" "NotEqualV" -> ifType isScalar "!=" "notEqual" -- Angle and Trigonometry Functions "ATan2" -> "atan" -- Exponential Functions "InvSqrt" -> "inversesqrt" -- Common Functions "RoundEven" -> "roundEven" "ModF" -> error "PrimModF is not implemented yet!" -- TODO "MixB" -> "mix" n | n `elem` -- Logic Functions [ "Any", "All" -- Angle and Trigonometry Functions , "ACos", "ACosH", "ASin", "ASinH", "ATan", "ATanH", "Cos", "CosH", "Degrees", "Radians", "Sin", "SinH", "Tan", "TanH" -- Exponential Functions , "Pow", "Exp", "Exp2", "Log2", "Sqrt" -- Common Functions , "IsNan", "IsInf", "Abs", "Sign", "Floor", "Trunc", "Round", "Ceil", "Fract", "Min", "Max", "Mix", "Step", "SmoothStep" -- Geometric Functions , "Length", "Distance", "Dot", "Cross", "Normalize", "FaceForward", "Reflect", "Refract" -- Matrix Functions , "Transpose", "Determinant", "Inverse" -- Fragment Processing Functions , "FWidth" -- Noise Functions , "Noise1", "Noise2", "Noise3", "Noise4" ] -> map toLower n _ -> "" n | n@(_:_) <- trName n -> call n xs where trName n = case n of "texture2D" -> "texture2D" "True" -> "true" "False" -> "false" "M22F" -> "float2x2" "M33F" -> "float3x3" "M44F" -> "float4x4" "==" -> "==" n | n `elem` ["primNegateWord", "primNegateInt", "primNegateFloat"] -> "-_" n | n `elem` ["V2", "V3", "V4"] -> toHLSLType (n ++ " " ++ show (length xs)) $ tyOf e _ -> "" -- not supported n | n `elem` ["primIntToWord", "primIntToFloat", "primCompareInt", "primCompareWord", "primCompareFloat"] -> error $ "WebGL 1 does not support: " ++ ppShow e n | n `elem` ["M23F", "M24F", "M32F", "M34F", "M42F", "M43F"] -> error "WebGL 1 does not support matrices with this dimension" x -> error $ "HLSL codegen - unsupported function: " ++ ppShow x x -> error $ "HLSL codegen - unsupported expression: " ++ ppShow x where newName = gets head <* modify tail call f xs = case f of (c:_) | isAlpha c -> case xs of [] -> return $ text f xs -> (text f ) . tupled <$> mapM gen xs [op, '_'] -> case xs of [a] -> (text [op] <+>) . parens <$> gen a o -> case xs of [a, b] -> hsep <$> sequence [parens <$> gen a, pure $ text o, parens <$> gen b] gen = genHLSL dns isMatrix :: Ty -> Bool isMatrix TMat{} = True isMatrix _ = False isIntegral :: Ty -> Bool isIntegral TWord = True isIntegral TInt = True isIntegral (TVec _ TWord) = True isIntegral (TVec _ TInt) = True isIntegral _ = False isScalarNum :: Ty -> Bool isScalarNum = \case TInt -> True TWord -> True TFloat -> True _ -> False isScalar :: Ty -> Bool isScalar TBool = True isScalar x = isScalarNum x getSwizzChar = \case A0 "Sx" -> Just 'x' A0 "Sy" -> Just 'y' A0 "Sz" -> Just 'z' A0 "Sw" -> Just 'w' _ -> Nothing showSwizzProj x a = parens a <> "." <> text x genHLSLs backend rp -- program point size (ETuple ints) -- interpolations (vert, tvert) -- vertex shader (frag, tfrag) -- fragment shader ffilter -- fragment filter = ( -- vertex input vertInNames , -- uniforms vertUniforms <> fragUniforms , -- vertex shader code shader $ ["cbuffer cbuf {"] <> uniformDecls vertUniforms -- non texture uniforms <> ["};"] <> ["struct VS_IN {"] <> [shaderDecl' (text t) (text n) | (n, t) <- zip vertInNames vertIns] <> ["};"] <> ["struct PS_IN {"] <> vertOutDecls "" <> ["};"] <> vertFuncs <> [shaderFunc "PS_IN" "VS" [("VS_IN" <+> "VS_input")] $ vertVals <> [shaderLet (text n) x | (n, x) <- zip vertOutNamesWithPosition vertHLSL] -- <> [shaderLet "gl_PointSize" x | Just x <- [ptHLSL]] ] , -- fragment shader code shader $ uniformDecls fragUniforms <> ["struct PS_IN {"] <> vertOutDecls "" <> ["};"] <> ["struct PS_OUT {"] <> [shaderDecl' (text t) (text n) | (n, t) <- zip fragOutNames fragOuts, backend == OpenGL33] <> ["};"] <> fragFuncs <> [shaderFunc "PS_OUT" "PS" [("PS_IN" <+> "PS_input")] $ fragVals <> [shaderStmt $ "if" <+> parens ("!" <> parens filt) <+> "discard" | Just filt <- [filtHLSL]] <> [shaderLet (text n) x | (n, x) <- zip fragOutNames fragHLSL ] ] ) where uniformDecls us = [shaderDecl' (text $ showHLSLType "2" t) (text n) | (n, (_, t)) <- Map.toList us] vertOutDecls io = [shaderDecl (text i <+> io) (text t) (text n) | (n, (i, t)) <- zip vertOutNames vertOuts] fragOutNames = case length frags of 0 -> [] 1 -> ["gl_FragColor"] (vertIns, verts) = case vert of Just (etaReds -> Just (xs, ETuple ys)) -> (toHLSLType "3" <$> xs, ys) Nothing -> ([toHLSLType "4" tvert], [mkTVar 0 tvert]) (fragOuts, frags) = case frag of Just (etaReds -> Just (xs, ETuple ys)) -> (toHLSLType "31" . tyOf <$> ys, ys) Nothing -> ([toHLSLType "41" tfrag], [mkTVar 0 tfrag]) (((vertHLSL, ptHLSL), (vertUniforms, (vertFuncs, vertVals))), ((filtHLSL, fragHLSL), (fragUniforms, (fragFuncs, fragVals)))) = flip evalState shaderNames $ do ((g1, (us1, verts)), (g2, (us2, frags))) <- (,) <$> runWriterT ((,) <$> traverse (genHLSL' "1" vertInNames . (,) vertIns) verts <*> traverse (genHLSL' "2" vertOutNamesWithPosition . reds) rp) <*> runWriterT ((,) <$> traverse (genHLSL' "3" vertOutNames . red) ffilter <*> traverse (genHLSL' "4" vertOutNames . (,) (snd <$> vertOuts)) frags) (,) <$> ((,) g1 <$> fixFuncs us1 mempty mempty verts) <*> ((,) g2 <$> fixFuncs us2 mempty mempty frags) fixFuncs :: Uniforms -> Set.Set SName -> ([Doc], [Doc]) -> Map.Map SName (ExpTV, ExpTV, [ExpTV]) -> State [SName] (Uniforms, ([Doc], [Doc])) fixFuncs us ns fsb (Map.toList -> fsa) | null fsa = return (us, fsb) | otherwise = do (unzip -> (defs, unzip -> (us', fs'))) <- forM fsa $ \(fn, (def, ty, tys)) -> runWriterT $ genHLSL (reverse $ take (length tys) funArgs) $ removeLams (length tys) def let fsb' = mconcat (zipWith combine fsa defs) <> fsb ns' = ns <> Set.fromList (map fst fsa) fixFuncs (us <> mconcat us') ns' fsb' (mconcat fs' `Map.difference` Map.fromSet (const undefined) ns') where combine (fn, (_, ty, tys)) def = case tys of [] -> ( [shaderDecl' ot n], [shaderLet n def] ) _ -> ( [shaderFunc ot n (zipWith (<+>) (map (toHLSLType "45") tys) (map text funArgs)) [shaderReturn def]] , [] ) where ot = toHLSLType "44" ty n = text fn funArgs = map (("z" ++) . show) [0..] shaderNames = map (("s" ++) . show) [0..] vertInNames = map (("vi" ++) . show) [1..length vertIns] vertOutNames = map (("vo" ++) . show) [1..length vertOuts] vertOutNamesWithPosition = "gl_Position": vertOutNames red (etaReds -> Just (ps, o)) = (ps, o) red x = error $ "red: " ++ ppShow x reds (etaReds -> Just (ps, o)) = (ps, o) reds x = error $ "red: " ++ ppShow x genHLSL' err vertOuts (ps, o) | length ps == length vertOuts = genHLSL (reverse vertOuts) o | otherwise = error $ "makeSubst illegal input " ++ err ++ " " ++ ppShow ps ++ "\n" ++ ppShow vertOuts noUnit TTuple0 = False noUnit _ = True vertOuts = zipWith go ints $ tail verts where go (A0 n) e = (interpName n, toHLSLType "3" $ tyOf e) interpName "Smooth" = "smooth" interpName "Flat" = "flat" interpName "NoPerspective" = "noperspective" shader xs = vcat $ [shaderFunc "vec4" "texture2D" ["sampler2D s", "vec2 uv"] [shaderReturn "texture(s,uv)"] | backend == OpenGL33] <> [shaderFunc "mat4" "transpose" ["mat4 m"] -- todo: not just for 4 dimension [ shaderLet "vec4 i0" "m[0]" , shaderLet "vec4 i1" "m[1]" , shaderLet "vec4 i2" "m[2]" , shaderLet "vec4 i3" "m[3]" , shaderReturn "mat4(\ \vec4(i0.x, i1.x, i2.x, i3.x),\ \vec4(i0.y, i1.y, i2.y, i3.y),\ \vec4(i0.z, i1.z, i2.z, i3.z),\ \vec4(i0.w, i1.w, i2.w, i3.w)\ \)" ] | backend == WebGL1 ] <> xs shaderFunc outtype name pars body = nest 4 (outtype <+> name <> tupled pars <+> "{" <$$> vcat body) <$$> "}" mainFunc xs = shaderFunc "void" "main" [] xs shaderStmt xs = nest 4 $ xs <> ";" shaderReturn xs = shaderStmt $ "return" <+> xs shaderLet a b = shaderStmt $ a <+> "=" b shaderDecl a b c = shaderDecl' (a <+> b) c shaderDecl' b c = shaderStmt $ b <+> c toHLSLType msg x = showHLSLType msg $ compInputType msg x -- move to lambdacube-ir? showHLSLType msg = \case IR.Bool -> "bool" IR.Word -> "uint" IR.Int -> "int" IR.Float -> "float" IR.V2F -> "float2" IR.V3F -> "float3" IR.V4F -> "float4" IR.V2B -> "bool2" IR.V3B -> "bool3" IR.V4B -> "bool4" IR.V2U -> "uint2" IR.V3U -> "uint3" IR.V4U -> "uint4" IR.V2I -> "int2" IR.V3I -> "int3" IR.V4I -> "int4" IR.M22F -> "float2x2" IR.M33F -> "float3x3" IR.M44F -> "float4x4" IR.M23F -> "float2x3" IR.M24F -> "float2x4" IR.M32F -> "float3x2" IR.M34F -> "float3x4" IR.M42F -> "float4x2" IR.M43F -> "float4x3" IR.FTexture2D -> "Texture2D" t -> error $ "toHLSLType: " ++ msg ++ " " ++ show t