-----------------------------------------------------------------------------
--
-- Module      :  GPUStream
-- Copyright   :  Tobias Bexelius
-- License     :  BSD3
--
-- Maintainer  :  Tobias Bexelius
-- Stability   :  Experimental
-- Portability :  Portable
--
-- |
-----------------------------------------------------------------------------

module GPUStream (
    PrimitiveStream(..),
    FragmentStream(..),
    VertexPosition,
    CullMode(..),
    Primitive(..),
    Triangle(..),
    Line(..),
    Point(..),
    VertexSetup(..),
    FragmentSetup,
    PrimitiveStreamDesc,
    FragmentStreamDesc,
    filterFragments,
    loadFragmentColorStream,
    loadFragmentDepthStream,
    loadFragmentColorDepthStream,
    loadFragmentAnyStream
) where

import Shader
import Formats
import Data.Monoid
import Data.Vec (Vec4)
import Resources
import qualified Graphics.Rendering.OpenGL as GL (PrimitiveMode(..))
import Graphics.Rendering.OpenGL (cullFace, ($=), Face(..))
import Control.Arrow (first, second)

-- | A stream of primitives built by vertices on the GPU. The first parameter is the primitive type (currently 'Triangle', 'Line' or 'Point') and the second the
-- the type of each primitives' vertices' type (built up of atoms of type 'Vertex').
newtype PrimitiveStream p a = PrimitiveStream [(PrimitiveStreamDesc, a)]
-- | A stream of fragments on the GPU, parameterized on the fragments type
-- (built up of atoms of type 'Fragment').
newtype FragmentStream a = FragmentStream [(FragmentStreamDesc, Fragment Bool, a)]

type VertexPosition                               = Vec4 (Vertex Float)
data CullMode                                     = CullNone | CullFront | CullBack  deriving (Eq,Ord,Bounded,Enum,Show)
data VertexSetup                                  = VertexSetup [[Float]] | IndexedVertexSetup [[Float]] [Int]  deriving (Eq,Ord,Show)
type FragmentSetup                                = [Shader String]
type PrimitiveStreamDesc                          = (GL.PrimitiveMode, VertexSetup)
type FragmentStreamDesc                           = (PrimitiveStreamDesc, CullMode, FragmentSetup, Shader String)

instance Functor (PrimitiveStream p) where
    fmap f (PrimitiveStream a)                    = PrimitiveStream $ map (second f) a
instance Functor FragmentStream where
    fmap f (FragmentStream a)                     = FragmentStream $ map (\(x,y,z) -> (x, y, f z)) a

instance Monoid (PrimitiveStream p a) where
    mempty                                        = PrimitiveStream []
    PrimitiveStream a `mappend` PrimitiveStream b = PrimitiveStream (a ++ b)
instance Monoid (FragmentStream a) where
    mempty                                        = FragmentStream []
    FragmentStream a `mappend` FragmentStream b   = FragmentStream (a ++ b)

-- | Filters out fragments in a stream where the provided function returns 'true'.
filterFragments :: (a -> Fragment Bool) -> FragmentStream a -> FragmentStream a
filterFragments f (FragmentStream xs)             = FragmentStream $ map filterOne xs
	where filterOne (fdesc, b, a)                    = (fdesc, b &&* f a, a)
	

-----------------------------------------

class Primitive p where
    getPrimitiveMode :: p -> GL.PrimitiveMode

data Triangle = TriangleStrip | TriangleList | TriangleFan deriving (Eq,Ord,Bounded,Enum,Show)
data Line = LineStrip | LineList deriving (Eq,Ord,Bounded,Enum,Show)
data Point = PointList deriving (Eq,Ord,Bounded,Enum,Show)

instance Primitive Triangle where
    getPrimitiveMode TriangleStrip = GL.TriangleStrip
    getPrimitiveMode TriangleList = GL.Triangles
    getPrimitiveMode TriangleFan = GL.TriangleFan
instance Primitive Line where
    getPrimitiveMode LineStrip = GL.LineStrip
    getPrimitiveMode LineList = GL.Lines
instance Primitive Point where
    getPrimitiveMode PointList = GL.Points

