----------------------------------------------------------------------------- -- -- Module : Shader -- Copyright : Tobias Bexelius -- License : BSD3 -- -- Maintainer : Tobias Bexelius -- Stability : Experimental -- Portability : Portable -- -- | -- ----------------------------------------------------------------------------- module Shader ( GPU(..), Shader, Uniform(..), addInput, addUniform, runShader, vertexProgram, fragmentProgram, colorFragmentShader, colorDepthFragmentShader, addVertexSamplerUniform, addFragmentSamplerUniform, Vertex(Vertex), Fragment(Fragment), Real'(..), dFdx, dFdy, fwidth, vSampleFunc, fSampleFunc, module Data.Boolean ) where import System.IO.Unsafe import Data.Vec ((:.)(..), Vec2, Vec3, Vec4, norm, normalize, dot, cross) import qualified Data.Vec as Vec import Data.Unique import Data.List import Data.Boolean import Control.Monad.State import Data.Map (Map) import qualified Data.Map as Map hiding (Map) import Data.IntSet (IntSet) import qualified Data.IntSet as IntSet hiding (IntSet) import Control.Arrow (first, second) import Resources import Formats -- | Denotes a type on the GPU, that can be moved there from the CPU (through the internal use of uniforms). -- Use the existing instances of this class to create new ones. Note that 'toGPU' should not be strict on its argument. -- Its definition should also always use the same series of 'toGPU' calls to convert values of the same type. This unfortunatly -- rules out ordinary lists (but instances for fixed length lists from the Vec package are however provided). class GPU a where -- | The type on the CPU. type CPU a -- | Converts a value from the CPU to the GPU. toGPU :: CPU a -> a type Shader = State (Map Unique Uniform, IntSet) vertexProgram vPos rast fins = let ((pos, outs), uns, ins) = runShader $ do vPos' <- vPos rast' <- selectM fins rast return (vPos', rast') sig = shaderInputSplits "attr" "va" ins ++ vertexOutputSets fins outs ++ "gl_Position = " ++ pos ++ ";\n" in (("#version 120\n" ++ uniformDecls "vu" uns ++ attributeDecls ins ++ varyingDecls outs ++ "void main(){\n" ++ sig ++ "}\n", sig), uns, ins) fragmentProgram m = let (outs, uns,ins) = runShader m sig = shaderInputSplits "var" "fa" ins ++outs in (("#version 120\n" ++ uniformDecls "fu" uns ++ varyingDecls ins ++ "void main(){\n" ++ sig ++ "}\n", sig), uns, ins) selectM (x:xs) ys = let (a:b) = drop x ys in do a' <- a b' <- selectM (map (\t -> t-x-1) xs) b return (a':b') selectM [] _ = return [] colorFragmentShader :: Fragment Bool -> Vec4 (Fragment Float) -> Shader String colorFragmentShader (Fragment nDisc) (Fragment r :. Fragment g :. Fragment b :. Fragment a :. ()) = do r' <- r g' <- g b' <- b a' <- a nDisc' <- nDisc return $ "if (!" ++ nDisc' ++ ") discard;\n" ++ "gl_FragColor=vec4(" ++ r' ++ "," ++ g' ++ "," ++ b' ++ "," ++ a' ++ ");\n" colorDepthFragmentShader :: Fragment Bool -> (Vec4 (Fragment Float), Fragment Float) -> Shader String colorDepthFragmentShader (Fragment nDisc) ((Fragment r :. Fragment g :. Fragment b :. Fragment a :. ()), Fragment d) = do r' <- r g' <- g b' <- b a' <- a d' <- d nDisc' <- nDisc return $ "if (!" ++ nDisc' ++ ") discard;\n" ++ "gl_FragDepth=" ++ d' ++ ";\n" ++ "gl_FragColor=vec4(" ++ r' ++ "," ++ g' ++ "," ++ b' ++ "," ++ a' ++ ");\n" uniformDecls :: String -> UniformSet -> String uniformDecls p (f,i,b,s) = let makeU tn xs = if not $ null xs then "uniform " ++ tn ++ p ++ "[" ++ show (length xs) ++ "];\n" else "" in makeU "float f" f ++ makeU "int i" i ++ makeU "bool b" b ++ concatMap (\(t,xs) -> makeU (samplerTypeString t ++ " s" ++ show (fromEnum t)) xs) s -- Generates e.g. [(0,4), (1,4), (2,3)] from [x,x,x,x,x,x,x,x,x,x,x] (length 11) inputVecs ins = [(i,min (length ins - i*4) 4) | i <- [0..(length ins - 1) `div` 4]] attributeDecls ins = concat [ "attribute " ++ tName v ++ " attr" ++ show i ++ ";\n" | (i,v) <- inputVecs ins] varyingDecls ins = concat [ "varying " ++ tName v ++ " var" ++ show i ++ ";\n" | (i,v) <- inputVecs ins] shaderInputSplits :: String -> String -> [Int] -> String shaderInputSplits from to ins = concat [ "float " ++ to ++ show i ++ " = " ++ from ++ show a ++ subElem c e ++ ";\n" | i <- ins | (a,e) <- inputVecs ins, c <- [0..3]] vertexOutputSets ins outs = concat [ "var" ++ show v ++ subElem c e ++ " = " ++ outs!!i ++ ";\n" | i <- ins | (v,e) <- inputVecs ins, c <- [0..3]] subElem :: Int -> Int -> String subElem _ 1 = "" subElem x _ = ['.', (['x','y','z','w']!!x)] tName :: Int -> String tName 1 = "float" tName x = "vec" ++ show x data Uniform = FloatUniform Float | IntUniform Int | BoolUniform Bool | SamplerUniform SamplerType Sampler WinMappedTexture samplerTypeString Sampler3D = "sampler3D" samplerTypeString Sampler2D = "sampler2D" samplerTypeString Sampler1D = "sampler1D" samplerTypeString SamplerCube = "samplerCube" vSampleFunc f t s tex c xs = toColor $ fromVVec 4 (vListFunc f $ [addVertexSamplerUniform t s tex, vVec c] ++ xs) fSampleFunc f t s tex c xs = toColor $ fromFVec 4 (fListFunc f $ [addFragmentSamplerUniform t s tex, fVec c] ++ xs) addVertexSamplerUniform t s = Vertex . addUniform ("s" ++ show (fromEnum t) ++ "vu") . SamplerUniform t s addFragmentSamplerUniform t s = Fragment . addUniform ("s" ++ show (fromEnum t) ++ "fu") . SamplerUniform t s -- | An opaque type constructor for atomic values in a vertex on the GPU, e.g. 'Vertex' 'Float'. newtype Vertex a = Vertex { fromVertex :: Shader String } -- | An opaque type constructor for atomic values in a fragment on the GPU, e.g. 'Fragment' 'Float'. newtype Fragment a = Fragment { fromFragment :: Shader String } runShader :: Shader a -> (a, UniformSet, [Int]) runShader m = (a, getSamplerList $ splitSet ([],[],[], Map.empty) $ reverse $ Map.elems $ fst s, IntSet.toAscList $ snd s) where (a,s) = runState m (Map.empty, IntSet.empty) splitSet (f,i,b,s) (FloatUniform u:xs) = splitSet (u:f,i,b,s) xs splitSet (f,i,b,s) (IntUniform u:xs) = splitSet (f,u:i,b,s) xs splitSet (f,i,b,s) (BoolUniform u:xs) = splitSet (f,i,u:b,s) xs splitSet (f,i,b,s) (SamplerUniform t samp tex:xs) = splitSet (f,i,b, Map.insertWith (++) t [(samp,tex)] s) xs splitSet s [] = s getSamplerList (f,i,b,s) = (f,i,b, Map.toList s) addInput :: Int -> Shader () addInput = modify . second . IntSet.insert addUniform :: String -> Uniform -> Shader String addUniform p u = do x <- gets fst case Map.lookupIndex n x of Nothing -> do let x' = Map.insert n u' x s = Map.findIndex n x' modify $ first $ const x' return $ p ++ "[" ++ show s ++ "]" Just i -> return $ p ++ "[" ++ show i ++ "]" where (n,u') = unsafePerformIO $ do n <- newUnique return (n,u) --wire u to make it happen as often as we want... instance GPU (Vertex Float) where type CPU (Vertex Float) = Float toGPU = Vertex . addUniform "fvu" . FloatUniform instance GPU (Vertex Int) where type CPU (Vertex Int) = Int toGPU = Vertex . addUniform "ivu" . IntUniform instance GPU (Vertex Bool) where type CPU (Vertex Bool) = Bool toGPU = Vertex . addUniform "bvu" . BoolUniform instance GPU (Fragment Float) where type CPU (Fragment Float) = Float toGPU = Fragment . addUniform "ffu" . FloatUniform instance GPU (Fragment Int) where type CPU (Fragment Int) = Int toGPU = Fragment . addUniform "ifu" . IntUniform instance GPU (Fragment Bool) where type CPU (Fragment Bool) = Bool toGPU = Fragment . addUniform "bfu" . BoolUniform instance GPU () where type CPU () = () toGPU = id instance (GPU a, GPU b) => GPU (a,b) where type CPU (a,b) = (CPU a, CPU b) toGPU (a,b)= (toGPU a, toGPU b) instance (GPU a, GPU b, GPU c) => GPU (a,b,c) where type CPU (a,b,c) = (CPU a, CPU b, CPU c) toGPU (a,b,c)= (toGPU a, toGPU b, toGPU c) instance (GPU a, GPU b, GPU c, GPU d) => GPU (a,b,c,d) where type CPU (a,b,c,d) = (CPU a, CPU b, CPU c, CPU d) toGPU (a,b,c,d)= (toGPU a, toGPU b, toGPU c, toGPU d) instance (GPU a, GPU b) => GPU (a:.b) where type CPU (a:.b) = CPU a :. CPU b toGPU (a:.b) = toGPU a :. toGPU b instance Eq (Vertex a) where (==) = noFun "(==)" (/=) = noFun "(/=)" instance Eq (Fragment a) where (==) = noFun "(==)" (/=) = noFun "(/=)" instance Ord a => Ord (Vertex a) where (<=) = noFun "(<=)" min = vBinaryFunc "min" max = vBinaryFunc "max" instance Ord a => Ord (Fragment a) where (<=) = noFun "(<=)" min = fBinaryFunc "min" max = fBinaryFunc "max" instance Show (Vertex a) where show = noFun "show" instance Show (Fragment a) where show = noFun "show" instance Num a => Num (Vertex a) where negate = vUnaryPreOp "-" (+) = vBinaryOp "+" (*) = vBinaryOp "*" fromInteger a = Vertex $ return $ show (fromInteger a :: a) abs = vUnaryFunc "abs" signum = vUnaryFunc "sign" instance Num a => Num (Fragment a) where negate = fUnaryPreOp "-" (+) = fBinaryOp "+" (*) = fBinaryOp "*" fromInteger a = Fragment $ return $ show (fromInteger a :: a) abs = fUnaryFunc "abs" signum = fUnaryFunc "sign" instance Fractional a => Fractional (Vertex a) where (/) = vBinaryOp "/" fromRational a = Vertex $ return $ show (fromRational a :: a) instance Fractional a => Fractional (Fragment a) where (/) = fBinaryOp "/" fromRational a = Fragment $ return $ show (fromRational a :: a) instance Floating a => Floating (Vertex a) where pi = fromRational (toRational (pi :: Double)) sqrt = vUnaryFunc "sqrt" exp = vUnaryFunc "exp" log = vUnaryFunc "log" (**) = vBinaryFunc "pow" sin = vUnaryFunc "sin" cos = vUnaryFunc "cos" tan = vUnaryFunc "tan" asin = vUnaryFunc "asin" acos = vUnaryFunc "acos" atan = vUnaryFunc "atan" sinh = noFun "sinh" cosh = noFun "cosh" asinh = noFun "asinh" atanh = noFun "atanh" acosh = noFun "acosh" instance Floating a => Floating (Fragment a) where pi = fromRational (toRational (pi :: Double)) sqrt = fUnaryFunc "sqrt" exp = fUnaryFunc "exp" log = fUnaryFunc "log" (**) = fBinaryFunc "pow" sin = fUnaryFunc "sin" cos = fUnaryFunc "cos" tan = fUnaryFunc "tan" asin = fUnaryFunc "asin" acos = fUnaryFunc "acos" atan = fUnaryFunc "atan" sinh = noFun "sinh" cosh = noFun "cosh" asinh = noFun "asinh" atanh = noFun "atanh" acosh = noFun "acosh" -- | This class provides the GPU functions either not found in Prelude's numerical classes, or that has wrong types. -- Instances are also provided for normal 'Float's and 'Double's. -- Minimal complete definition: 'floor'' and 'ceiling''. class (Ord a, Floating a) => Real' a where rsqrt :: a -> a exp2 :: a -> a log2 :: a -> a floor' :: a -> a ceiling' :: a -> a fract' :: a -> a mod' :: a -> a -> a clamp :: a -> a -> a -> a saturate :: a -> a mix :: a -> a -> a-> a step :: a -> a -> a smoothstep :: a -> a -> a -> a rsqrt = (1/) . sqrt exp2 = (2**) log2 = logBase 2 clamp x a b = min (max x a) b saturate x = clamp x 0 1 mix x y a = x*(1-a)+y*a step a x | x < a = 0 | otherwise = 1 smoothstep a b x = let t = saturate ((x-a) / (b-a)) in t*t*(3-2*t) fract' x = x - floor' x mod' x y = x - y* floor' (x/y) instance Real' Float where floor' = fromIntegral . floor ceiling' = fromIntegral . ceiling instance Real' Double where floor' = fromIntegral . floor ceiling' = fromIntegral . ceiling instance Real' (Vertex Float) where rsqrt = vUnaryFunc "inversesqrt" exp2 = vUnaryFunc "exp2" log2 = vUnaryFunc "log2" floor' = vUnaryFunc "floor" ceiling' = vUnaryFunc "ceil" fract' = vUnaryFunc "fract" mod' = vBinaryFunc "mod" clamp x a b = vListFunc "clamp" [x,a,b] mix x y a = vListFunc "mix" [x,y,a] step = vBinaryFunc "step" smoothstep a b x = vListFunc "smoothstep" [a,b,x] instance Real' (Fragment Float) where rsqrt = fUnaryFunc "inversesqrt" exp2 = fUnaryFunc "exp2" log2 = fUnaryFunc "log2" floor' = fUnaryFunc "floor" ceiling' = fUnaryFunc "ceil" fract' = fUnaryFunc "fract" mod' = fBinaryFunc "mod" clamp x a b = fListFunc "clamp" [x,a,b] mix x y a = fListFunc "mix" [x,y,a] step = fBinaryFunc "step" smoothstep a b x = fListFunc "smoothstep" [a,b,x] instance Boolean (Vertex Bool) where true = Vertex $ return "true" false = Vertex $ return "false" notB = vUnaryPreOp "!" (&&*) = vBinaryOp "&&" (||*) = vBinaryOp "||" instance Boolean (Fragment Bool) where true = Fragment $ return "true" false = Fragment $ return "false" notB = fUnaryPreOp "!" (&&*) = fBinaryOp "&&" (||*) = fBinaryOp "||" instance Eq a => EqB (Vertex Bool) (Vertex a) where (==*) = vBinaryOp "==" (/=*) = vBinaryOp "!=" instance Eq a => EqB (Fragment Bool) (Fragment a) where (==*) = fBinaryOp "==" (/=*) = fBinaryOp "!=" instance Ord a => OrdB (Vertex Bool) (Vertex a) where (<*) = vBinaryOp "<" (>=*) = vBinaryOp ">=" (>*) = vBinaryOp ">" (<=*) = vBinaryOp "<=" instance Ord a => OrdB (Fragment Bool) (Fragment a) where (<*) = fBinaryOp "<" (>=*) = fBinaryOp ">=" (>*) = fBinaryOp ">" (<=*) = fBinaryOp "<=" instance IfB (Vertex Bool) (Vertex a) where ifB c a b = Vertex $ do c' <- fromVertex c a' <- fromVertex a b' <- fromVertex b return $ "(" ++ c' ++ "?" ++ a' ++ ":" ++ b' ++ ")" instance IfB (Fragment Bool) (Fragment a) where ifB c a b = Fragment $ do c' <- fromFragment c a' <- fromFragment a b' <- fromFragment b return $ "(" ++ c' ++ "?" ++ a' ++ ":" ++ b' ++ ")" -- | The derivative in x using local differencing of the rasterized value. dFdx :: Fragment Float -> Fragment Float -- | The derivative in y using local differencing of the rasterized value. dFdy :: Fragment Float -> Fragment Float -- | The sum of the absolute derivative in x and y using local differencing of the rasterized value. fwidth :: Fragment Float -> Fragment Float dFdx = fUnaryFunc "dFdx" dFdy = fUnaryFunc "dFdy" fwidth = fUnaryFunc "fwidth" -------------------------------------- -- Vector specializations {-# RULES "norm/F4" norm = normF4 #-} {-# RULES "norm/F3" norm = normF3 #-} {-# RULES "norm/F2" norm = normF2 #-} normF4 :: Vec4 (Fragment Float) -> Fragment Float normF4 = fUnaryFunc "length" . fVec normF3 :: Vec3 (Fragment Float) -> Fragment Float normF3 = fUnaryFunc "length" . fVec normF2 :: Vec2 (Fragment Float) -> Fragment Float normF2 = fUnaryFunc "length" . fVec {-# RULES "norm/V4" norm = normV4 #-} {-# RULES "norm/V3" norm = normV3 #-} {-# RULES "norm/V2" norm = normV2 #-} normV4 :: Vec4 (Vertex Float) -> Vertex Float normV4 = vUnaryFunc "length" . vVec normV3 :: Vec3 (Vertex Float) -> Vertex Float normV3 = vUnaryFunc "length" . vVec normV2 :: Vec2 (Vertex Float) -> Vertex Float normV2 = vUnaryFunc "length" . vVec {-# RULES "normalize/F4" normalize = normalizeF4 #-} {-# RULES "normalize/F3" normalize = normalizeF3 #-} {-# RULES "normalize/F2" normalize = normalizeF2 #-} normalizeF4 :: Vec4 (Fragment Float) -> Vec4 (Fragment Float) normalizeF4 = fromFVec 4 . fUnaryFunc "normalize" . fVec normalizeF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) normalizeF3 = fromFVec 3 . fUnaryFunc "normalize" . fVec normalizeF2 :: Vec2 (Fragment Float) -> Vec2 (Fragment Float) normalizeF2 = fromFVec 2 . fUnaryFunc "normalize" . fVec {-# RULES "normalize/V4" normalize = normalizeV4 #-} {-# RULES "normalize/V3" normalize = normalizeV3 #-} {-# RULES "normalize/V2" normalize = normalizeV2 #-} normalizeV4 :: Vec4 (Vertex Float) -> Vec4 (Vertex Float) normalizeV4 = fromVVec 4 . vUnaryFunc "normalize" . vVec normalizeV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) normalizeV3 = fromVVec 3 . vUnaryFunc "normalize" . vVec normalizeV2 :: Vec2 (Vertex Float) -> Vec2 (Vertex Float) normalizeV2 = fromVVec 2 . vUnaryFunc "normalize" . vVec {-# RULES "dot/F4" dot = dotF4 #-} {-# RULES "dot/F3" dot = dotF3 #-} {-# RULES "dot/F2" dot = dotF2 #-} dotF4 :: Vec4 (Fragment Float) -> Vec4 (Fragment Float) -> Fragment Float dotF4 a b = fBinaryFunc "dot" (fVec a) (fVec b) dotF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) -> Fragment Float dotF3 a b = fBinaryFunc "dot" (fVec a) (fVec b) dotF2 :: Vec2 (Fragment Float) -> Vec2 (Fragment Float) -> Fragment Float dotF2 a b = fBinaryFunc "dot" (fVec a) (fVec b) {-# RULES "dot/V4" dot = dotV4 #-} {-# RULES "dot/V3" dot = dotV3 #-} {-# RULES "dot/V2" dot = dotV2 #-} dotV4 :: Vec4 (Vertex Float) -> Vec4 (Vertex Float) -> Vertex Float dotV4 a b = vBinaryFunc "dot" (vVec a) (vVec b) dotV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) -> Vertex Float dotV3 a b = vBinaryFunc "dot" (vVec a) (vVec b) dotV2 :: Vec2 (Vertex Float) -> Vec2 (Vertex Float) -> Vertex Float dotV2 a b = vBinaryFunc "dot" (vVec a) (vVec b) {-# RULES "cross/F3" cross = crossF3 #-} crossF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) -> Vec3 (Fragment Float) crossF3 a b = fromFVec 3 $ fBinaryFunc "cross" (fVec a) (fVec b) {-# RULES "cross/V3" cross = crossV3 #-} crossV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) ->Vec3 (Vertex Float) crossV3 a b = fromVVec 3 $ vBinaryFunc "cross" (vVec a) (vVec b) -------------------------------------- -- Private -- noFun :: String -> a noFun = error . (++ ": No overloading for Vertex/Fragment") vListFunc s xs = Vertex $ do xs' <- mapM fromVertex xs return $ s ++ "(" ++ intercalate "," xs' ++ ")" fListFunc s xs = Fragment $ do xs' <- mapM fromFragment xs return $ s ++ "(" ++ intercalate "," xs' ++ ")" vUnaryFunc s a = vListFunc s [a] fUnaryFunc s a = fListFunc s [a] vBinaryFunc s a b = vListFunc s [a,b] fBinaryFunc s a b = fListFunc s [a,b] vUnaryPreOp s a = Vertex $ do a' <- fromVertex a return $ "(" ++ s ++ a' ++ ")" fUnaryPreOp s a = Fragment $ do a' <- fromFragment a return $ "(" ++ s ++ a' ++ ")" vBinaryOp s a b = Vertex $ do a' <- fromVertex a b' <- fromVertex b return $ "(" ++ a' ++ s ++ b' ++ ")" fBinaryOp s a b = Fragment $ do a' <- fromFragment a b' <- fromFragment b return $ "(" ++ a' ++ s ++ b' ++ ")" vVec v = let xs = Vec.toList v in vListFunc (tName $ length xs) xs fVec v = let xs = Vec.toList v in fListFunc (tName $ length xs) xs fromVVec e v = Vec.fromList $ map (\n -> Vertex $ do v' <- fromVertex v return (v' ++ subElem n e)) [0..(e-1)] fromFVec e v = Vec.fromList $ map (\n -> Fragment $ do v' <- fromFragment v return (v' ++ subElem n e)) [0..(e-1)]