{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module Aztecs.GL
  ( -- * Colors
    Color (..),

    -- * Meshes
    Mesh (..),

    -- * Shapes
    Rectangle (..),
    Circle (..),
    Triangle (..),

    -- * Rendering
    render,

    -- * Transform
    module Aztecs.Transform,
  )
where

import Aztecs
import Aztecs.GLFW (RawWindow (..), Window (..))
import Aztecs.Transform
import Control.Monad
import Control.Monad.IO.Class
import qualified Data.Vector.Storable as SV
import Foreign.Ptr
import Foreign.Storable
import Graphics.Rendering.OpenGL (($=))
import qualified Graphics.Rendering.OpenGL as GL
import qualified Graphics.UI.GLFW as GLFW
import Prelude hiding (lookup)

-- | RGBA Color
data Color = Color
  { colorR :: !Float,
    colorG :: !Float,
    colorB :: !Float,
    colorA :: !Float
  }
  deriving (Show, Eq)

instance (Monad m) => Component m Color

-- | Mesh component
data Mesh = Mesh
  { meshVBO :: !GL.BufferObject,
    meshVertexCount :: !GL.NumArrayIndices,
    meshPrimitiveMode :: !GL.PrimitiveMode
  }
  deriving (Show, Eq)

instance (Monad m) => Component m Mesh

inParentWindowContext :: (MonadIO m) => EntityID -> Access m () -> Access m ()
inParentWindowContext e action = do
  res <- lookup e
  case res of
    Just (Parent parentE) -> do
      res' <- lookup parentE
      case res' of
        Just (RawWindow raw _) -> do
          liftIO . GLFW.makeContextCurrent $ Just raw
          action
        Nothing -> return ()
    Nothing -> return ()

-- | Rectangle component
data Rectangle = Rectangle
  { rectangleWidth :: !Float,
    rectangleHeight :: !Float
  }
  deriving (Show, Eq)

instance (MonadIO m) => Component m Rectangle where
  componentOnInsert e rect = inParentWindowContext e $ do
    mesh <- liftIO $ compileRectangle rect
    insert e $ bundle mesh
  componentOnChange e oldRect newRect = when (oldRect /= newRect) . inParentWindowContext e $ do
    mMesh <- lookup e
    case mMesh of
      Just (Mesh vbo _ _) -> liftIO $ GL.deleteObjectName vbo
      Nothing -> return ()
    mesh <- liftIO $ compileRectangle newRect
    insert e $ bundle mesh
  componentOnRemove e _ = inParentWindowContext e $ do
    mMesh <- lookup e
    case mMesh of
      Just (Mesh vbo _ _) -> do
        liftIO $ GL.deleteObjectName vbo
        _ <- remove @_ @Mesh e
        return ()
      Nothing -> return ()

-- | Circle component
data Circle = Circle
  { circleRadius :: !Float,
    circleSegments :: !Int
  }
  deriving (Show, Eq)

instance (MonadIO m) => Component m Circle where
  componentOnInsert e circ = inParentWindowContext e $ do
    mesh <- liftIO $ compileCircle circ
    insert e $ bundle mesh
  componentOnChange e oldCirc newCirc = when (oldCirc /= newCirc) . inParentWindowContext e $ do
    mMesh <- lookup e
    case mMesh of
      Just (Mesh vbo _ _) -> liftIO $ GL.deleteObjectName vbo
      Nothing -> return ()
    mesh <- liftIO $ compileCircle newCirc
    insert e $ bundle mesh
  componentOnRemove e _ = inParentWindowContext e $ do
    mMesh <- lookup e
    case mMesh of
      Just (Mesh vbo _ _) -> do
        liftIO $ GL.deleteObjectName vbo
        _ <- remove @_ @Mesh e
        return ()
      Nothing -> return ()

-- | Triangle component
data Triangle = Triangle
  { triangleX1 :: !Float,
    triangleY1 :: !Float,
    triangleX2 :: !Float,
    triangleY2 :: !Float,
    triangleX3 :: !Float,
    triangleY3 :: !Float
  }
  deriving (Show, Eq)

instance (MonadIO m) => Component m Triangle where
  componentOnInsert e tri = inParentWindowContext e $ do
    mesh <- liftIO $ compileTriangle tri
    insert e $ bundle mesh
  componentOnChange e oldTri newTri = when (oldTri /= newTri) . inParentWindowContext e $ do
    mMesh <- lookup e
    case mMesh of
      Just (Mesh vbo _ _) -> liftIO $ GL.deleteObjectName vbo
      Nothing -> return ()
    mesh <- liftIO $ compileTriangle newTri
    insert e $ bundle mesh
  componentOnRemove e _ = inParentWindowContext e $ do
    mMesh <- lookup e
    case mMesh of
      Just (Mesh vbo _ _) -> do
        liftIO $ GL.deleteObjectName vbo
        _ <- remove @_ @Mesh e
        return ()
      Nothing -> return ()

-- | Compile a rectangle into a VBO mesh
compileRectangle :: Rectangle -> IO Mesh
compileRectangle (Rectangle w h) = do
  let hw = w / 2
      hh = h / 2
      vertices = SV.fromList [-hw, -hh, hw, -hh, hw, hh, -hw, hh] :: SV.Vector GL.GLfloat
  compileMesh vertices GL.Quads

-- | Compile a circle into a VBO mesh
compileCircle :: Circle -> IO Mesh
compileCircle (Circle radius segments) = do
  let angles = [2 * pi * fromIntegral i / fromIntegral segments | i <- [0 .. segments]]
      vertices = SV.fromList $ [0, 0] ++ concatMap (\a -> [radius * cos a, radius * sin a]) angles :: SV.Vector GL.GLfloat
  compileMesh vertices GL.TriangleFan

-- | Compile a triangle into a VBO mesh
compileTriangle :: Triangle -> IO Mesh
compileTriangle (Triangle x1 y1 x2 y2 x3 y3) = do
  let vertices = SV.fromList [x1, y1, x2, y2, x3, y3] :: SV.Vector GL.GLfloat
  compileMesh vertices GL.Triangles

-- | Compile vertices into a VBO mesh
compileMesh :: SV.Vector GL.GLfloat -> GL.PrimitiveMode -> IO Mesh
compileMesh vertices mode = do
  let vertexCount = fromIntegral $ SV.length vertices `div` 2

  -- Create and bind VBO
  [vbo] <- GL.genObjectNames 1
  GL.bindBuffer GL.ArrayBuffer $= Just vbo

  -- Upload vertex data
  SV.unsafeWith vertices $ \ptr -> do
    let dataSize = fromIntegral $ SV.length vertices * sizeOf (undefined :: GL.GLfloat)
    GL.bufferData GL.ArrayBuffer $= (dataSize, ptr, GL.StaticDraw)

  GL.bindBuffer GL.ArrayBuffer $= Nothing
  return $ Mesh vbo vertexCount mode

-- | Render all @Mesh@ components that are descendants of @Window@s
render :: (MonadIO m) => Access m ()
render = do
  windows <- system . readQuery $ (,,) <$> query @_ @Window <*> query @_ @RawWindow <*> query @_ @Children
  mapM_ go windows
  where
    go (window, RawWindow raw _, children) = do
      liftIO $ do
        GLFW.makeContextCurrent (Just raw)

        -- Set viewport and clear
        GL.viewport $= (GL.Position 0 0, GL.Size (fromIntegral $ windowWidth window) (fromIntegral $ windowHeight window))
        GL.clearColor $= GL.Color4 0 0 0 1
        GL.clear [GL.ColorBuffer]

        -- Set up orthographic projection
        GL.matrixMode $= GL.Projection
        GL.loadIdentity
        GL.ortho 0 (fromIntegral $ windowWidth window) 0 (fromIntegral $ windowHeight window) (-1) 1
        GL.matrixMode $= GL.Modelview 0
        GL.loadIdentity
      mapM_ go' $ unChildren children
    go' e = do
      meshRes <- lookup e
      transformRes <- lookup e
      colorRes <- lookup e
      let res = (,,) <$> meshRes <*> transformRes <*> colorRes
      case res of
        Just (mesh, trans, color) -> liftIO $ renderMesh (mesh, trans, color)
        Nothing -> return ()

      childrenRes <- lookup e
      case childrenRes of
        Just (Children cs) -> mapM_ go' cs
        Nothing -> return ()

-- | Render a mesh with transform and color
renderMesh :: (Mesh, Transform2D, Color) -> IO ()
renderMesh (Mesh vbo vertexCount mode, t, color) = do
  let V2 tx ty = transformTranslation t
      V2 sx sy = transformScale t
      rot = transformRotation t
  GL.preservingMatrix $ do
    -- Apply transform
    GL.translate $ GL.Vector3 (realToFrac tx) (realToFrac ty) (0 :: GL.GLfloat)
    GL.rotate (realToFrac rot) $ GL.Vector3 0 0 (1 :: GL.GLfloat)
    GL.scale (realToFrac sx) (realToFrac sy) (1 :: GL.GLfloat)

    -- Set color
    GL.color $ GL.Color4 (realToFrac $ colorR color) (realToFrac $ colorG color) (realToFrac $ colorB color) (realToFrac $ colorA color :: GL.GLfloat)

    -- Bind VBO and set up vertex pointer
    GL.bindBuffer GL.ArrayBuffer $= Just vbo
    GL.vertexAttribPointer (GL.AttribLocation 0)
      $= (GL.ToFloat, GL.VertexArrayDescriptor 2 GL.Float 0 nullPtr)
    GL.vertexAttribArray (GL.AttribLocation 0) $= GL.Enabled

    -- Also set up the legacy vertex pointer for fixed-function pipeline
    GL.clientState GL.VertexArray $= GL.Enabled
    GL.arrayPointer GL.VertexArray $= GL.VertexArrayDescriptor 2 GL.Float 0 nullPtr

    -- Draw
    GL.drawArrays mode 0 vertexCount

    -- Cleanup
    GL.clientState GL.VertexArray $= GL.Disabled
    GL.vertexAttribArray (GL.AttribLocation 0) $= GL.Disabled
    GL.bindBuffer GL.ArrayBuffer $= Nothing
