{-|

Functions for combining meshes.

-}

module Graphics.LambdaCube.MeshUtil (mkVMesh,mkVMesh') where

import qualified Data.Vector as V
-- TODO: or a more suitable algorithm
import qualified Data.Vector.Algorithms.Intro as V
import Data.List
import Data.Maybe
import Data.Ord

import Graphics.LambdaCube.RenderOperation
import Graphics.LambdaCube.Types hiding (transpose)
import Graphics.LambdaCube.Utility
import Graphics.LambdaCube.VertexBufferVector

type MeshGroup = (String,VVB,Maybe VIB,Proj4,OperationType)

-- | Build a single mesh that represents the union of a list of
-- transformed meshes (given the orientation and translation for each
-- constituent).  The resulting mesh is optimised with respect to
-- context switches during rendering.
mkVMesh :: [(U,Vec3,VMesh)] -> VMesh
mkVMesh vml = mkVMesh' [((orthogonal $ rightOrthoU o) .*. translation p,m) | (o,p,m) <- vml]

-- | Build a single mesh that represents the union of a list of
-- transformed meshes (given the transformation matrix for each
-- constituent).  The resulting mesh is optimised with respect to
-- context switches during rendering.
mkVMesh' :: [(Proj4,VMesh)] -> VMesh
mkVMesh' vml = VMesh [joinGroup g | g <- groupByGeometry $ groupByMaterial $ vertexData vml] Nothing

-- FIXME: problem shared geom vs private geom
vertexData :: [(Proj4,VMesh)] -> [MeshGroup]
vertexData l =
    [ (materialName, sortedVData gData vData, iData, proj, opType)
    | (proj, VMesh subMeshList gData) <- l
    , VSubMesh materialName opType vData iData <- subMeshList
    ]
  where
    sortedVData local global = V.modify (V.sortBy (comparing vectorVertexType)) $ case (local,global) of
        (Just a, _) -> a
        (Nothing, Just a) -> a
        _ -> error "illegal mesh format"

groupByMaterial :: [MeshGroup] -> [[MeshGroup]]
groupByMaterial = groupSetBy (\(a,_,_,_,_) (b,_,_,_,_) -> compare a b)

groupByGeometry :: [[MeshGroup]] -> [[MeshGroup]]
groupByGeometry l = groupSetBy compareMeshItem =<< l
  where
    compareMeshItem (_,_,Just _,_,_) (_,_,Nothing,_,_) = GT
    compareMeshItem (_,_,Nothing,_,_) (_,_,Just _,_,_) = LT
    compareMeshItem (_,a1,_,_,a2) (_,b1,_,_,b2) = compare (V.map vectorVertexType a1, a2) (V.map vectorVertexType b1, b2)

joinGroup :: [MeshGroup] -> VSubMesh
joinGroup groupMeshList = VSubMesh materialName operationType joinedVertexData joinedIndexData
  where
    (materialName,_,indexData,_,operationType) = head groupMeshList
    vertexDataList :: [[(VectorVertexData,Proj4)]]
    vertexDataList = [[(v,proj) | v <- V.toList vd] | (_,vd,_,proj,_) <- groupMeshList]

    joinedIndexData = case indexData of
        Nothing -> Nothing
        Just _  -> let indexDataList = [fromJust id | (_,_,id,_,_) <- groupMeshList]
                       offsets = scanl (+) 0 [V.length v | a <- vertexDataList, (VVD_POSITION v,_) <- a]
                   in Just $ V.concat $ zipWith (\o v -> V.map (+o) v) offsets indexDataList

    joinedVertexData :: Maybe VVB
    joinedVertexData = Just $ V.fromList $ map mergeAttribs $ transpose vertexDataList

mergeAttribs :: [(VectorVertexData, Proj4)] -> VectorVertexData
mergeAttribs ca = case vectorVertexType $ (fst (head ca)) of
    VVT_BINORMAL ->             VVD_BINORMAL $             V.concat [rot proj v      | (VVD_BINORMAL v,proj) <- ca]
    VVT_BLEND_INDICES ->        VVD_BLEND_INDICES $        V.concat [v               | (VVD_BLEND_INDICES v,_proj) <- ca]
    VVT_BLEND_WEIGHTS ->        VVD_BLEND_WEIGHTS $        V.concat [v               | (VVD_BLEND_WEIGHTS v,_proj) <- ca]
    VVT_DIFFUSE ->              VVD_DIFFUSE $              V.concat [v               | (VVD_DIFFUSE v,_proj) <- ca]
    VVT_NORMAL ->               VVD_NORMAL $               V.concat [rot proj v      | (VVD_NORMAL v,proj) <- ca]
    VVT_POSITION ->             VVD_POSITION $             V.concat [rotTrans proj v | (VVD_POSITION v,proj) <- ca]
    VVT_SPECULAR ->             VVD_SPECULAR $             V.concat [v               | (VVD_SPECULAR v,_proj) <- ca]
    VVT_TANGENT ->              VVD_TANGENT $              V.concat [rot proj v      | (VVD_TANGENT v,proj) <- ca]
    VVT_TEXTURE_COORDINATES1 -> VVD_TEXTURE_COORDINATES1 $ V.concat [v               | (VVD_TEXTURE_COORDINATES1 v,_proj) <- ca]
    VVT_TEXTURE_COORDINATES2 -> VVD_TEXTURE_COORDINATES2 $ V.concat [v               | (VVD_TEXTURE_COORDINATES2 v,_proj) <- ca]
    VVT_TEXTURE_COORDINATES3 -> VVD_TEXTURE_COORDINATES3 $ V.concat [v               | (VVD_TEXTURE_COORDINATES3 v,_proj) <- ca]
  where
    mulProj4 :: Proj4 -> Vec3 -> Vec3
    mulProj4 p v = trim ((extendWith 1 v :: Vec4) .* fromProjective p)
    rot proj v = V.map (mulProj4 proj') v
      where
        proj' = linear $ trim $ fromProjective proj :: Proj4
    rotTrans proj v = V.map (mulProj4 proj) v