-----------------------------------------
loadFragmentColorStream :: ColorFormat f => FragmentStream (Color f (Fragment Float)) -> ContextCacheIO () -> ContextCacheIO ()
loadFragmentColorStream = loadFragmentColorStream' . fmap (fromColor 0 1)
    where loadFragmentColorStream' (FragmentStream xs) = layerMapM_ drawCallColor xs
loadFragmentDepthStream :: FragmentStream (Fragment Float) -> ContextCacheIO () -> ContextCacheIO ()
loadFragmentDepthStream (FragmentStream xs) = layerMapM_ (drawCallColorDepth  . setDefaultColor) xs
                                              where
                                                  setDefaultColor (desc, notDisc, d) = (desc, notDisc, (0,d))

loadFragmentColorDepthStream :: ColorFormat f => FragmentStream (Color f (Fragment Float), Fragment Float) -> ContextCacheIO () -> ContextCacheIO ()
loadFragmentColorDepthStream = loadFragmentColorDepthStream' . fmap (first (fromColor 0 1))
    where loadFragmentColorDepthStream' (FragmentStream xs) = layerMapM_ drawCallColorDepth xs
loadFragmentAnyStream :: FragmentStream a -> ContextCacheIO () -> ContextCacheIO ()
loadFragmentAnyStream (FragmentStream xs) = layerMapM_ (drawCallColor  . setDefaultColor) xs
                                            where
                                                setDefaultColor (desc, notDisc, _) = (desc, notDisc, 0)

layerMapM_ f (x:xs) io = layerMapM_ f xs (f x io)
layerMapM_ _ [] io = io

drawCallColor (((p, vs), cull, rast, vPos), nd, c) io =
    let  (fp, funs, fins) = fragmentProgram $ colorFragmentShader nd c
         (vp, vuns, vins) = vertexProgram vPos rast fins
    in drawCall p cull vins vs vp fp vuns funs io

drawCallColorDepth (((p, vs), cull, rast, vPos), nd, cd) io =
    let  (fp, funs, fins) = fragmentProgram $ colorDepthFragmentShader nd cd
         (vp, vuns, vins) = vertexProgram vPos rast fins
    in drawCall p cull vins vs vp fp vuns funs io


mapSelect = map . select
    where select (x:xs) ys = let (a:b) = drop x ys
                             in a: select (map (\t-> t-x-1) xs) b
          select [] _      = []


drawCall p cull ins (VertexSetup v) vp fp vuns funs io = do
    xs <- ioEvaluate (mapSelect ins v)
    ins' <- ioEvaluate ins
    vp' <- ioEvaluate vp
    fp' <- ioEvaluate fp
    s <- ioEvaluate (length ins)
    vs <- ioEvaluate (length v)
    vuns'<-ioEvaluate vuns
    funs'<-ioEvaluate funs
    cull'<-ioEvaluate cull
    p'<-ioEvaluate p
    io
    (pr, (vu, fu)) <- createProgramResource vp' fp' s
    vb <- createVertexBuffer xs ins' v
    useProgramResource pr
    useUniforms vu vuns'
    useUniforms fu funs'
    liftIO $ do useCull cull'
                drawVertexBuffer p' vb vs

drawCall p cull ins (IndexedVertexSetup v i) vp fp vuns funs io = do
    i' <- ioEvaluate i
    xs <- ioEvaluate (mapSelect ins v)
    ins' <- ioEvaluate ins
    vp' <- ioEvaluate vp
    fp' <- ioEvaluate fp
    s <- ioEvaluate (length ins)
    vs <- ioEvaluate (length v)
    vuns'<-ioEvaluate vuns
    funs'<-ioEvaluate funs
    cull'<-ioEvaluate cull
    p'<-ioEvaluate p
    io
    (pr, (vu, fu)) <- createProgramResource vp' fp' s
    ib <- createIndexBuffer i' vs
    vb <- createVertexBuffer xs ins' v
    useProgramResource pr
    useUniforms vu vuns'
    useUniforms fu funs'
    liftIO $ do useCull cull'
                drawIndexVertexBuffer p' vb ib

useCull CullNone = cullFace $= Nothing
useCull CullFront = cullFace $= Just Front
useCull CullBack = cullFace $= Just Back