module Graphics.LambdaCube.Loader.MeshXML where

import Data.List
import Data.Maybe
import Data.Word
import Foreign.C.Types
import Foreign.Marshal.Array
import Foreign.Ptr
import qualified Data.IntMap as IntMap

import Text.XML.Light

import Graphics.LambdaCube.HardwareBuffer
import Graphics.LambdaCube.HardwareIndexBuffer
import Graphics.LambdaCube.HardwareVertexBuffer
import Graphics.LambdaCube.Mesh
import Graphics.LambdaCube.RenderOperation
import Graphics.LambdaCube.RenderSystem
import Graphics.LambdaCube.Types
import Graphics.LambdaCube.VertexIndexData

readBool :: String -> Bool -> Element -> Bool
readBool n v e = maybe v (=="true") $ findAttr (unqual n) e

readInt :: String -> Int -> Element -> Int
readInt n v e = maybe v read $ findAttr (unqual n) e

readFloatType :: String -> FloatType -> Element -> FloatType
readFloatType n v e = maybe v read $ findAttr (unqual n) e

readStr :: String -> String -> Element -> String
readStr n v e = fromMaybe v $ findAttr (unqual n) e

readBoneAssignments :: Element -> ([Int],[FloatType])
readBoneAssignments ba = unzip $ [(i,w) | (_,i,w) <- sortBy mycmp $ foldl' readAssignment [] $ findElements (unqual "vertexboneassignment") ba]
  where
    mycmp (a,_,_) (b,_,_) = a `compare` b
    readAssignment l e = (v,i,w):l
      where
        w = readFloatType "weight" 1 e
        i = readInt "boneindex" 0 e
        v = readInt "vertexindex" 0 e

readDecl :: (Element,Int) -> [VertexElement]
readDecl (x,n) = [v { veIndex = i } | (v,i) <- zip declList $ reverse indexList]
  where
    hasPositions = readBool "positions"          False x
    hasNormals   = readBool "normals"            False x
    hasDiffuse   = readBool "colours_diffuse"    False x
    hasSpecular  = readBool "colours_specular"   False x
    hasBinormals = readBool "binormals"          False x
    hasTangents  = readBool "tangents"           False x
    dimTangents  = readInt  "tangent_dimensions" 3 x
    numTexCoords = readInt  "texture_coords"     0 x
    dimTexCoords = [readInt ("texture_coord_dimensions_" ++ show (i-1)) 2 x | i <- [1..numTexCoords]]

    mColourElementType = VET_COLOUR_ABGR

    f a b       = if a then [b] else []

    l           = [ f hasPositions (VET_FLOAT3, VES_POSITION)
                  , f hasNormals   (VET_FLOAT3, VES_NORMAL)
                  , f hasTangents  (if dimTangents == 4 then VET_FLOAT4 else VET_FLOAT3, VES_TANGENT)
                  , f hasBinormals (VET_FLOAT3, VES_BINORMAL)
                  , f hasDiffuse   (mColourElementType, VES_DIFFUSE)
                  , f hasSpecular  (mColourElementType, VES_SPECULAR)
                  ] ++
                  [[(multiplyTypeCount VET_FLOAT1 dims,VES_TEXTURE_COORDINATES)] | dims <- dimTexCoords]
    l'          = concat l
    offs        = scanl (\a (b,_) -> a + getTypeSize b) 0 l'
    declList    = [VertexElement n o t s 0 | (o,(t,s)) <- zip offs l']
    indexList   = snd $ foldl fi (IntMap.empty,[]) declList
    fi (m,vl) e = case IntMap.lookup semIdx m of
                    { Nothing   -> (IntMap.insert semIdx 0 m,0:vl)
                    ; Just i    -> (IntMap.insert semIdx (i+1) m,(i+1):vl)
                    }
      where
        semIdx  = fromEnum $ veSemantic e

-- TODO: use vector vbs
readGeometry :: (RenderSystem rs vb ib q t p lp) => rs -> Element -> IO (VertexData vb)
readGeometry rs x = do
    let vcount = read $ fromMaybe (error "fromJust 0") $ findAttr (unqual "vertexcount") x
        vbs = findElements (unqual "vertexbuffer") x
        elems = [readDecl v | v <- zip vbs [0..]]
        sizes = [foldl (\a b -> a + (getTypeSize $ veType b)) 0 e | e <- elems]
        -- create VertexDeclaration
        decl = VertexDeclaration $ concat elems
        -- create VertexBufferBinding
        usage = HBU_STATIC -- TODO

    --debugM "readGeometry" $ "gemetry declaration: " ++ show elems

    bufs <- mapM (\s -> createVertexBuffer rs s vcount usage True) sizes
    let binding = VertexBufferBinding $ IntMap.fromList $ zip [0..] bufs
        fillBuffer (d,vex,b) = do
            -- 1. lock buffer
            ptr <- lock b 0 (getSizeInBytes b) HBL_NORMAL
            let fillVertex (i,vx) = do
                    -- iterate over elements and read data from xml subelement
                    let fillAttribute e = do
                            let p = plusPtr ptr (i * (getVertexSize b) + veOffset e)
                                setCFloatAttr :: String -> [String] -> IO ()
                                setCFloatAttr en ea = pokeArray p $ getCFloat ((findElements (unqual en) vx ) !! veIndex e) ea

                                setColourAttr :: String -> IO ()
                                setColourAttr en = pokeArray p $ getColour ((findElements (unqual en) vx ) !! veIndex e)

                                getCFloat :: Element -> [String] -> [CFloat]
                                getCFloat xn ll = map ef ll
                                  where
                                    ef nm = read $ fromMaybe (error "fromJust 1") $ findAttr (unqual nm) xn

                                getColour :: Element -> [Word8]
                                getColour xn = if length values == 4 then values else 1:values
                                  where
                                    values = map read $ words $ fromMaybe (error "fromJust 2") $ findAttr (unqual "value") xn

                            case veSemantic e of
                                VES_POSITION              -> setCFloatAttr "position" ["x","y","z"]
                                VES_BLEND_WEIGHTS         -> error "invalid semantic"
                                VES_BLEND_INDICES         -> error "invalid semantic"
                                VES_NORMAL                -> setCFloatAttr "normal" ["x","y","z"]
                                VES_DIFFUSE               -> setColourAttr "colour_diffuse"
                                VES_SPECULAR              -> setColourAttr "colour_specular"
                                VES_TEXTURE_COORDINATES   -> setCFloatAttr "texcoord" $ take (getTypeCount $ veType e) ["u","v","w","x"]
                                VES_BINORMAL              -> setCFloatAttr "binormal" ["x","y","z"]
                                VES_TANGENT               -> setCFloatAttr "tangent" $ if veType e == VET_FLOAT4 then ["x","y","z","w"] else ["x","y","z"]
                            -- end fillAttribute

                    mapM_ fillAttribute d
                    -- end fillVertex

            -- 2. fill n-th vertex attributes according declaration
            mapM_ fillVertex $ zip [0..] $ findElements (unqual "vertex") vex
            -- 3. unlock buffer
            unlock b
            -- end fillBuffer

    -- fill buffers with data
    mapM_ fillBuffer $ zip3 elems vbs bufs

    -- create VertexData
    return $ VertexData decl binding 0 vcount

readSubMesh :: (RenderSystem rs vb ib q t p lp) => rs -> Element -> IO (SubMesh vb ib)
readSubMesh rs x = do
    let Just material   = findAttr (unqual "material") x
        useShared       = readBool "usesharedvertices"  True    x
        use32BitIndex   = readBool "use32bitindexes"    False   x
        (hasFaces,oper) = readOpType $ readStr  "operationtype" "triangle_list" x
        readOpType o    = case o of
            "triangle_list"   -> (True,  OT_TRIANGLE_LIST)
            "triangle_strip"  -> (True,  OT_TRIANGLE_STRIP)
            "triangle_fan"    -> (True,  OT_TRIANGLE_FAN)
            "line_strip"      -> (False, OT_LINE_STRIP)
            "line_list"       -> (False, OT_LINE_LIST)
            "point_list"      -> (False, OT_POINT_LIST)
            _                 -> error "Invalid mesh format!"

    midata <- if not hasFaces then return Nothing else do
        let Just faces      = findElement (unqual "faces") x
            faceCount       = readInt "count" 0 faces
            faceList        = findElements (unqual "face") faces
            indexCount      = if oper == OT_TRIANGLE_LIST then 3 * faceCount else 2 + faceCount
            usage           = HBU_STATIC -- TODO

            attrList        = take (if oper == OT_TRIANGLE_LIST then 3 else 1) ["v1", "v2", "v3"]
            indexList gf    = concat $ [gf (head faceList) ["v1", "v2", "v3"]] ++ [gf f attrList | f <- tail faceList]
            getWord16 :: Element -> [String] -> [Word16]
            getWord16 xn ll = map ef ll
              where
                ef nm = read $ fromMaybe (error "fromJust 3") $ findAttr (unqual nm) xn

            getWord32 :: Element -> [String] -> [Word32]
            getWord32 xn ll = map ef ll
              where
                ef nm = read $ fromMaybe (error "fromJust 4") $ findAttr (unqual nm) xn

        -- create IndexData
        ib <- createIndexBuffer rs (if use32BitIndex then IT_32BIT else IT_16BIT) indexCount usage True
        -- 1. lock buffer
        ptr <- lock ib 0 (getSizeInBytes ib) HBL_NORMAL
        -- 2. fill buffer
        case use32BitIndex of
            { True  -> pokeArray (castPtr ptr) $ indexList getWord32
            ; False -> pokeArray (castPtr ptr) $ indexList getWord16
            }
        -- 3. unlock buffer
        unlock ib
        return $ Just IndexData
            { idIndexBuffer = ib
            , idIndexStart  = 0
            , idIndexCount  = indexCount
            }

    -- read geometry if necessary
    mvdata <- if useShared then return Nothing else do
        vd <- readGeometry rs $ fromMaybe (error "fromJust 5") $ findElement (unqual "geometry") x
        return $ Just vd

    return $ SubMesh
        { smOperationType       = oper
        , smVertexData          = mvdata
        , smIndexData           = midata
        , smMaterialName        = material
        }

readMesh :: (RenderSystem rs vb ib q t p lp) => rs -> Element -> IO (Mesh vb ib)
readMesh rs x = do
    let Just sml = findElement (unqual "submeshes") x
    submeshes <- mapM (readSubMesh rs) $ findElements (unqual "submesh") sml
    sharedVDs <- mapM (readGeometry rs) $ findElements (unqual "sharedgeometry") x
    let mesh = Mesh
            { msSubMeshList                 = submeshes
            , msSharedVertexData            = listToMaybe sharedVDs
            , msBoundRadius                 = 0
            }
    r <- calculateBoundingRadius mesh
    return mesh { msBoundRadius = r }

parseMesh :: (RenderSystem rs vb ib q t p lp) => rs -> String -> IO (Mesh vb ib)
parseMesh rs doc = do
    let Just x = parseXMLDoc doc
    readMesh rs $ fromMaybe (error "fromJust 6") $ findElement (unqual "mesh") x

loadMesh :: (RenderSystem rs vb ib q t p lp) => rs -> FilePath -> IO (Mesh vb ib)
loadMesh rs fileName = do
    doc <- readFile fileName
    parseMesh rs doc