-----------------------------------------------------------------------------
--
-- Module      :  Resources
-- Copyright   :  Tobias Bexelius
-- License     :  BSD3
--
-- Maintainer  :  Tobias Bexelius
-- Stability   :  Experimental
-- Portability :  Portable
--
-- |
--
-----------------------------------------------------------------------------

module Resources (
    createProgramResource,
    createVertexBuffer,
    createIndexBuffer,
    useUniforms,
    useProgramResource,
    drawIndexVertexBuffer,
    drawVertexBuffer,
    VertexBuffer(..),
    IndexBuffer(..),
    UniformLocationSet,
    UniformState(..),
    ContextCacheIO,
    ContextCache(contextWindow,contextViewPort),
    newContextCache,
    setContextWindow,
    liftIO,
    hiddenWindowContextCache,
    WinMappedTexture,
    newWinMappedTexture,
    bindWinMappedTexture,
    Sampler(..),
    Filter(..),
    EdgeMode(..),
    SamplerType(..),
    cubeMapTargets,
    getContextCache,
    saveContextCache,
    changeContextSize,
    getCurrentOrSetHiddenContext,
    ioEvaluate,
    evaluateDeep,
    evaluatePtr
) where

import Graphics.Rendering.OpenGL hiding (Sampler3D, Sampler2D, Sampler1D, SamplerCube, Point, Linear, Clamp)
import qualified Data.HashTable as HT
import qualified Graphics.UI.GLUT as GLUT
import Data.Map (Map)
import qualified Data.Map as Map
import System.Mem.StableName
    (makeStableName, hashStableName, StableName)
import Data.Bits
import Control.Monad.Reader
import Control.Monad
import Foreign.C.Types
import Foreign.Storable
import Foreign.Marshal
import Data.List
import Data.Maybe
import Foreign.Ptr
import System.IO.Unsafe (unsafePerformIO)
import Data.IORef
import Control.Exception (evaluate)
import Foreign.ForeignPtr
import System.Mem.Weak (addFinalizer)
import Data.Unique


data VertexBuffer = VertexBuffer BufferObject Int
data IndexBuffer = IndexBuffer BufferObject Int DataType
type UniformLocationSet = (UniformLocation,UniformLocation,UniformLocation, Map SamplerType UniformLocation)

data SamplerType = Sampler3D | Sampler2D | Sampler1D | SamplerCube deriving (Eq, Ord, Enum, Bounded)

data UniformState = UniformState
                    {
                        floatUniforms :: Map Unique Float,
                        intUniforms :: Map Unique Int,
                        boolUniforms :: Map Unique Bool,
                        sampler3DUniforms :: Map Unique (Sampler, WinMappedTexture),
                        sampler2DUniforms :: Map Unique (Sampler, WinMappedTexture),
                        sampler1DUniforms :: Map Unique (Sampler, WinMappedTexture),
                        samplerCubeUniforms :: Map Unique (Sampler, WinMappedTexture)
                    } deriving Eq
                    
-- | A structure describing how a texture is sampled
data Sampler = Sampler Filter EdgeMode deriving (Eq, Ord)
-- | Filter mode used in sampler state
data Filter = Point | Linear
            deriving (Eq,Ord,Bounded,Enum,Show)
-- | Edge mode used in sampler state
data EdgeMode = Wrap | Mirror | Clamp
              deriving (Eq,Ord,Bounded,Enum,Show)

