module Shader (
GPU(..),
Shader,
addInput,
runShader,
vertexProgram,
fragmentProgram,
colorFragmentShader,
colorDepthFragmentShader,
addVertexSamplerUniform,
addFragmentSamplerUniform,
Vertex(Vertex),
Fragment(Fragment),
Real'(..),
dFdx,
dFdy,
fwidth,
vSampleFunc,
fSampleFunc,
module Data.Boolean
) where
import System.IO.Unsafe
import Data.Vec ((:.)(..), Vec2, Vec3, Vec4, norm, normalize, dot, cross)
import qualified Data.Vec as Vec
import Data.Unique
import Data.List
import Data.Boolean
import Control.Monad.State
import Data.Map (Map)
import qualified Data.Map as Map hiding (Map)
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet hiding (IntSet)
import Control.Arrow (first, second)
import Resources
import Formats
class GPU a where
type CPU a
toGPU :: CPU a -> a
type Shader = State (UniformState, IntSet)
setFloatUniforms a = modify $ first $ \ u -> u{floatUniforms=a}
setIntUniforms a = modify $ first $ \ u -> u{intUniforms=a}
setBoolUniforms a = modify $ first $ \ u -> u{boolUniforms=a}
setSampler3DUniforms a = modify $ first $ \ u -> u{sampler3DUniforms=a}
setSampler2DUniforms a = modify $ first $ \ u -> u{sampler2DUniforms=a}
setSampler1DUniforms a = modify $ first $ \ u -> u{sampler1DUniforms=a}
setSamplerCubeUniforms a = modify $ first $ \ u -> u{samplerCubeUniforms=a}
vertexProgram vPos rast fins = let ((pos, outs), uns, ins) = runShader $ do vPos' <- vPos
rast' <- selectM fins rast
return (vPos', rast')
sig = shaderInputSplits "attr" "va" ins ++
vertexOutputSets fins outs ++
"gl_Position = " ++ pos ++ ";\n"
in (("#version 120\n" ++
uniformDecls "vu" uns ++
attributeDecls ins ++
varyingDecls outs ++
"void main(){\n" ++
sig ++
"}\n", sig), uns, ins)
fragmentProgram m = let (outs, uns,ins) = runShader m
sig = shaderInputSplits "var" "fa" ins ++outs
in (("#version 120\n" ++
uniformDecls "fu" uns ++
varyingDecls ins ++
"void main(){\n" ++
sig ++
"}\n", sig), uns, ins)
selectM (x:xs) ys = let (a:b) = drop x ys in do a' <- a
b' <- selectM (map (\t -> tx1) xs) b
return (a':b')
selectM [] _ = return []
colorFragmentShader :: Fragment Bool -> Vec4 (Fragment Float) -> Shader String
colorFragmentShader (Fragment nDisc) (Fragment r :. Fragment g :. Fragment b :. Fragment a :. ()) =
do r' <- r
g' <- g
b' <- b
a' <- a
nDisc' <- nDisc
return $ "if (!" ++ nDisc' ++ ") discard;\n"
++ "gl_FragColor=vec4(" ++ r' ++ "," ++ g' ++ "," ++ b' ++ "," ++ a' ++ ");\n"
colorDepthFragmentShader :: Fragment Bool -> (Vec4 (Fragment Float), Fragment Float) -> Shader String
colorDepthFragmentShader (Fragment nDisc) ((Fragment r :. Fragment g :. Fragment b :. Fragment a :. ()), Fragment d) =
do r' <- r
g' <- g
b' <- b
a' <- a
d' <- d
nDisc' <- nDisc
return $ "if (!" ++ nDisc' ++ ") discard;\n"
++ "gl_FragDepth=" ++ d' ++ ";\n" ++
"gl_FragColor=vec4(" ++ r' ++ "," ++ g' ++ "," ++ b' ++ "," ++ a' ++ ");\n"
uniformDecls :: String -> UniformState -> String
uniformDecls p uns = makeU "float f" floatUniforms ++
makeU "int i" intUniforms ++
makeU "bool b" boolUniforms ++
makeU ("sampler3D s" ++ show (fromEnum Sampler3D)) sampler3DUniforms ++
makeU ("sampler2D s" ++ show (fromEnum Sampler2D)) sampler2DUniforms ++
makeU ("sampler1D s" ++ show (fromEnum Sampler1D)) sampler1DUniforms ++
makeU ("samplerCube sc" ++ show (fromEnum SamplerCube)) samplerCubeUniforms
where makeU tn f = if Map.null (f uns)
then ""
else "uniform " ++ tn ++ p ++ "[" ++ show (Map.size (f uns)) ++ "];\n"
inputVecs ins = [(i,min (length ins i*4) 4) | i <- [0..(length ins 1) `div` 4]]
attributeDecls ins = concat [ "attribute " ++ tName v ++ " attr" ++ show i ++ ";\n" | (i,v) <- inputVecs ins]
varyingDecls ins = concat [ "varying " ++ tName v ++ " var" ++ show i ++ ";\n" | (i,v) <- inputVecs ins]
shaderInputSplits :: String -> String -> [Int] -> String
shaderInputSplits from to ins = concat [ "float " ++ to ++ show i ++ " = " ++ from ++ show a ++ subElem c e ++ ";\n" | i <- ins | (a,e) <- inputVecs ins, c <- [0..3]]
vertexOutputSets ins outs = concat [ "var" ++ show v ++ subElem c e ++ " = " ++ outs!!i ++ ";\n" | i <- ins | (v,e) <- inputVecs ins, c <- [0..3]]
subElem :: Int -> Int -> String
subElem _ 1 = ""
subElem x _ = ['.', (['x','y','z','w']!!x)]
tName :: Int -> String
tName 1 = "float"
tName x = "vec" ++ show x
vSampleFunc f t s tex c xs = toColor $ fromVVec 4 (vListFunc f $ [addVertexSamplerUniform t s tex, vVec c] ++ xs)
fSampleFunc f t s tex c xs = toColor $ fromFVec 4 (fListFunc f $ [addFragmentSamplerUniform t s tex, fVec c] ++ xs)
addVertexSamplerUniform Sampler3D s t = Vertex $ addUniform ("s" ++ show (fromEnum Sampler3D) ++ "vu") sampler3DUniforms setSampler3DUniforms (s,t)
addVertexSamplerUniform Sampler2D s t = Vertex $ addUniform ("s" ++ show (fromEnum Sampler2D) ++ "vu") sampler2DUniforms setSampler2DUniforms (s,t)
addVertexSamplerUniform Sampler1D s t = Vertex $ addUniform ("s" ++ show (fromEnum Sampler1D) ++ "vu") sampler1DUniforms setSampler1DUniforms (s,t)
addVertexSamplerUniform SamplerCube s t = Vertex $ addUniform ("s" ++ show (fromEnum SamplerCube) ++ "vu") samplerCubeUniforms setSamplerCubeUniforms (s,t)
addFragmentSamplerUniform Sampler3D s t = Fragment $ addUniform ("s" ++ show (fromEnum Sampler3D) ++ "fu") sampler3DUniforms setSampler3DUniforms (s,t)
addFragmentSamplerUniform Sampler2D s t = Fragment $ addUniform ("s" ++ show (fromEnum Sampler2D) ++ "fu") sampler2DUniforms setSampler2DUniforms (s,t)
addFragmentSamplerUniform Sampler1D s t = Fragment $ addUniform ("s" ++ show (fromEnum Sampler1D) ++ "fu") sampler1DUniforms setSampler1DUniforms (s,t)
addFragmentSamplerUniform SamplerCube s t = Fragment $ addUniform ("s" ++ show (fromEnum SamplerCube) ++ "fu") samplerCubeUniforms setSamplerCubeUniforms (s,t)
newtype Vertex a = Vertex { fromVertex :: Shader String }
newtype Fragment a = Fragment { fromFragment :: Shader String }
runShader :: Shader a -> (a, UniformState, [Int])
runShader m = (a, fst s, IntSet.toAscList $ snd s)
where (a,s) = runState m (UniformState Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty Map.empty, IntSet.empty)
addInput :: Int -> Shader ()
addInput = modify . second . IntSet.insert
addUniform p getter setter u = do x <- gets (getter . fst)
case Map.lookupIndex n x of
Nothing -> do let x' = Map.insert n u' x
s = Map.findIndex n x'
setter x'
return $ p ++ "[" ++ show s ++ "]"
Just i -> return $ p ++ "[" ++ show i ++ "]"
where (n,u') = unsafePerformIO $ do n <- newUnique
return (n,u)
instance GPU (Vertex Float) where
type CPU (Vertex Float) = Float
toGPU = Vertex . addUniform "fvu" floatUniforms setFloatUniforms
instance GPU (Vertex Int) where
type CPU (Vertex Int) = Int
toGPU = Vertex . addUniform "ivu" intUniforms setIntUniforms
instance GPU (Vertex Bool) where
type CPU (Vertex Bool) = Bool
toGPU = Vertex . addUniform "bvu" boolUniforms setBoolUniforms
instance GPU (Fragment Float) where
type CPU (Fragment Float) = Float
toGPU = Fragment . addUniform "ffu" floatUniforms setFloatUniforms
instance GPU (Fragment Int) where
type CPU (Fragment Int) = Int
toGPU = Fragment . addUniform "ifu" intUniforms setIntUniforms
instance GPU (Fragment Bool) where
type CPU (Fragment Bool) = Bool
toGPU = Fragment . addUniform "bfu" boolUniforms setBoolUniforms
instance GPU () where
type CPU () = ()
toGPU = id
instance (GPU a, GPU b) => GPU (a,b) where
type CPU (a,b) = (CPU a, CPU b)
toGPU (a,b)= (toGPU a, toGPU b)
instance (GPU a, GPU b, GPU c) => GPU (a,b,c) where
type CPU (a,b,c) = (CPU a, CPU b, CPU c)
toGPU (a,b,c)= (toGPU a, toGPU b, toGPU c)
instance (GPU a, GPU b, GPU c, GPU d) => GPU (a,b,c,d) where
type CPU (a,b,c,d) = (CPU a, CPU b, CPU c, CPU d)
toGPU (a,b,c,d)= (toGPU a, toGPU b, toGPU c, toGPU d)
instance (GPU a, GPU b) => GPU (a:.b) where
type CPU (a:.b) = CPU a :. CPU b
toGPU (a:.b) = toGPU a :. toGPU b
instance Eq (Vertex a) where
(==) = noFun "(==)"
(/=) = noFun "(/=)"
instance Eq (Fragment a) where
(==) = noFun "(==)"
(/=) = noFun "(/=)"
instance Ord a => Ord (Vertex a) where
(<=) = noFun "(<=)"
min = vBinaryFunc "min"
max = vBinaryFunc "max"
instance Ord a => Ord (Fragment a) where
(<=) = noFun "(<=)"
min = fBinaryFunc "min"
max = fBinaryFunc "max"
instance Show (Vertex a) where
show = noFun "show"
instance Show (Fragment a) where
show = noFun "show"
instance Num a => Num (Vertex a) where
negate = vUnaryPreOp "-"
(+) = vBinaryOp "+"
(*) = vBinaryOp "*"
fromInteger a = Vertex $ return $ show (fromInteger a :: a)
abs = vUnaryFunc "abs"
signum = vUnaryFunc "sign"
instance Num a => Num (Fragment a) where
negate = fUnaryPreOp "-"
(+) = fBinaryOp "+"
(*) = fBinaryOp "*"
fromInteger a = Fragment $ return $ show (fromInteger a :: a)
abs = fUnaryFunc "abs"
signum = fUnaryFunc "sign"
instance Fractional a => Fractional (Vertex a) where
(/) = vBinaryOp "/"
fromRational a = Vertex $ return $ show (fromRational a :: a)
instance Fractional a => Fractional (Fragment a) where
(/) = fBinaryOp "/"
fromRational a = Fragment $ return $ show (fromRational a :: a)
instance Floating a => Floating (Vertex a) where
pi = fromRational (toRational (pi :: Double))
sqrt = vUnaryFunc "sqrt"
exp = vUnaryFunc "exp"
log = vUnaryFunc "log"
(**) = vBinaryFunc "pow"
sin = vUnaryFunc "sin"
cos = vUnaryFunc "cos"
tan = vUnaryFunc "tan"
asin = vUnaryFunc "asin"
acos = vUnaryFunc "acos"
atan = vUnaryFunc "atan"
sinh = noFun "sinh"
cosh = noFun "cosh"
asinh = noFun "asinh"
atanh = noFun "atanh"
acosh = noFun "acosh"
instance Floating a => Floating (Fragment a) where
pi = fromRational (toRational (pi :: Double))
sqrt = fUnaryFunc "sqrt"
exp = fUnaryFunc "exp"
log = fUnaryFunc "log"
(**) = fBinaryFunc "pow"
sin = fUnaryFunc "sin"
cos = fUnaryFunc "cos"
tan = fUnaryFunc "tan"
asin = fUnaryFunc "asin"
acos = fUnaryFunc "acos"
atan = fUnaryFunc "atan"
sinh = noFun "sinh"
cosh = noFun "cosh"
asinh = noFun "asinh"
atanh = noFun "atanh"
acosh = noFun "acosh"
class (Ord a, Floating a) => Real' a where
rsqrt :: a -> a
exp2 :: a -> a
log2 :: a -> a
floor' :: a -> a
ceiling' :: a -> a
fract' :: a -> a
mod' :: a -> a -> a
clamp :: a -> a -> a -> a
saturate :: a -> a
mix :: a -> a -> a-> a
step :: a -> a -> a
smoothstep :: a -> a -> a -> a
rsqrt = (1/) . sqrt
exp2 = (2**)
log2 = logBase 2
clamp x a b = min (max x a) b
saturate x = clamp x 0 1
mix x y a = x*(1a)+y*a
step a x | x < a = 0
| otherwise = 1
smoothstep a b x = let t = saturate ((xa) / (ba))
in t*t*(32*t)
fract' x = x floor' x
mod' x y = x y* floor' (x/y)
instance Real' Float where
floor' = fromIntegral . floor
ceiling' = fromIntegral . ceiling
instance Real' Double where
floor' = fromIntegral . floor
ceiling' = fromIntegral . ceiling
instance Real' (Vertex Float) where
rsqrt = vUnaryFunc "inversesqrt"
exp2 = vUnaryFunc "exp2"
log2 = vUnaryFunc "log2"
floor' = vUnaryFunc "floor"
ceiling' = vUnaryFunc "ceil"
fract' = vUnaryFunc "fract"
mod' = vBinaryFunc "mod"
clamp x a b = vListFunc "clamp" [x,a,b]
mix x y a = vListFunc "mix" [x,y,a]
step = vBinaryFunc "step"
smoothstep a b x = vListFunc "smoothstep" [a,b,x]
instance Real' (Fragment Float) where
rsqrt = fUnaryFunc "inversesqrt"
exp2 = fUnaryFunc "exp2"
log2 = fUnaryFunc "log2"
floor' = fUnaryFunc "floor"
ceiling' = fUnaryFunc "ceil"
fract' = fUnaryFunc "fract"
mod' = fBinaryFunc "mod"
clamp x a b = fListFunc "clamp" [x,a,b]
mix x y a = fListFunc "mix" [x,y,a]
step = fBinaryFunc "step"
smoothstep a b x = fListFunc "smoothstep" [a,b,x]
instance Boolean (Vertex Bool) where
true = Vertex $ return "true"
false = Vertex $ return "false"
notB = vUnaryPreOp "!"
(&&*) = vBinaryOp "&&"
(||*) = vBinaryOp "||"
instance Boolean (Fragment Bool) where
true = Fragment $ return "true"
false = Fragment $ return "false"
notB = fUnaryPreOp "!"
(&&*) = fBinaryOp "&&"
(||*) = fBinaryOp "||"
instance Eq a => EqB (Vertex Bool) (Vertex a) where
(==*) = vBinaryOp "=="
(/=*) = vBinaryOp "!="
instance Eq a => EqB (Fragment Bool) (Fragment a) where
(==*) = fBinaryOp "=="
(/=*) = fBinaryOp "!="
instance Ord a => OrdB (Vertex Bool) (Vertex a) where
(<*) = vBinaryOp "<"
(>=*) = vBinaryOp ">="
(>*) = vBinaryOp ">"
(<=*) = vBinaryOp "<="
instance Ord a => OrdB (Fragment Bool) (Fragment a) where
(<*) = fBinaryOp "<"
(>=*) = fBinaryOp ">="
(>*) = fBinaryOp ">"
(<=*) = fBinaryOp "<="
instance IfB (Vertex Bool) (Vertex a) where
ifB c a b = Vertex $ do c' <- fromVertex c
a' <- fromVertex a
b' <- fromVertex b
return $ "(" ++ c' ++ "?" ++ a' ++ ":" ++ b' ++ ")"
instance IfB (Fragment Bool) (Fragment a) where
ifB c a b = Fragment $ do c' <- fromFragment c
a' <- fromFragment a
b' <- fromFragment b
return $ "(" ++ c' ++ "?" ++ a' ++ ":" ++ b' ++ ")"
dFdx :: Fragment Float -> Fragment Float
dFdy :: Fragment Float -> Fragment Float
fwidth :: Fragment Float -> Fragment Float
dFdx = fUnaryFunc "dFdx"
dFdy = fUnaryFunc "dFdy"
fwidth = fUnaryFunc "fwidth"
normF4 :: Vec4 (Fragment Float) -> Fragment Float
normF4 = fUnaryFunc "length" . fVec
normF3 :: Vec3 (Fragment Float) -> Fragment Float
normF3 = fUnaryFunc "length" . fVec
normF2 :: Vec2 (Fragment Float) -> Fragment Float
normF2 = fUnaryFunc "length" . fVec
normV4 :: Vec4 (Vertex Float) -> Vertex Float
normV4 = vUnaryFunc "length" . vVec
normV3 :: Vec3 (Vertex Float) -> Vertex Float
normV3 = vUnaryFunc "length" . vVec
normV2 :: Vec2 (Vertex Float) -> Vertex Float
normV2 = vUnaryFunc "length" . vVec
normalizeF4 :: Vec4 (Fragment Float) -> Vec4 (Fragment Float)
normalizeF4 = fromFVec 4 . fUnaryFunc "normalize" . fVec
normalizeF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float)
normalizeF3 = fromFVec 3 . fUnaryFunc "normalize" . fVec
normalizeF2 :: Vec2 (Fragment Float) -> Vec2 (Fragment Float)
normalizeF2 = fromFVec 2 . fUnaryFunc "normalize" . fVec
normalizeV4 :: Vec4 (Vertex Float) -> Vec4 (Vertex Float)
normalizeV4 = fromVVec 4 . vUnaryFunc "normalize" . vVec
normalizeV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float)
normalizeV3 = fromVVec 3 . vUnaryFunc "normalize" . vVec
normalizeV2 :: Vec2 (Vertex Float) -> Vec2 (Vertex Float)
normalizeV2 = fromVVec 2 . vUnaryFunc "normalize" . vVec
dotF4 :: Vec4 (Fragment Float) -> Vec4 (Fragment Float) -> Fragment Float
dotF4 a b = fBinaryFunc "dot" (fVec a) (fVec b)
dotF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) -> Fragment Float
dotF3 a b = fBinaryFunc "dot" (fVec a) (fVec b)
dotF2 :: Vec2 (Fragment Float) -> Vec2 (Fragment Float) -> Fragment Float
dotF2 a b = fBinaryFunc "dot" (fVec a) (fVec b)
dotV4 :: Vec4 (Vertex Float) -> Vec4 (Vertex Float) -> Vertex Float
dotV4 a b = vBinaryFunc "dot" (vVec a) (vVec b)
dotV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) -> Vertex Float
dotV3 a b = vBinaryFunc "dot" (vVec a) (vVec b)
dotV2 :: Vec2 (Vertex Float) -> Vec2 (Vertex Float) -> Vertex Float
dotV2 a b = vBinaryFunc "dot" (vVec a) (vVec b)
crossF3 :: Vec3 (Fragment Float) -> Vec3 (Fragment Float) -> Vec3 (Fragment Float)
crossF3 a b = fromFVec 3 $ fBinaryFunc "cross" (fVec a) (fVec b)
crossV3 :: Vec3 (Vertex Float) -> Vec3 (Vertex Float) ->Vec3 (Vertex Float)
crossV3 a b = fromVVec 3 $ vBinaryFunc "cross" (vVec a) (vVec b)
noFun :: String -> a
noFun = error . (++ ": No overloading for Vertex/Fragment")
vListFunc s xs = Vertex $ do xs' <- mapM fromVertex xs
return $ s ++ "(" ++ intercalate "," xs' ++ ")"
fListFunc s xs = Fragment $ do xs' <- mapM fromFragment xs
return $ s ++ "(" ++ intercalate "," xs' ++ ")"
vUnaryFunc s a = vListFunc s [a]
fUnaryFunc s a = fListFunc s [a]
vBinaryFunc s a b = vListFunc s [a,b]
fBinaryFunc s a b = fListFunc s [a,b]
vUnaryPreOp s a = Vertex $ do a' <- fromVertex a
return $ "(" ++ s ++ a' ++ ")"
fUnaryPreOp s a = Fragment $ do a' <- fromFragment a
return $ "(" ++ s ++ a' ++ ")"
vBinaryOp s a b = Vertex $ do a' <- fromVertex a
b' <- fromVertex b
return $ "(" ++ a' ++ s ++ b' ++ ")"
fBinaryOp s a b = Fragment $ do a' <- fromFragment a
b' <- fromFragment b
return $ "(" ++ a' ++ s ++ b' ++ ")"
vVec v = let xs = Vec.toList v in vListFunc (tName $ length xs) xs
fVec v = let xs = Vec.toList v in fListFunc (tName $ length xs) xs
fromVVec e v = Vec.fromList $ map (\n -> Vertex $ do v' <- fromVertex v
return (v' ++ subElem n e))
[0..(e1)]
fromFVec e v = Vec.fromList $ map (\n -> Fragment $ do v' <- fromFragment v
return (v' ++ subElem n e))
[0..(e1)]