{-# LANGUAGE RankNTypes, ScopedTypeVariables, DataKinds, KindSignatures,
             TypeFamilies, FlexibleContexts, UndecidableInstances,
             FlexibleInstances, DefaultSignatures, TypeOperators #-}

{-|
Module:      Graphics.Rendering.Ombra.Shader
License:     BSD3
Maintainer:  ziocroc@gmail.com
Stability:   experimental
Portability: GHC only
-}

module Graphics.Rendering.Ombra.Shader (
        module Graphics.Rendering.Ombra.Shader.Language,
        ShaderStage(..),
        Shader,
        VertexShader,
        FragmentShader,
        -- * Uniforms
        uniform,
        (~<),
        foldUniforms,
        -- * Optimized shaders
        UniformSetter,
        shader,
        sarr,
        shaderParam,
        pshader,
        ushader,
        pushader,
        uniform',
        (~<*),
        -- * Fragment shader functionalities
        Fragment(..),
        farr,
        fragment,
        -- * Loops
        forLoop,
        foldGArray,
        -- * Classes
        MultiShaderType(..),
        ShaderInput(..),
        FragmentShaderOutput(..),
        MapShader(..),
        Uniform(..)
) where

import Control.Arrow
import Control.Applicative
import Control.Category
import Data.Hashable
import Data.MemoTrie
import Data.Proxy
import GHC.Generics
import GHC.TypeLits
import Graphics.Rendering.Ombra.Backend (GLES)
import Graphics.Rendering.Ombra.Internal.GL (Sampler2D)
import Graphics.Rendering.Ombra.Shader.Language
import qualified Graphics.Rendering.Ombra.Shader.Language.Functions as Shader
import Graphics.Rendering.Ombra.Shader.Language.Types
import Graphics.Rendering.Ombra.Shader.CPU
import Graphics.Rendering.Ombra.Shader.Types
import Graphics.Rendering.Ombra.Texture (Texture)
import Prelude hiding (id, (.))

newtype UniformSetter x = UniformSetter { unUniformSetter :: x }

instance Functor UniformSetter where
        fmap f (UniformSetter x) = UniformSetter $ f x

instance Applicative UniformSetter where
        pure = UniformSetter
        UniformSetter f <*> UniformSetter x = UniformSetter $ f x

instance Monad UniformSetter where
        return = pure
        UniformSetter x >>= f = f x

hashMST :: MultiShaderType a => a -> a
hashMST = mapMST (fromExpr . HashDummy . hash . toExpr)

-- | Create a shader function that can be reused efficiently. Ideally, every
-- operation on G* and *Sampler types should be performed by a top level Shader
-- created with this function, while arrow combinators and uniforms can appear
-- anywhere.
shader :: (MultiShaderType i, MultiShaderType o) => Shader s i o -> Shader s i o
shader (Shader f hf) = Shader f (memoHash hf)
-- BUG: shader modifies the hash of the shader

-- | This variant of 'shader' can be used with shaders that have a mostly static
-- parameter. It will create a different shader every time the parameter changes
-- to a new value, therefore parameters should not be used for things like
-- model matrices (for which uniforms are more appropriate). Unlike uniforms,
-- parameters can be used anywhere, in particular they can be used to change the
-- shader structure. 'Shader's themselves can be used as parameters.
shaderParam :: (HasTrie p, MultiShaderType i, MultiShaderType o)
            => Shader s (p, i) o
            -> Shader s (p, i) o
shaderParam (Shader f hf) =
        let hf' = memo (\p -> memoHash $ \(uid, i) -> hf (uid, (p, i)))
        in Shader f (\(uid, (p, i)) -> hf' p (uid, i))

-- | See 'shaderParam'. The result of partially applying this function is a
-- function for which the same rules of 'shader' apply (that is, it should be
-- reused rather than recreated at every frame).
pshader :: (HasTrie p, MultiShaderType i, MultiShaderType o)
        => (p -> Shader s i o)
        -> (p -> Shader s i o)
pshader shaderf = let shader' = shaderParam $ first shaderf ^>> app
                  in \p -> const p &&& id ^>> shader'

-- | 'shader' with an additional parameter that can be used to set the values of
-- the uniforms. Like 'pshader', this should be used as a function of
-- functions, not a function with two arguments.
ushader :: (MultiShaderType i, MultiShaderType o)
        => (UniformSetter x -> Shader s i o)
        -> (UniformSetter x -> Shader s i o)
ushader shaderf = let err = "ushader: not an uniform value"
                      Shader _ hf = shaderf $ error err
                      hf' = memoHash hf
                  in \x -> let Shader f _ = shaderf x in Shader f hf'

-- | Combination of 'pshader' and 'ushader'.
pushader :: (HasTrie p, MultiShaderType i, MultiShaderType o)
         => (p -> UniformSetter x -> Shader s i o)
         -> (p -> UniformSetter x -> Shader s i o)
pushader shaderf = let err = error "pushader: not an uniform value"
                       hf' = memo $ \p -> let Shader _ hf = shaderf p err
                                          in memoHash hf
                   in \p x -> let Shader f _ = shaderf p x
                              in Shader f $ hf' p

-- | @'shader' . 'arr'@
sarr :: (MultiShaderType i, MultiShaderType o) => (i -> o) -> Shader s i o
sarr = shader . arr

memoHash :: (MultiShaderType i, MultiShaderType o)
         => ((UniformID, i) -> (UniformID, o))
         -> ((UniformID, i) -> (UniformID, o))
memoHash hf = let mf = memo $ second hashMST . hf . second fromExprMST
              in mf . second toExprMST

-- | Add a shader variable that can be set with a CPU value.
uniform :: forall u s. Uniform u => Shader s (CPUUniform u) u
uniform = Shader (\(ShaderState uid umap tmap, multiValue) ->
                        let (uniExpr, uid') =
                                buildMST' (\t -> fromExpr . Uniform t) uid
                            acc value@(UniformValue _ _) (uid, umap, tmap) =
                                    (uid - 1, (uid, value) : umap, tmap)
                            acc value@(UniformTexture tex) (uid, umap, tmap) =
                                    (uid - 1, (uid, value) : umap, tex : tmap)
                            (_, umap', tmap') =
                                    foldrUniform (Proxy :: Proxy u) acc
                                                 (uid' - 1, umap, tmap)
                                                 multiValue
                        in (ShaderState uid' umap' tmap', uniExpr)
                 )
                 (\(uid, _) ->
                       let (uniExpr, uid') =
                               buildMST' (\t -> fromExpr . Uniform t) uid
                       in (uid', uniExpr)
                 )

-- | Like 'uniform' but uses a 'UniformSetter'.
uniform' :: Uniform u => Shader s (UniformSetter (CPUUniform u)) u
uniform' = unUniformSetter ^>> uniform

-- | Add a uniform and directly set it with the second operand.
infixl 9 ~<
(~<) :: Uniform u => Shader s (u, i) o -> CPUUniform u -> Shader s i o
shader ~< u = (const u ^>> uniform) &&& id >>> shader

-- | Add a uniform and directly set it with a 'UniformSetter'.
infixl 9 ~<*
(~<*) :: Uniform u
     => Shader s (u, i) o
     -> UniformSetter (CPUUniform u)
     -> Shader s i o
shader ~<* u = (const u ^>> uniform') &&& id >>> shader

-- | This works like 'sarr' but provides a 'Fragment'.
farr :: (MultiShaderType i, MultiShaderType o)
     => (Fragment -> i -> o)
     -> FragmentShader i o
farr f = shader $ arr (f frag)

fragment :: FragmentShader a Fragment
fragment = arr $ const frag

frag :: Fragment
frag = Fragment { fragCoord = Shader.fragCoord
                , fragFrontFacing = Shader.fragFrontFacing
                , dFdx = Shader.dFdx
                , dFdy = Shader.dFdy
                , fwidth = Shader.fwidth
                }

-- | Repeatedly apply a function to a shader value. This is compiled to an
-- actual for loop, therefore it won't duplicate the function code (doing that
-- could slow down compilation or cause an out of memory error). The same
-- applies to derived functions like 'foldGArray' and 'foldUniforms'.
forLoop :: ShaderInput a 
        => Int -- ^ Maximum number of iterations (should be as low as possible)
        -> a -- ^ Initial value
        -> (GInt -> a -> (a, GBool)) -- ^ Iteration -> Old value -> (Next, Stop)
        -> a
forLoop iters iacc f = buildFromExprList $
        Shader.unsafeLoop (fromExpr . Literal "int" $ show iters)
                          (foldrMST (\x -> ((typeName x, toExpr x) :)) [] iacc)
                          (\i es -> let acc = buildFromExprList es
                                        (acc', stop) = f i acc
                                        es' = foldrMST (\x -> (toExpr x :))
                                                       [] acc'
                                    in (es', stop))
        -- XXX
        where buildFromExprList es = fst $ buildMST (\i -> fromExpr $ es !! i) 0

foldGArray :: forall t n a. (ShaderType t, KnownNat n, ShaderInput a)
           => (a -> t -> a)
           -> a
           -> GArray n t
           -> a
foldGArray f iacc arr = forLoop (fromIntegral $ natVal (Proxy :: Proxy n))
                                iacc
                                (\i acc -> (f acc $ arr ! i, false))

-- | Create an array uniform and then fold over it with the given function and
-- initial value.
foldUniforms :: forall a u s. (ShaderInput a, ArrayUniform u, GLES)
             => Shader s (((a -> u -> a), a), [CPUBase u]) a
foldUniforms = (\((f, i), us) -> case someNatVal . fromIntegral $ length us of
                                      Just (SomeNat p) -> (foldArray p f i, us)
               ) ^>> app
        where foldArray :: forall n. KnownNat n
                        => Proxy n
                        -> (a -> u -> a)
                        -> a
                        -> Shader s [CPUBase u] a
              foldArray p f i = baseUniformGArray p (Proxy :: Proxy u) $
                                        uniform >>^ \(arr :: GArray n u) ->
                                                        foldGArray f i arr