toGLFilter Point = ((Nearest, Just Nearest), Nearest)
toGLFilter Linear = ((Linear', Just Linear'), Linear')
toGLWrap Wrap = (Repeated, Repeat)
toGLWrap Mirror = (Mirrored, Repeat)
toGLWrap Clamp = (Repeated, ClampToEdge)

type ProgramCache = HT.HashTable (String, String) (Program, (UniformLocationSet,UniformLocationSet))
type VBCache = HT.HashTable ([Int], StableName [[Float]]) VertexBuffer
type IBCache = HT.HashTable (StableName [Int]) IndexBuffer
data ContextCache = ContextCache {programCache :: ProgramCache,
                                  vbCache :: VBCache,
                                  ibCache :: IBCache,
                                  contextWindow :: GLUT.Window,
                                  contextViewPort :: Size}
newContextCache :: GLUT.Window -> IO ContextCache
newContextCache w = do
    pc <- HT.new (==) (\(a,b)-> HT.hashString a `xor` HT.hashString b)
    vbc <- HT.new (==) (\(a,b) -> HT.hashString (map toEnum a) `xor` HT.hashInt (hashStableName b))
    ibc <- HT.new (==) (HT.hashInt . hashStableName)
    let cache = ContextCache pc vbc ibc w (Size 0 0)
    saveContextCache cache
    return cache

setContextWindow :: ContextCacheIO ()
setContextWindow = do w <- asks contextWindow
                      s <- asks contextViewPort
                      liftIO $ do GLUT.currentWindow $= Just w
                                  viewport $= (Position 0 0, s)                     


type ContextCacheIO = ReaderT ContextCache IO

createShaderResource :: Shader s => String -> IO s
createShaderResource str = do [s] <- genObjectNames 1
                              -- putStrLn $ "Created shader " ++ show s
                              shaderSource s $= [str]
                              compileShader s
                              b <- get $ compileStatus s
                              if b then return s
                                   else do e <- get $ shaderInfoLog s
                                           error $ e ++ "\nSource:\n\n" ++ str

createProgramResource :: (String,String) -> (String,String) -> Int -> ContextCacheIO (Program, (UniformLocationSet,UniformLocationSet))
createProgramResource (vstr,vsig) (fstr,fsig) s = do
                                       cache <- asks programCache
                                       test <- liftIO $ HT.lookup cache (vsig,fsig)
                                       case test of
                                            Just p -> return p
                                            Nothing -> liftIO $ do
                                                 [p] <- genObjectNames 1
                                                 -- putStrLn $ "Created program " ++ show p
                                                 vs <- createShaderResource vstr
                                                 fs <- createShaderResource fstr
                                                 attachedShaders p $= ([vs],[fs])
                                                 mapM_
                                                     (\i -> attribLocation p ("attr" ++ show i) $= AttribLocation i)
                                                     [0..fromIntegral ((s-1) `div` 4)]
                                                 linkProgram p
                                                 b <- get $ linkStatus p
                                                 unless b $ do
                                                     e <- get $ programInfoLog p
                                                     error e
                                                 let allSamplers = [minBound..maxBound :: SamplerType]
                                                 fvu <- get $ uniformLocation p "fvu"
                                                 ivu <- get $ uniformLocation p "ivu"
                                                 bvu <- get $ uniformLocation p "bvu"
                                                 svus' <- mapM (\ s -> get $ uniformLocation p $ "s" ++ show (fromEnum s) ++"vu") allSamplers
                                                 let svus = Map.fromAscList $ zip allSamplers svus'
                                                 ffu <- get $ uniformLocation p "ffu"
                                                 ifu <- get $ uniformLocation p "ifu"
                                                 bfu <- get $ uniformLocation p "bfu"
                                                 sfus' <- mapM (\ s -> get $ uniformLocation p $ "s" ++ show (fromEnum s) ++"fu") allSamplers
                                                 let sfus = Map.fromAscList $ zip allSamplers sfus'
                                                 let p' = (p, ((fvu,ivu,bvu,svus),(ffu,ifu,bfu,sfus)))
                                                 HT.insert cache (vsig,fsig) p'
                                                 return p'

createVertexBuffer :: [[Float]] -> [Int] -> [[Float]] -> ContextCacheIO VertexBuffer
createVertexBuffer xs i v = do vName <- liftIO $ makeStableName v
                               let k = (i,vName)
                               cache <- asks vbCache
                               test <- liftIO $ HT.lookup cache k
                               case test of
                                  Just b -> return b
                                  Nothing -> do
                                     w <- asks contextWindow
                                     liftIO $ do [b] <- genObjectNames 1
                                                 -- putStrLn $ "Created vertex buffer " ++ show b
                                                 bindBuffer ArrayBuffer $= Just b
                                                 let xsdata = map realToFrac $ concat xs :: [CFloat]
                                                 withArray xsdata (\p -> bufferData ArrayBuffer $= (fromIntegral $ sizeOf (0 :: CFloat) * length xsdata, p, StaticDraw))
                                                 let b' = VertexBuffer b $ length (head xs)
                                                 HT.insert cache k b'
                                                 addFinalizer v $ do HT.delete cache k
                                                                     w' <- get GLUT.currentWindow
                                                                     GLUT.currentWindow $= Just w
                                                                     deleteObjectNames [b]
                                                                     GLUT.currentWindow $= w'
                                                                     -- putStrLn "Deleted vertex buffer"
                                                 return b'

createIndexBuffer :: [Int] -> Int -> ContextCacheIO IndexBuffer
createIndexBuffer xs vs = do cache <- asks ibCache
                             iName <- liftIO $ makeStableName xs
                             test <- liftIO $ HT.lookup cache iName
                             case test of
                                Just i -> return i
                                Nothing -> do
                                   w <- asks contextWindow
                                   liftIO $ do [b] <- genObjectNames 1
                                               -- putStrLn $ "Created index buffer " ++ show b
                                               bindBuffer ElementArrayBuffer $= Just b
                                               i <- if vs > fromIntegral (maxBound :: CUShort)
                                                then do
                                                    withArray (map fromIntegral  xs :: [CUInt]) (\p -> bufferData ElementArrayBuffer $= (fromIntegral $ sizeOf (0 :: CFloat) * length xs, p, StaticDraw))
                                                    return $ IndexBuffer b (length xs) UnsignedInt
                                                else if vs > fromIntegral (maxBound :: CUChar)
                                                 then do
                                                    withArray (map fromIntegral  xs :: [CUShort]) (\p -> bufferData ElementArrayBuffer $= (fromIntegral $ sizeOf (0 :: CFloat) * length xs, p, StaticDraw))
                                                    return $ IndexBuffer b (length xs) UnsignedShort
                                                 else do
                                                    withArray (map fromIntegral  xs :: [CUChar]) (\p -> bufferData ElementArrayBuffer $= (fromIntegral $ (sizeOf (0 :: CFloat)) * length xs, p, StaticDraw))
                                                    return $ IndexBuffer b (length xs) UnsignedByte
                                               HT.insert cache iName i
                                               addFinalizer xs $ do HT.delete cache iName
                                                                    w' <- get GLUT.currentWindow
                                                                    GLUT.currentWindow $= Just w
                                                                    deleteObjectNames [b]
                                                                    GLUT.currentWindow $= w'  
                                                                    -- putStrLn "Deleted index buffer"
                                               return i

useProgramResource :: Program -> ContextCacheIO ()
useProgramResource p = liftIO $ currentProgram $= Just p

useUniforms :: UniformLocationSet -> UniformState -> ContextCacheIO ()
useUniforms (fu,iu,bu,su) uns = do
            w <- asks contextWindow
            liftIO $ do
                unless (null f) $ withArray (map (TexCoord1 . realToFrac) f :: [TexCoord1 GLfloat]) $ uniformv fu (fromIntegral $ length f)
                unless (null i) $ withArray (map (TexCoord1 . fromIntegral) i :: [TexCoord1 GLint]) $ uniformv iu (fromIntegral $ length i)
                unless (null b) $ withArray (map (TexCoord1 . fromBool) b :: [TexCoord1 GLint]) $ uniformv bu (fromIntegral $ length b)
                unless (null s3) $ useSampler w Sampler3D s3
                unless (null s2) $ useSampler w Sampler2D s2
                unless (null s1) $ useSampler w Sampler1D s1
                unless (null sc) $ useSampler w SamplerCube sc
    where
        f = Map.elems $ floatUniforms uns
        i = Map.elems $ intUniforms uns
        b = Map.elems $ boolUniforms uns
        s3 = Map.elems $ sampler3DUniforms uns
        s2 = Map.elems $ sampler2DUniforms uns
        s1 = Map.elems $ sampler1DUniforms uns
        sc = Map.elems $ samplerCubeUniforms uns
        
        useSampler w t xs = do
            let texs = nub xs
            samplers <- mapM (createSampler w t) $ zip texs [0..]
            let texToSamp = zip texs samplers
                ss = map (fromJust . flip lookup texToSamp) xs
            withArray ss $ uniformv (su Map.! t) (fromIntegral $ length ss)
        createSampler w t ((Sampler f e,tex),i) = do
            activeTexture $= TextureUnit i
            bindWinMappedTexture (target t) w tex t
            textureFilter (target t) $= toGLFilter f
            mapM_ (\c -> textureWrapMode (target t) c $= toGLWrap e) [S, T, R, Q]
            return $ TexCoord1 (fromIntegral i::GLint)
        target Sampler3D = Texture3D
        target Sampler2D = Texture2D
        target Sampler1D = Texture1D
        target SamplerCube = TextureCubeMap

useVertexBuffer :: VertexBuffer -> IO ()
useVertexBuffer (VertexBuffer b s) = do bindBuffer ArrayBuffer $= Just b
                                        mapM_
                                            (\i -> do
                                                let ptroffset = nullPtr `plusPtr` (sizeOf (0 :: CFloat) * i * 4)
                                                    stride = fromIntegral $ sizeOf (0 :: CFloat) * s
                                                    a = AttribLocation $ fromIntegral i
                                                    e = fromIntegral $ min (s - i*4) 4
                                                vertexAttribArray a $= Enabled
                                                vertexAttribPointer a $= (ToFloat, VertexArrayDescriptor e Float stride ptroffset)
                                            )
                                            [0..(s-1) `div` 4]


drawIndexVertexBuffer :: PrimitiveMode -> VertexBuffer -> IndexBuffer -> IO ()
drawIndexVertexBuffer p vb (IndexBuffer i s t) = do useVertexBuffer vb
                                                    bindBuffer ElementArrayBuffer $= Just i
                                                    drawElements p (fromIntegral s) t nullPtr

drawVertexBuffer :: PrimitiveMode -> VertexBuffer -> Int -> IO ()
drawVertexBuffer p vb s = do useVertexBuffer vb
                             drawArrays p 0 (fromIntegral s)


{-# NOINLINE hiddenWindowContextCache #-}
hiddenWindowContextCache :: ContextCache
hiddenWindowContextCache = unsafePerformIO $ do
       GLUT.initialDisplayMode $= [ GLUT.SingleBuffered, GLUT.RGBMode, GLUT.WithAlphaComponent, GLUT.WithDepthBuffer, GLUT.WithStencilBuffer ]
       w <- GLUT.createWindow "Hidden Window"
       GLUT.windowStatus $= GLUT.Hidden
       newContextCache w

{-# NOINLINE windowContextCaches #-}
windowContextCaches :: IORef (Map GLUT.Window ContextCache)
windowContextCaches = unsafePerformIO $ newIORef $ Map.empty

getContextCache :: GLUT.Window -> IO (Maybe ContextCache)
getContextCache w = do m <- atomicModifyIORef windowContextCaches (\m -> (m,m))
                       return $ Map.lookup w m

saveContextCache :: ContextCache -> IO ()
saveContextCache c = atomicModifyIORef windowContextCaches $ \ m -> (Map.insert (contextWindow c) c m, ())

changeContextSize :: GLUT.Window -> Size -> IO ()
changeContextSize w s = atomicModifyIORef windowContextCaches $ \ m -> (Map.adjust (\c -> c {contextViewPort = s}) w m, ())

getCurrentOrSetHiddenContext = do
    mw <- get GLUT.currentWindow
    case mw of Just w  -> do mc <- getContextCache w
                             case mc of Just cache -> return cache
                                        Nothing    -> setAndGetHiddenWindow
               Nothing -> setAndGetHiddenWindow
    where
        setAndGetHiddenWindow = do GLUT.currentWindow $= Just (contextWindow hiddenWindowContextCache)
                                   return hiddenWindowContextCache


evaluateDeep a = do t <- evaluate (a==a)
                    case t of True  -> return a
                              False -> return undefined

ioEvaluate :: Eq a => a -> ContextCacheIO a
ioEvaluate = liftIO . evaluateDeep

evaluatePtr p = do a <- peek (castPtr p :: Ptr CUChar)
                   t <- evaluate (a==a)
                   case t of True  -> return p
                             False -> return undefined

----------------------------------------------------
-- Texture operations


type WinMappedTexture = IORef (Map GLUT.Window TextureObject)

newWinMappedTexture :: (TextureObject -> ContextCache -> IO a) -> IO WinMappedTexture
newWinMappedTexture ionew = do
    cache <- getCurrentOrSetHiddenContext
    [tex] <- genObjectNames 1
    -- putStrLn $ "Created texture " ++ show tex
    ionew tex cache
    ref <- newIORef $ Map.singleton (contextWindow cache) tex
    mkWeakIORef ref $ do m <- readIORef ref
                         w <- get GLUT.currentWindow
                         mapM_ deleteTexture $ Map.toList m
                         GLUT.currentWindow $= w
                         -- putStrLn "Deleted texture"
    return ref
    where
        deleteTexture (w,t) = do GLUT.currentWindow $= Just w
                                 deleteObjectNames [t]
                                   
bindWinMappedTexture target w ref s  = do
    mtex <- atomicModifyIORef ref (\a -> (a, takeOne a))
    case mtex of
        Right tex -> textureBinding target $= Just tex
        Left (w',t) -> do GLUT.currentWindow $= Just w'
                          textureBinding target $= Just t
                          ft <- get $ textureLevelRange target
                          f <- get $ textureInternalFormat (Left target) 0
                          tex <- transferTexture s ft f
                          textureLevelRange target $= ft
                          -- putStrLn $ "Transferred texture " ++ show tex
                          atomicModifyIORef ref (flip (,) () . Map.insertWith (const id) w tex)
    where takeOne a = case Map.lookup w a of
                        Nothing -> Left $ Map.elemAt 0 a
                        Just t -> Right t
          createTexInWin = do GLUT.currentWindow $= Just w
                              [tex] <- genObjectNames 1
                              textureBinding target $= Just tex
                              return tex
          transferTexture Sampler3D (from,to) f = do
                psSize <- mapM getDataAndSize3D [from..to]
                tex <- createTexInWin
                mapM_ (setDataWithSize3D f) $ zip psSize [from..to]
                return tex
          transferTexture Sampler2D (from,to) f = do
                psSize <- mapM getDataAndSize2D [from..to]
                tex <- createTexInWin
                mapM_ (setDataWithSize2D f) $ zip psSize [from..to]
                return tex
          transferTexture Sampler1D (from,to) f = do
                psSize <- mapM getDataAndSize1D [from..to]
                tex <- createTexInWin
                mapM_ (setDataWithSize1D f) $ zip psSize [from..to]
                return tex
          transferTexture SamplerCube (from,to) f =  do
                psSize <- mapM getDataAndSizeCube [(n,side) | n <- [from..to], side <- cubeMapTargets]
                tex <- createTexInWin
                mapM_ (setDataWithSizeCube f) $ zip psSize [(n,side) | n <- [from..to], side <- cubeMapTargets]
                return tex

          getDataAndSize3D n = do
                s@(TextureSize3D x y z) <- get $ textureSize3D (Left Texture3D) n
                fp <- mallocForeignPtrBytes (fromIntegral x * fromIntegral y * fromIntegral z * 4 * sizeOf (undefined :: Float))
                withForeignPtr fp $ \ p -> getTexImage (Left Texture3D) n (PixelData RGBA Float p)
                return (s,fp)
          setDataWithSize3D f ((s,fp),n) =
                withForeignPtr fp $ \ p -> texImage3D NoProxy n f s 0 (PixelData RGBA Float p)
          getDataAndSize2D n = do
                s@(TextureSize2D x y) <- get $ textureSize2D (Left Texture2D) n
                fp <- mallocForeignPtrBytes (fromIntegral x * fromIntegral y * 4 * sizeOf (undefined :: Float))
                withForeignPtr fp $ \ p -> getTexImage (Left Texture2D) n (PixelData RGBA Float p)
                return (s,fp)
          setDataWithSize2D f ((s,fp),n) =
                withForeignPtr fp $ \ p -> texImage2D Nothing NoProxy n f s 0 (PixelData RGBA Float p)
          getDataAndSize1D n = do
                s@(TextureSize1D x) <- get $ textureSize1D (Left Texture1D) n
                fp <- mallocForeignPtrBytes (fromIntegral x * 4 * sizeOf (undefined :: Float))
                withForeignPtr fp $ \ p -> getTexImage (Left Texture2D) n (PixelData RGBA Float p)
                return (s,fp)
          setDataWithSize1D f ((s,fp),n) =
                withForeignPtr fp $ \ p -> texImage1D NoProxy n f s 0 (PixelData RGBA Float p)
          getDataAndSizeCube (n,side) = do
                s@(TextureSize2D x y) <- get $ textureSize2D (Right side) n
                fp <- mallocForeignPtrBytes (fromIntegral x * fromIntegral y * 4 * sizeOf (undefined :: Float))
                withForeignPtr fp $ \ p -> getTexImage (Right side) n (PixelData RGBA Float p)
                return (s,fp)
          setDataWithSizeCube f ((s,fp),(n,side)) =
                withForeignPtr fp $ \ p -> texImage2D (Just side) NoProxy n f s 0 (PixelData RGBA Float p)

cubeMapTargets = [TextureCubeMapPositiveX, TextureCubeMapNegativeX, TextureCubeMapPositiveY, TextureCubeMapNegativeY, TextureCubeMapPositiveZ, TextureCubeMapNegativeZ]