{-# Language TemplateHaskell #-}

module MXNet.Core.IO.Internal.TH where



import Data.List

import Data.Char

import Data.Bifunctor

import Text.ParserCombinators.ReadP

import Language.Haskell.TH



import MXNet.Core.Base

import MXNet.Core.Base.Internal



diInfoName (n,_,_,_,_,_) = n

diInfoDesc (_,n,_,_,_,_) = n

diInfoArgc (_,_,n,_,_,_) = n

diInfoArgN (_,_,_,n,_,_) = n

diInfoArgT (_,_,_,_,n,_) = n

diInfoArgD (_,_,_,_,_,n) = n



registerDataIters :: Q [Dec]

registerDataIters = do 

    dataiterInfo <- runIO (mxListDataIters >>= mapM info . zip [0..])

    concat <$> mapM (uncurry makeDataIter) dataiterInfo

  where

    info (idx, creator) = do

        info <- mxDataIterGetIterInfo creator

        let name = diInfoName info

            argn = diInfoArgN info

            argt = diInfoArgT info

            args = nub $ zip argn argt

        return ((idx, name), args)



makeDataIter :: (Integer, String) -> [(String, String)] -> Q [Dec]

makeDataIter (index, name) args = do

    let args' = map (second parseArgDesc) args

        dname = mkName (deCap name)

    let kvs = mkName "kvs"

        cstName = mkName $ name ++ "_Args"

        args = foldr add promotedNilT args'

        typWithArgs = if null args' then [t| IO DataIterHandle |] else [t| HMap $(varT kvs) -> IO DataIterHandle |]

    cst <- tySynD cstName [] args

    sig <- sigD dname [t| (MatchKVList $(varT kvs) $(conT cstName), ShowKV $(varT kvs)) => $(typWithArgs) |]

    let allargs = mkName "allargs"

    fun <- funD dname [clause [varP allargs] (normalB [e| do{

        args <- return (dump $(varE allargs));

        len  <- return (fromIntegral $ length args);

        (keys, vals) <- return (unzip args);

        crts <- mxListDataIters;

        checked $ mxDataIterCreateIter (crts !! $(litE $ integerL index)) len keys vals;

    } |]) []]

    return [cst, sig, fun]

  where

    deCap (x:xs) = (toLower x):xs

    toTyp ArgString = [t| String |]

    toTyp ArgInt    = [t| Int |]

    toTyp ArgLong   = [t| Integer |]

    toTyp ArgFloat  = [t| Float |]

    toTyp ArgBool   = [t| Bool |]

    toTyp ArgShape  = [t| [Int] |]

    toTyp (ArgEnum v)  = [t| String |]

    toTyp (ArgTuple t) = [t| [$(toTyp t)] |]

    app t1 t2 = [t| $(toTyp t1)  -> $(t2) |]

    add (nm,(at,_)) lst = let item = [t| $(litT (strTyLit nm)) ':= $(toTyp at) |]

                          in appT (appT promotedConsT item) lst



data ArgType = ArgString | ArgInt | ArgLong | ArgFloat | ArgBool | ArgShape | ArgEnum [String] | ArgTuple ArgType

    deriving (Eq, Show)

data ArgOccr = Required | Optional

    deriving (Eq, Show)



parseArgDesc :: String -> (ArgType, ArgOccr)

parseArgDesc str = case readP_to_S desc str of

                     (r, _):_ -> r

                     _ -> error ("cannot parse arg desc: " ++ str)



alphaNum = many1 (satisfy isAlphaNum)

quoted = between (char '\'') (char '\'') (many $ satisfy isAlphaNum +++ choice (map char "/_-."))

boxed = between (char '[') (char ']') (quoted +++ number +++ alphaNum)

number = optional (char '-') >> many1 (satisfy isDigit)

comma = skipSpaces >> char ',' >> skipSpaces

enum = between (char '{') (char '}') (sepBy1 (alphaNum +++ quoted) comma)

typ = choice [ string "string"              >> return ArgString

             , string "int"                 >> return ArgInt

             , string "int (non-negative)"  >> return ArgInt

             , string "long"                >> return ArgLong

             , string "long (non-negative)" >> return ArgLong

             , string "boolean"             >> return ArgBool

             , string "float"               >> return ArgFloat

             , string "Shape(tuple)"        >> return ArgShape 

             , string "tuple of" >> skipSpaces >> (between (char '<') (char '>') typ >>= return . ArgTuple)

             , enum >>= (return . ArgEnum) ]

occ = choice [ string "required" >> 

               return Required 

             , string "optional" >> comma >> 

               string "default=" >> (quoted +++ boxed +++ alphaNum +++ number) >> 

               return Optional]



desc :: ReadP (ArgType, ArgOccr)

desc = do

    t <- typ

    comma

    o <- occ

    return (t, o)