{-# OPTIONS_GHC -DFlt=Float -DVECT_Float #-}

-- TODO: the pointer versions of these functions should be really implemented 
-- via the pointer versions of the original opengl functions...

-- | OpenGL support, including 'Vertex', 'TexCoord', etc instances for 'Vec2', 'Vec3' and 'Vec4'.
 
module Data.Vect.Flt.OpenGL where

import Control.Monad
import Data.Vect.Flt.Base
import Data.Vect.Flt.Util.Projective
import qualified Graphics.Rendering.OpenGL as GL

import Foreign

import Graphics.Rendering.OpenGL hiding 
  ( Normal3 , rotate , translate , scale
  , matrix , currentMatrix , withMatrix , multMatrix 
  )

--------------------------------------------------------------------------------

-- | There should be a big warning here about the different conventions, 
-- hidden transpositions, and all the confusion this will inevitably cause...
--
-- As it stands, 
--
-- > glRotate t1 axis1 >> glRotate t2 axis2 >> glRotate t3 axis3
-- 
-- has the same result as
--
-- > multMatrix (rotMatrixProj4 t3 axis3 .*. rotMatrixProj4 t2 axis2 .*. rotMatrixProj4 t1 axis1)
--
-- because at the interface of OpenGL and this library there is a transposition
-- to compensate for the different conventions. (This transposition is implicit
-- in the code, because the way the matrices are stored in the memory is also
-- different: OpenGL stores them column-major, and we store them row-major).

class ToOpenGLMatrix m where
  makeGLMatrix :: m -> IO (GLmatrix Flt)

class FromOpenGLMatrix m where
  peekGLMatrix :: GLmatrix Flt -> IO m
  
setMatrix :: ToOpenGLMatrix m => Maybe MatrixMode -> m -> IO ()
setMatrix mode m = makeGLMatrix m >>= \x -> GL.matrix mode $= x
 
getMatrix :: FromOpenGLMatrix m => Maybe MatrixMode -> IO m
getMatrix mode = get (GL.matrix mode) >>= peekGLMatrix

matrix :: (ToOpenGLMatrix m, FromOpenGLMatrix m) => Maybe MatrixMode -> StateVar m
matrix mode = makeStateVar (getMatrix mode) (setMatrix mode)

currentMatrix :: (ToOpenGLMatrix m, FromOpenGLMatrix m) => StateVar m
currentMatrix = matrix Nothing

multMatrix :: ToOpenGLMatrix m => m -> IO ()
multMatrix m = makeGLMatrix m >>= GL.multMatrix

instance ToOpenGLMatrix Mat4 where
  makeGLMatrix m = GL.withNewMatrix GL.ColumnMajor (flip poke m . castPtr) 
 
instance FromOpenGLMatrix Mat4 where
  -- huh? GL.withMatrix is strange
  peekGLMatrix x = GL.withMatrix x $ \_ p -> peek (castPtr p)
  
instance ToOpenGLMatrix Mat3 where
  makeGLMatrix m = makeGLMatrix (extendWith 1 m :: Mat4)
 
instance ToOpenGLMatrix Mat2 where
  makeGLMatrix m = makeGLMatrix (extendWith 1 m :: Mat4)

instance ToOpenGLMatrix Ortho4 where
  makeGLMatrix m = makeGLMatrix (fromOrtho m :: Mat4)

instance ToOpenGLMatrix Ortho3 where
  makeGLMatrix m = makeGLMatrix (fromOrtho m :: Mat3)

instance ToOpenGLMatrix Ortho2 where
  makeGLMatrix m = makeGLMatrix (fromOrtho m :: Mat2)

instance ToOpenGLMatrix Proj4 where
  makeGLMatrix m = makeGLMatrix (fromProjective m :: Mat4)

instance ToOpenGLMatrix Proj3 where
  makeGLMatrix m = makeGLMatrix (fromProjective m :: Mat3)
  
--------------------------------------------------------------------------------

{-# SPECIALISE radianToDegrees :: Float  -> Float  #-}
{-# SPECIALISE radianToDegrees :: Double -> Double #-}
radianToDegrees :: RealFrac a => a -> a
radianToDegrees x = x * 57.295779513082322

{-# SPECIALIZE degreesToRadian :: Float  -> Float  #-}
{-# SPECIALIZE degreesToRadian :: Double -> Double #-}
degreesToRadian :: Floating a => a -> a
degreesToRadian x = x * 1.7453292519943295e-2

-- | The angle is in radians. (WARNING: OpenGL uses degrees!)
glRotate :: Flt -> Vec3 -> IO ()
glRotate angle (Vec3 x y z) = GL.rotate (radianToDegrees angle) (Vector3 x y z)

glTranslate :: Vec3 -> IO ()
glTranslate (Vec3 x y z) = GL.translate (Vector3 x y z)

glScale3 :: Vec3 -> IO ()
glScale3 (Vec3 x y z) = GL.scale x y z

glScale :: Flt -> IO ()
glScale x = GL.scale x x x

--------------------------------------------------------------------------------
 
-- | \"Orthogonal projecton\" matrix, a la OpenGL 
-- (the corresponding functionality is removed in OpenGL 3.1)
orthoMatrix 
  :: (Flt,Flt)   -- ^ (left,right)
  -> (Flt,Flt)   -- ^ (bottom,top)
  -> (Flt,Flt)   -- ^ (near,far)
  -> Mat4 
orthoMatrix (l,r) (b,t) (n,f) = Mat4
  (Vec4 (2/(r-l)) 0 0 0)
  (Vec4 0 (2/(t-b)) 0 0)
  (Vec4 0 0 (-2/(f-n)) 0)
  (Vec4 (-(r+l)/(r-l)) (-(t+b)/(t-b)) (-(f+n)/(f-n)) 1)
  
-- | The same as "orthoMatrix", but with a different parametrization.
orthoMatrix2 {- ' CPP is sensitive to primes -}
  :: Vec3     -- ^ (left,top,near)
  -> Vec3     -- ^ (right,bottom,far)
  -> Mat4 
orthoMatrix2 (Vec3 l t n) (Vec3 r b f) = orthoMatrix (l,r) (b,t) (n,f)

-- | \"Perspective projecton\" matrix, a la OpenGL 
-- (the corresponding functionality is removed in OpenGL 3.1).
frustumMatrix
  :: (Flt,Flt)   -- ^ (left,right)
  -> (Flt,Flt)   -- ^ (bottom,top)
  -> (Flt,Flt)   -- ^ (near,far)
  -> Mat4 
frustumMatrix (l,r) (b,t) (n,f) = Mat4
  (Vec4 (2*n/(r-l)) 0 0 0)
  (Vec4 0 (2*n/(t-b)) 0 0)
  (Vec4 ((r+l)/(r-l)) ((t+b)/(t-b)) (-(f+n)/(f-n)) (-1))
  (Vec4 0 0 (-2*f*n*(f-n)) 0)
  
-- | The same as "frustumMatrix", but with a different parametrization.
frustumMatrix2 {- ' CPP is sensitive to primes -}
  :: Vec3     -- ^ (left,top,near)
  -> Vec3     -- ^ (right,bottom,far)
  -> Mat4 
frustumMatrix2 (Vec3 l t n) (Vec3 r b f) = frustumMatrix (l,r) (b,t) (n,f)

--------------------------------------------------------------------------------
-- Vertex instances

instance GL.Vertex Vec2 where
  vertex (Vec2 x y) = GL.vertex (GL.Vertex2 x y)
  vertexv p = peek p >>= vertex 
  
instance GL.Vertex Vec3 where
  vertex (Vec3 x y z) = GL.vertex (GL.Vertex3 x y z)
  vertexv p = peek p >>= vertex   
  
instance GL.Vertex Vec4 where
  vertex (Vec4 x y z w) = GL.vertex (GL.Vertex4 x y z w)
  vertexv p = peek p >>= vertex   

--------------------------------------------------------------------------------
-- the Normal instance
-- note that there is no Normal2\/Normal4 in the OpenGL binding

instance GL.Normal Normal3 where
  normal u = GL.normal (GL.Normal3 x y z) 
    where Vec3 x y z = fromNormal u 
  normalv p = peek p >>= normal 

instance GL.Normal Vec3 where
  normal (Vec3 x y z) = GL.normal (GL.Normal3 x y z) 
  normalv p = peek p >>= normal 

--------------------------------------------------------------------------------
-- Color instances
  
instance GL.Color Vec3 where
  color (Vec3 r g b) = GL.color (GL.Color3 r g b)
  colorv p = peek p >>= color

instance GL.Color Vec4 where
  color (Vec4 r g b a) = GL.color (GL.Color4 r g b a)
  colorv p = peek p >>= color

instance GL.SecondaryColor Vec3 where
  secondaryColor (Vec3 r g b) = GL.secondaryColor (GL.Color3 r g b)
  secondaryColorv p = peek p >>= secondaryColor

{-
-- there is no such thing?
instance GL.SecondaryColor Vec4 where
  secondaryColor (Vec4 r g b a) = GL.secondaryColor (GL.Color4 r g b a)
  secondaryColorv p = peek p >>= secondaryColor
-}

--------------------------------------------------------------------------------
-- TexCoord instances

instance GL.TexCoord Vec2 where
  texCoord (Vec2 u v) = GL.texCoord (GL.TexCoord2 u v)
  texCoordv p = peek p >>= texCoord
  multiTexCoord unit (Vec2 u v) = GL.multiTexCoord unit (GL.TexCoord2 u v)
  multiTexCoordv unit p = peek p >>= multiTexCoord unit

instance GL.TexCoord Vec3 where
  texCoord (Vec3 u v w) = GL.texCoord (GL.TexCoord3 u v w)
  texCoordv p = peek p >>= texCoord
  multiTexCoord unit (Vec3 u v w) = GL.multiTexCoord unit (GL.TexCoord3 u v w)
  multiTexCoordv unit p = peek p >>= multiTexCoord unit

instance GL.TexCoord Vec4 where
  texCoord (Vec4 u v w z) = GL.texCoord (GL.TexCoord4 u v w z)
  texCoordv p = peek p >>= texCoord
  multiTexCoord unit (Vec4 u v w z) = GL.multiTexCoord unit (GL.TexCoord4 u v w z)
  multiTexCoordv unit p = peek p >>= multiTexCoord unit

--------------------------------------------------------------------------------
-- Vertex Attributes (experimental)

class VertexAttrib' a where
  vertexAttrib :: GL.AttribLocation -> a -> IO ()
  
instance VertexAttrib' {- ' CPP is sensitive to primes -} Flt where
  vertexAttrib loc x = GL.vertexAttrib1 loc x

instance VertexAttrib' Vec2 where
  vertexAttrib loc (Vec2 x y) = GL.vertexAttrib2 loc x y

instance VertexAttrib' Vec3 where
  vertexAttrib loc (Vec3 x y z) = GL.vertexAttrib3 loc x y z 

instance VertexAttrib' Vec4 where
  vertexAttrib loc (Vec4 x y z w) = GL.vertexAttrib4 loc x y z w 

instance VertexAttrib' Normal2 where
  vertexAttrib loc u = GL.vertexAttrib2 loc x y
    where Vec2 x y = fromNormal u 

instance VertexAttrib' Normal3 where
  vertexAttrib loc u = GL.vertexAttrib3 loc x y z
    where Vec3 x y z = fromNormal u 

instance VertexAttrib' Normal4 where
  vertexAttrib loc u = GL.vertexAttrib4 loc x y z w
    where Vec4 x y z w = fromNormal u 
   
--------------------------------------------------------------------------------
-- Uniform (again, experimental)

-- (note that the uniform location code in the OpenGL 2.2.1.1 is broken; 
-- a work-around is to put a zero character at the end of uniform names)

{-
toFloat :: Flt -> Float
toFloat = realToFrac

fromFloat :: Float -> Flt
fromFloat = realToFrac
-}

-- Uniforms are always floats...
#ifdef VECT_Float

instance GL.Uniform Flt where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Index1 x) -> x) $ get (uniform loc)
    setter x = ($=) (uniform loc) (Index1 x) 
  uniformv loc cnt ptr = uniformv loc cnt (castPtr ptr :: Ptr (Index1 Flt))

instance GL.Uniform Vec2 where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Vertex2 x y) -> Vec2 x y) $ get (uniform loc)
    setter (Vec2 x y) = ($=) (uniform loc) (Vertex2 x y) 
  uniformv loc cnt ptr = uniformv loc (2*cnt) (castPtr ptr :: Ptr Flt)

instance GL.Uniform Vec3 where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Vertex3 x y z) -> Vec3 x y z) $ get (uniform loc)
    setter (Vec3 x y z) = ($=) (uniform loc) (Vertex3 x y z) 
  uniformv loc cnt ptr = uniformv loc (3*cnt) (castPtr ptr :: Ptr Flt)

instance GL.Uniform Vec4 where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Vertex4 x y z w) -> Vec4 x y z w) $ get (uniform loc)
    setter (Vec4 x y z w) = ($=) (uniform loc) (Vertex4 x y z w) 
  uniformv loc cnt ptr = uniformv loc (4*cnt) (castPtr ptr :: Ptr Flt)
    
#endif