{-# LANGUAGE TypeOperators, TypeFamilies, FlexibleInstances, MultiParamTypeClasses, EmptyDataDecls, TypeSynonymInstances #-} ----------------------------------------------------------------------------- -- -- Module : Shader -- Copyright : Tobias Bexelius -- License : BSD3 -- -- Maintainer : Tobias Bexelius -- Stability : Experimental -- Portability : Portable -- -- | -- ----------------------------------------------------------------------------- module Shader ( GPU(..), rasterizeVertex, inputVertex, fragmentFrontFacing, Shader(), V, F, Vertex, Fragment, ShaderInfo, getShaders, Real'(..), Convert(..), dFdx, dFdy, fwidth, sampleBinFunc, sampleTernFunc, module Data.Boolean ) where import Control.Monad.Trans.State.Lazy (put, get, StateT, runStateT) import System.IO.Unsafe import Data.Vec ((:.)(..), Vec2, Vec3, Vec4, norm, normalize, dot, cross) import qualified Data.Vec as Vec import Data.Unique import Data.List import Data.Maybe import Data.Boolean import Data.Map (Map) import qualified Data.Map as Map hiding (Map) import qualified Data.HashTable as HT import Control.Exception (evaluate) import System.Mem.StableName import Data.IntSet (IntSet) import qualified Data.IntSet as IntSet hiding (IntSet) import Control.Arrow (first, second) import Resources import Formats infixl 7 `mod'` -- | Denotes a type on the GPU, that can be moved there from the CPU (through the internal use of uniforms). -- Use the existing instances of this class to create new ones. class GPU a where -- | The type on the CPU. type CPU a -- | Converts a value from the CPU to the GPU. toGPU :: CPU a -> a data ShaderTree = ShaderUniform !Uniform | ShaderConstant !Const | ShaderInput !Int | ShaderInputTree ShaderTree | ShaderOp !Op (String -> [String] -> String) [ShaderTree] type ShaderDAG = ([Int],[(ShaderTree, [Int])]) -- | An opaque type constructor for atomic values in a specific GPU context (i.e. 'V' or 'F'), e.g. 'Shader' 'V' 'Float'. newtype Shader c t = Shader { fromS :: ShaderTree } -- | Used to denote a vertex context in the first parameter to 'Shader' data V -- | Used to denote a fragment context in the first parameter to 'Shader' data F -- | A type synonyme for atomic values in a vertex on the GPU, e.g. 'Vertex' 'Float'. type Vertex = Shader V -- | A type synonyme for atomic values in a fragment on the GPU, e.g. 'Fragment' 'Float'. type Fragment = Shader F rasterizeVertex :: Vertex Float -> Fragment Float rasterizeVertex = Shader . ShaderInputTree . fromS inputVertex :: Int -> Vertex Float inputVertex = Shader . ShaderInput fragmentFrontFacing :: Fragment Bool fragmentFrontFacing = Shader $ ShaderOp "gl_ff" (assign bool (const "gl_FrontFacing")) [] getShaders :: Vec4 (Vertex Float) -> Fragment Bool -> Vec4 (Fragment Float) -> Maybe (Fragment Float) -> (ShaderInfo, ShaderInfo, [Int]) getShaders pos (Shader ndisc) color mdepth = ((createShaderKey vdag,vstr,vuns),(createShaderKey fdag,fstr,funs), inputs) where fcolor = fromS $ fromVec "vec4" color (varyings, fdag@(fcolor':ndisc':mdepth',_)) = splitShaders (createDAG (fcolor:ndisc: map fromS (maybeToList mdepth))) vpos = fromS $ fromVec "vec4" pos vdag@(vpos':varyings',_) = createDAG (vpos:varyings) inputs = extractInputs vdag vcodeAssigns = getCodeAssignments (fromJust . flip elemIndex inputs) (length inputs) "v" vdag vCodeFinish = setVaryings varyings' ++ "gl_Position = t" ++ show vpos' ++ ";\n" fcodeAssigns = getCodeAssignments id (length varyings') "f" fdag depthAssign = case mdepth' of [d] -> "gl_FragDepth = t" ++ show d ++ ";\n" [] -> "" fcodeFinish = "if (!t" ++ show ndisc' ++ ") discard;\n" ++ depthAssign ++ "gl_FragColor = t" ++ show fcolor' ++ ";\n" vuns = extractUniforms vdag funs = extractUniforms fdag attributeDecl = inoutDecls "attribute" "v" (length inputs) varyingDecl = inoutDecls "varying" "f" (length varyings') vstr = makeShader (attributeDecl ++ varyingDecl ++ uniformDecls "v" vuns) (vcodeAssigns ++ vCodeFinish) fstr = makeShader (varyingDecl ++ uniformDecls "f" funs) (fcodeAssigns ++ fcodeFinish) sampleBinFunc f t s tex c = toColor $ toVec float 4 (binaryFunc "vec4" f (Shader $ ShaderUniform $ UniformSampler t s tex) (fromVec (tName c) c)) sampleTernFunc f t s tex c x = toColor $ toVec float 4 (ternaryFunc "vec4" f (Shader $ ShaderUniform $ UniformSampler t s tex) (fromVec (tName c) c) x) instance GPU (Shader c Float) where type CPU (Shader c Float) = Float toGPU = Shader . ShaderUniform . UniformFloat instance GPU (Shader c Int) where type CPU (Shader c Int) = Int toGPU = Shader . ShaderUniform . UniformInt instance GPU (Shader c Bool) where type CPU (Shader c Bool) = Bool toGPU = Shader . ShaderUniform . UniformBool instance GPU () where type CPU () = () toGPU = id instance (GPU a, GPU b) => GPU (a,b) where type CPU (a,b) = (CPU a, CPU b) toGPU (a,b)= (toGPU a, toGPU b) instance (GPU a, GPU b, GPU c) => GPU (a,b,c) where type CPU (a,b,c) = (CPU a, CPU b, CPU c) toGPU (a,b,c)= (toGPU a, toGPU b, toGPU c) instance (GPU a, GPU b, GPU c, GPU d) => GPU (a,b,c,d) where type CPU (a,b,c,d) = (CPU a, CPU b, CPU c, CPU d) toGPU (a,b,c,d)= (toGPU a, toGPU b, toGPU c, toGPU d) instance (GPU a, GPU b) => GPU (a:.b) where type CPU (a:.b) = CPU a :. CPU b toGPU (a:.b) = toGPU a :. toGPU b instance Num (Shader c Float) where negate = unaryPreOp float "-" (+) = binaryOp float "+" (*) = binaryOp float "*" fromInteger = Shader . ShaderConstant . ConstFloat . fromInteger abs = unaryFunc float "abs" signum = unaryFunc float "sign" instance Num (Shader c Int) where negate = unaryPreOp int "-" (+) = binaryOp int "+" (*) = binaryOp int "*" fromInteger = Shader . ShaderConstant . ConstInt . fromInteger abs x = ifB (x <* 0) (-x) x signum x = ifB (x <* 0) (-1) 1 instance Fractional (Shader c Float) where (/) = binaryOp float "/" fromRational = Shader . ShaderConstant . ConstFloat . fromRational instance Floating (Shader c Float) where pi = Shader $ ShaderConstant $ ConstFloat pi sqrt = unaryFunc float "sqrt" exp = unaryFunc float "exp" log = unaryFunc float "log" (**) = binaryFunc float "pow" sin = unaryFunc float "sin" cos = unaryFunc float "cos" tan = unaryFunc float "tan" asin = unaryFunc float "asin" acos = unaryFunc float "acos" atan = unaryFunc float "atan" sinh x = (exp x - exp (-x)) / 2 cosh x = (exp x + exp (-x)) / 2 asinh x = log (x + sqrt (x * x + 1)) atanh x = log ((1 + x) / (1 - x)) / 2 acosh x = log (x + sqrt (x * x - 1)) -- | This class provides the GPU functions either not found in Prelude's numerical classes, or that has wrong types. -- Instances are also provided for normal 'Float's and 'Double's. -- Minimal complete definition: 'floor'' and 'ceiling''. class Floating a => Real' a where rsqrt :: a -> a exp2 :: a -> a log2 :: a -> a floor' :: a -> a ceiling' :: a -> a fract' :: a -> a mod' :: a -> a -> a clamp :: a -> a -> a -> a saturate :: a -> a mix :: a -> a -> a-> a step :: a -> a -> a smoothstep :: a -> a -> a -> a rsqrt = (1/) . sqrt exp2 = (2**) log2 = logBase 2 saturate x = clamp x 0 1 mix x y a = x*(1-a)+y*a smoothstep a b x = let t = saturate ((x-a) / (b-a)) in t*t*(3-2*t) fract' x = x - floor' x mod' x y = x - y* floor' (x/y) instance Real' Float where clamp x a = min (max x a) step a x | x < a = 0 | otherwise = 1 floor' = fromIntegral . floor ceiling' = fromIntegral . ceiling instance Real' Double where clamp x a = min (max x a) step a x | x < a = 0 | otherwise = 1 floor' = fromIntegral . floor ceiling' = fromIntegral . ceiling instance Real' (Shader c Float) where rsqrt = unaryFunc float "inversesqrt" exp2 = unaryFunc float "exp2" log2 = unaryFunc float "log2" floor' = unaryFunc float "floor" ceiling' = unaryFunc float "ceil" fract' = unaryFunc float "fract" mod' = binaryFunc float "mod" clamp = ternaryFunc float "clamp" mix = ternaryFunc float "mix" step = binaryFunc float "step" smoothstep = ternaryFunc float "smoothstep" instance Boolean (Shader c Bool) where true = Shader $ ShaderConstant $ ConstBool True false = Shader $ ShaderConstant $ ConstBool False notB = unaryPreOp bool "!" (&&*) = binaryOp bool "&&" (||*) = binaryOp bool "||" instance Eq a => EqB (Shader c Bool) (Shader c a) where (==*) = binaryOp bool "==" (/=*) = binaryOp bool "!=" instance Ord a => OrdB (Shader c Bool) (Shader c a) where (<*) = binaryOp bool "<" (>=*) = binaryOp bool ">=" (>*) = binaryOp bool ">" (<=*) = binaryOp bool "<=" instance IfB (Shader c Bool) (Shader c Int) where ifB c a b = Shader $ ShaderOp "if" (assign int (\[a,b,c]->a++"?"++b++":"++c)) [fromS c,fromS a,fromS b] instance IfB (Shader c Bool) (Shader c Float) where ifB c a b = Shader $ ShaderOp "if" (assign float (\[a,b,c]->a++"?"++b++":"++c)) [fromS c,fromS a,fromS b] instance IfB (Shader c Bool) (Shader c Bool) where ifB c a b = Shader $ ShaderOp "if" (assign bool (\[a,b,c]->a++"?"++b++":"++c)) [fromS c,fromS a,fromS b] -- | Provides a common way to convert numeric types to integer and floating point representations. class Convert a where type ConvertFloat a type ConvertInt a -- | Convert to a floating point number. toFloat :: a -> ConvertFloat a -- | Convert to an integral number, using truncation if necessary. toInt :: a -> ConvertInt a instance Convert Float where type ConvertFloat Float = Float type ConvertInt Float = Int toFloat = id toInt = truncate instance Convert Int where type ConvertFloat Int = Float type ConvertInt Int = Int toFloat = fromIntegral toInt = id instance Convert (Shader c Float) where type ConvertFloat (Shader c Float) = Shader c Float type ConvertInt (Shader c Float) = Shader c Int toFloat = id toInt = unaryFunc int int instance Convert (Shader c Int) where type ConvertFloat (Shader c Int) = Shader c Float type ConvertInt (Shader c Int) = Shader c Int toFloat = unaryFunc float float toInt = id -- | The derivative in x using local differencing of the rasterized value. dFdx :: Fragment Float -> Fragment Float -- | The derivative in y using local differencing of the rasterized value. dFdy :: Fragment Float -> Fragment Float -- | The sum of the absolute derivative in x and y using local differencing of the rasterized value. fwidth :: Fragment Float -> Fragment Float dFdx = unaryFunc float "dFdx" dFdy = unaryFunc float "dFdy" fwidth = unaryFunc float "fwidth" -------------------------------------- -- Vector specializations {-# RULES "norm/F4" norm = normF4 #-} {-# RULES "norm/F3" norm = normF3 #-} {-# RULES "norm/F2" norm = normF2 #-} normF4 :: Vec4 (Shader c Float) -> Shader c Float normF4 = unaryFunc float "length" . fromVec "vec4" normF3 :: Vec3 (Shader c Float) -> Shader c Float normF3 = unaryFunc float "length" . fromVec "vec3" normF2 :: Vec2 (Shader c Float) -> Shader c Float normF2 = unaryFunc float "length" . fromVec "vec2" {-# RULES "normalize/F4" normalize = normalizeF4 #-} {-# RULES "normalize/F3" normalize = normalizeF3 #-} {-# RULES "normalize/F2" normalize = normalizeF2 #-} normalizeF4 :: Vec4 (Shader c Float) -> Vec4 (Shader c Float) normalizeF4 = toVec float 4 . unaryFunc "vec4" "normalize" . fromVec "vec4" normalizeF3 :: Vec3 (Shader c Float) -> Vec3 (Shader c Float) normalizeF3 = toVec float 3 . unaryFunc "vec3" "normalize" . fromVec "vec3" normalizeF2 :: Vec2 (Shader c Float) -> Vec2 (Shader c Float) normalizeF2 = toVec float 2 . unaryFunc "vec2" "normalize" . fromVec "vec2" {-# RULES "dot/F4" dot = dotF4 #-} {-# RULES "dot/F3" dot = dotF3 #-} {-# RULES "dot/F2" dot = dotF2 #-} dotF4 :: Vec4 (Shader c Float) -> Vec4 (Shader c Float) -> Shader c Float dotF4 a b = binaryFunc float "dot" (fromVec "vec4" a) (fromVec "vec4" b) dotF3 :: Vec3 (Shader c Float) -> Vec3 (Shader c Float) -> Shader c Float dotF3 a b = binaryFunc float "dot" (fromVec "vec3" a) (fromVec "vec3" b) dotF2 :: Vec2 (Shader c Float) -> Vec2 (Shader c Float) -> Shader c Float dotF2 a b = binaryFunc float "dot" (fromVec "vec2" a) (fromVec "vec2" b) {-# RULES "cross/F3" cross = crossF3 #-} crossF3 :: Vec3 (Shader c Float) -> Vec3 (Shader c Float) -> Vec3 (Shader c Float) crossF3 a b = toVec float 3 $ binaryFunc "vec3" "cross" (fromVec "vec3" a) (fromVec "vec3" b) {-# RULES "minB/F" minB = minS #-} {-# RULES "maxB/F" maxB = maxS #-} minS :: Shader a Float -> Shader a Float -> Shader a Float minS = binaryFunc float "min" maxS :: Shader a Float -> Shader a Float -> Shader a Float maxS = binaryFunc float "max" -------------------------------------- -- Private -- setVaryings xs = setVaryings' 0 $ map (('t':) . show) xs where setVaryings' _ [] = "" setVaryings' n xs = case splitAt 4 xs of (ys,rest) -> "f" ++ show n ++ " = " ++ tName' (length ys) ++ "(" ++ intercalate "," ys ++ ");\n" ++ setVaryings' (n+1) rest inoutDecls t n i = inoutDecls' i 0 where inoutDecls' i x | i >= 4 = t ++ " vec4 " ++ n ++ show x ++ ";\n" ++ inoutDecls' (i-4) (x+1) | i == 0 = "" | otherwise = t ++ " " ++ tName' i ++ " " ++ n ++ show x ++ ";\n" uniformDecls :: String -> UniformSet -> String uniformDecls p (f,i,b,s) = makeU float "f" (length f) ++ makeU int "i" (length i) ++ makeU bool "b" (length b) ++ concatMap (\(t,xs) -> makeU (sampName t) ('s':show (fromEnum t)) (length xs)) (Map.toList s) where makeU t n 0 = "" makeU t n i = "uniform " ++ t ++ " " ++ p ++ "u" ++ n ++ "[" ++ show i ++ "];\n" makeShader init assignments = "#version 120\n" ++ init ++ "void main(){\n" ++ assignments ++ "}\n" createShaderKey :: ShaderDAG -> ShaderKey createShaderKey (a,xs) = (a,map (first toShaderKeyNode) xs) where toShaderKeyNode (ShaderUniform _) = ShaderKeyUniform toShaderKeyNode (ShaderInput a) = ShaderKeyInput a toShaderKeyNode (ShaderConstant a) = ShaderKeyConstant a toShaderKeyNode (ShaderOp a _ _) = ShaderKeyOp a toShaderKeyNode (ShaderInputTree _) = error "Use splitShaders first" splitShaders :: ShaderDAG -> ([ShaderTree], ShaderDAG) -- ^ (previous, current) splitShaders (a,xs) = case mapAccumL splitNode [] xs of (trees, xs2) -> (reverse trees, (a,xs2)) where splitNode ts (ShaderInputTree a, ys) = (a:ts, (ShaderInput (length ts), ys)) splitNode ts a = (ts, a) createDAG :: [ShaderTree] -> ShaderDAG createDAG = second reverse . unsafePerformIO . startDAG where startDAG xs = do ht <- HT.new (==) (fromIntegral . hashStableName) runStateT (mapM (createDAG' ht) xs) [] createDAG' :: HT.HashTable (StableName ShaderTree) Int -> ShaderTree -> StateT [(ShaderTree, [Int])] IO Int createDAG' ht n = do n' <- liftIO $ evaluate n -- To make makeStableName "stable" k <- liftIO $ makeStableName n' m <- liftIO $ HT.lookup ht k case m of Just i -> return i Nothing -> do xs' <- case n' of ShaderOp _ _ xs -> mapM (createDAG' ht) xs _ -> return [] ys <- get let y = length ys liftIO $ HT.insert ht k y put $ (n',xs'):ys return y extractUniforms :: ShaderDAG -> UniformSet extractUniforms (_,xs) = foldl' extractUniform ([],[],[],Map.empty) $ reverse $ map fst xs where extractUniform (a,b,c,m) (ShaderUniform (UniformFloat x)) = (x:a,b,c,m) extractUniform (a,b,c,m) (ShaderUniform (UniformInt x)) = (a,x:b,c,m) extractUniform (a,b,c,m) (ShaderUniform (UniformBool x)) = (a,b,x:c,m) extractUniform (a,b,c,m) (ShaderUniform (UniformSampler t s tex)) = (a,b,c,Map.insertWith' (++) t [(s,tex)] m) extractUniform x _ = x extractInputs :: ShaderDAG -> [Int] extractInputs (_,xs) = IntSet.toAscList $ foldl' extractIn IntSet.empty $ map fst xs where extractIn s (ShaderInput a) = IntSet.insert a s extractIn x _ = x getCodeAssignments :: (Int -> Int) -> Int -> String -> ShaderDAG -> String getCodeAssignments inF numIns inName (_,xs) = concat $ snd $ mapAccumL getCode ((0,0,0,Map.empty),Map.empty) $ zip [0..] xs where getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformFloat _), _)) = (((f+1,i,b,s),inlns), assign float (const $ inName ++ "uf[" ++ show f ++ "]") (var n) []) getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformInt _), _)) = (((f,i+1,b,s),inlns), assign int (const $ inName ++ "ui[" ++ show i ++ "]") (var n) []) getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformBool _), _)) = (((f,i,b+1,s),inlns), assign bool (const $ inName ++ "ub[" ++ show b ++ "]") (var n) []) getCode ((f,i,b,s),inlns) (n, (ShaderUniform (UniformSampler t _ _), _)) = case first (fromMaybe 0) $ Map.insertLookupWithKey (const $ const (+1)) t 1 s of (x, s') -> (((f,i,b,s'),Map.insert n (inName ++ "us" ++ show (fromEnum t) ++ "[" ++ show x ++ "]") inlns), "") getCode x (n, (ShaderConstant (ConstFloat f), _)) = (x, assign float (const $ show f) (var n) []) getCode x (n, (ShaderConstant (ConstInt i), _)) = (x, assign int (const $ show i) (var n) []) getCode x (n, (ShaderConstant (ConstBool b), _)) = (x, assign bool (const $ if b then "true" else "false") (var n) []) getCode x (n, (ShaderInput i, _)) = (x, assign float (const $ inName ++ inoutAccessor (inF i) numIns) (var n) []) getCode x@(_,inlns) (n, (ShaderOp _ f _, xs)) = (x, f (var n) (map (varMaybeInline inlns) xs)) getCode _ (_, (ShaderInputTree _, _)) = error "Shader.getCodeAssignments: Use splitShaders first!" var n = 't' : show n varMaybeInline inlns n = fromMaybe (var n) (Map.lookup n inlns) inoutAccessor i tot = case divMod i 4 of (d,m) -> if i+1 == tot && m == 0 then show d else show d ++ "." ++ (["x","y","z","w"]!!m) sampName Sampler3D = "sampler3D" sampName Sampler2D = "sampler2D" sampName Sampler1D = "sampler1D" sampName SamplerCube = "samplerCube" tName v = tName' $ Vec.length v tName' 1 = float tName' x = "vec" ++ show x assign :: String -> ([String] -> String) -> String -> [String] -> String assign t f x ys = t ++ " " ++ x ++ "=" ++ f ys ++ ";\n" binFunc :: String -> [String] -> String binFunc s = head . binFunc' where binFunc' (a:b:xs) = binFunc' $ (s ++ "(" ++ a ++ "," ++ b ++ ")"):binFunc' xs binFunc' x = x binaryOp t s a b = Shader $ ShaderOp s (assign t (intercalate s)) [fromS a, fromS b] unaryPreOp t s a = Shader $ ShaderOp s (assign t ((s ++) . head)) [fromS a] unaryPostOp t s a = Shader $ ShaderOp s (assign t ((++ s) . head)) [fromS a] unaryFunc t s a = Shader $ ShaderOp s (assign t (((s ++ "(") ++) . (++ ")") . head)) [fromS a] binaryFunc t s a b = Shader $ ShaderOp s (assign t (binFunc s)) [fromS a, fromS b] ternaryFunc t s a b c = Shader $ ShaderOp s (assign t (\[a,b,c]->s++"("++a++","++b++","++c++")")) [fromS a, fromS b, fromS c] fromVec t = Shader . ShaderOp "" (assign t (((t ++ "(") ++) . (++ ")") . intercalate ",")) . map fromS . Vec.toList toVec t n a = Vec.fromList $ map (\s -> Shader $ ShaderOp s (assign t (\[x]->x++"["++s++"]")) [fromS a]) [show n' | n' <-[0..n - 1]] float = "float" int = "int" bool = "bool"