----------------------------------------------------------- -- | -- module: MXNet.Core.Base.Internal.TH -- copyright: (c) 2016-2017 Tao He -- license: MIT -- maintainer: sighingnow@gmail.com -- -- Template haskell tools for finding Ops on NDArray and Symbol from dynamic library. -- module MXNet.Core.Base.Internal.TH where import Data.Char import Data.List import Data.Monoid import Language.Haskell.TH import MXNet.Core.NNVM.Internal import MXNet.Core.Base.Internal ------------------------------------------------------------------------------- -- | Register NDArray ops. registerNDArrayOps :: Bool -- ^ If support "out" key in argument dictionary. -> Q [Dec] registerNDArrayOps mutable = runIO $ do (_, names) <- mxListAllOpNames concat <$> mapM (register mutable) names where register mutable _name = do (_, handle) <- nnGetOpHandle _name (_, _, desc, _, argv, argtype, _, _, _) <- mxSymbolGetAtomicSymbolInfo handle makeNDArrayFunc mutable _name desc argv argtype -- | Register symbol functions. registerSymbolOps :: Q [Dec] registerSymbolOps = runIO $ do (_, names) <- mxListAllOpNames concat <$> mapM register names where register _name = do (_, handle) <- nnGetOpHandle _name (_, _, desc, _, argv, argtype, _, _, _) <- mxSymbolGetAtomicSymbolInfo handle makeSymbolFunc _name desc argv argtype ------------------------------------------------------------------------------- -- | Generate the TH AST of a function for a NDArray op. makeNDArrayFunc :: Bool -- ^ If support "out" key in argument dictionary. -> String -- ^ Function's name. -> String -- ^ Function's description. -> [String] -- ^ Function's argument names. -> [String] -- ^ Function's argument types. -> IO [Dec] -- ^ Generated signature and function definition. makeNDArrayFunc mutable _name desc argv argtype = do let deprecated = desc `startWith` "DEPRECATED" || _name == "Softmax" -- Softmax is renamed to SoftmaxOutput let alias = _name `elem` ["Concat", "Pad", "Flatten", "Reshape"] let name = let str = if head _name == '_' then _name else if _name == "where" then "where_" else toLower <$> _name in if mutable then str <> "'" else str let explicitArg = getExplicitArg argv argtype ndarrayArg = filter (\(_, t) -> t `startWith` "NDArray" || t `startWith` "Symbol") explicitArg ordinaryArg = filter (\(_, t) -> not (t `startWith` "NDArray" || t `startWith` "Symbol")) explicitArg implicitArg = getImplicitArg argv argtype hasImplicit = (not . null) implicitArg let forallArgT = makeForallArgT implicitArg explicitArgT = (makeHsType . snd) <$> explicitArg implicitArgT = if hasImplicit then [AppT (ConT (mkName "HMap")) (VarT (mkName "kvs"))] else error "Impossible: no implicit available." -- will never be evaluated let ndarrayArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ndarrayArg ordinaryArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ordinaryArg implicitArgP = if hasImplicit then [VarP . mkName $ "varargs"] else [] returnArgP = if mutable then [VarP (mkName "outputs")] else [] let ndargs = foldr (\(v, t) args-> case makeHsType t of ConT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (ConE . mkName $ ":") args AppT ListT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (VarE . mkName $ "++") args _ -> error "Impossible: not a valid haskell type representation.") (ListE []) ndarrayArg dictargs = UInfixE (VarE (mkName "varArgK")) (VarE (mkName "zip")) (VarE (mkName "varArgV")) let func = NormalB . DoE $ [ LetS [ ValD (VarP (mkName "allArgs")) (NormalB $ foldr (\(name, t) acc -> AppE (AppE (AppE (VarE (mkName "add'")) (SigE (ConE (mkName "Proxy")) (AppT (ConT (mkName "Proxy")) (LitT (StrTyLit name))))) (SigE (VarE (mkName ("arg'" <> name))) (makeHsType t))) acc) (VarE (mkName $ if hasImplicit then "varargs" else "nil")) ordinaryArg ) [] ] , LetS [ ValD (VarP (mkName "args")) (NormalB $ AppE (VarE (mkName "dump")) (VarE (mkName "allArgs")) ) [] , ValD (TupP [VarP (mkName "varArgK"), VarP (mkName "varArgV")]) (NormalB $ AppE (VarE (mkName "unzip")) (VarE (mkName "args")) ) [] , ValD (VarP (mkName "outArg")) (NormalB $ if mutable then AppE (ConE (mkName "Just")) (VarE (mkName "outputs")) else ConE (mkName "Nothing") ) [] ] , BindS (TupP [VarP (mkName "_"), VarP (mkName "op")]) $ AppE (VarE (mkName "nnGetOpHandle")) (LitE (StringL _name)) , BindS (TupP [VarP (mkName "_"), VarP (mkName "res")]) $ AppE (AppE (AppE (AppE (VarE (mkName "mxImperativeInvoke")) (VarE (mkName "op"))) ndargs) dictargs) (VarE (mkName "outArg")) , NoBindS $ AppE (VarE (mkName "return")) $ AppE (VarE (mkName "toResult")) (VarE (mkName "res")) ] let argT = explicitArgT <> (if hasImplicit then implicitArgT else []) <> (if mutable then [AppT ListT (ConT (mkName "NDArrayHandle"))] else []) sig = SigD (mkName name) $ ForallT [ PlainTV (mkName "r")] [ AppT (ConT (mkName "NDArrayOpResult")) (VarT (mkName "r"))] (forallArgT (foldr (\a b -> ArrowT `AppT` a `AppT` b) (AppT (ConT (mkName "IO")) (VarT (mkName "r"))) argT)) pragma = PragmaD $ SpecialiseP (mkName name) (forallArgT (foldr (\a b -> ArrowT `AppT` a `AppT` b) (AppT (ConT (mkName "IO")) (ConT (mkName "NDArrayHandle"))) argT)) (Just Inline) AllPhases fun = FunD (mkName name) [Clause (ndarrayArgP <> ordinaryArgP <> implicitArgP <> returnArgP) func []] return $ if null argv || deprecated || alias || _name `elem` ["_NDArray", "_Native", "_arange"] || _name `elem` ["cast", "crop"] -- duplicate with "Cast" and "Crop" || null explicitArg || _name == "take" -- Operator @take@ will take two @SymbolHandle@ as arguments, can't be marshalled as strings. then [] else [sig, fun, pragma] where -- | Translate mxnet's type name to Haskell's type name. makeHsType :: String -> Type makeHsType s = case s of "boolean" -> ConT . mkName $ "Bool" "float" -> ConT . mkName $ "Float" "double" -> ConT . mkName $ "Double" "real_t" -> ConT . mkName $ "Float" 'i':'n':'t':_ -> ConT . mkName $ "Int" 'l':'o':'n':'g':_ -> ConT . mkName $ "Int" "string" -> ConT . mkName $ "String" "NDArray" -> ConT . mkName $ "NDArrayHandle" "NDArray-or-Symbol" -> ConT . mkName $ "NDArrayHandle" "NDArray-or-Symbol[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle" "Symbol" -> ConT . mkName $ "NDArrayHandle" "NDArray[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle" "Symbol[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle" "Symbol or Symbol[]" -> AppT ListT . ConT . mkName $ "NDArrayHandle" '{':_ -> ConT . mkName $ "String" "Shape(tuple)" -> ConT . mkName $ "String" "tuple of " -> AppT ListT . ConT . mkName $ "Float" "tuple of " -> AppT ListT . ConT . mkName $ "Double" s -> ConT . mkName $ "unknown type name: " <> s -- | Generate type signatures for implicit arguments. makeKVListT :: [(String, String, String)] -- ^ [(name, type, default value)] -> Type makeKVListT args = foldr combineKV PromotedNilT ((\(v, t, _) -> makeKV v t) <$> args) where makeKV v t = AppT (AppT (PromotedT (mkName ":=")) (LitT (StrTyLit v))) (makeHsType t) combineKV a acc = AppT (AppT (PromotedT (mkName ":")) a) acc -- | Make forall arguments signature according to it's implicit argument. makeForallArgT :: [(String, String, String)] -- ^ Implicit arguments, (name, type, default value) -> (Type -> Type) makeForallArgT [] = id makeForallArgT implicitArg = ForallT [ KindedTV (mkName "kvs") (AppT ListT (AppT (ConT (mkName "KV")) StarT)) ] [ AppT (ConT (mkName "ShowKV")) (VarT (mkName "kvs")) , AppT (AppT (ConT (mkName "MatchKVList")) (VarT (mkName "kvs"))) (makeKVListT implicitArg) ] -- | Generate the TH AST of a function for a Symbol op. makeSymbolFunc :: String -- ^ Function's name. -> String -- ^ Function's description. -> [String] -- ^ Function's argument names. -> [String] -- ^ Function's argument types. -> IO [Dec] -- ^ Generated signature and function definition. makeSymbolFunc _name desc argv argtype = do let deprecated = desc `startWith` "DEPRECATED" || _name == "Softmax" -- Softmax is renamed to SoftmaxOutput let alias = _name `elem` ["Concat", "Pad", "Flatten", "Reshape"] let name = let str = if head _name == '_' then _name else if _name == "where" then "where_" else toLower <$> _name in str let explicitArg = getExplicitArg argv argtype ndarrayArg = filter (\(v, t) -> t `startWith` "NDArray" || t `startWith` "Symbol") explicitArg ordinaryArg = filter (\(v, t) -> not (t `startWith` "NDArray" || t `startWith` "Symbol")) explicitArg implicitArg = getImplicitArg argv argtype hasImplicit = (not . null) implicitArg let forallArgT = makeForallArgT implicitArg explicitArgT = (makeHsType . snd) <$> explicitArg implicitArgT = if hasImplicit then [AppT (ConT (mkName "HMap")) (VarT (mkName "kvs"))] else error "Impossible: no implicit available." -- will never be evaluated let nameArgP = [VarP . mkName $ "name"] ndarrayArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ndarrayArg ordinaryArgP = (VarP . mkName . ("arg'" <>) . fst) <$> ordinaryArg implicitArgP = if hasImplicit then [VarP . mkName $ "varargs"] else [] let ndargs = foldr (\(v, t) args -> case makeHsType t of ConT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (ConE . mkName $ ":") args AppT ListT _ -> UInfixE (VarE . mkName . ("arg'" <>) $ v) (VarE . mkName $ "++") args _ -> error "Impossible: not a valid haskell type representation.") (ListE []) ndarrayArg let func = NormalB . DoE $ [ LetS [ ValD (VarP (mkName "allArgs")) (NormalB $ foldr (\(name, t) acc -> AppE (AppE (AppE (VarE (mkName "add'")) (SigE (ConE (mkName "Proxy")) (AppT (ConT (mkName "Proxy")) (LitT (StrTyLit name))))) (SigE (VarE (mkName ("arg'" <> name))) (makeHsType t))) acc) (VarE (mkName $ if hasImplicit then "varargs" else "nil")) ordinaryArg ) [] ] , LetS [ ValD (VarP (mkName "args")) (NormalB $ AppE (VarE (mkName "dump")) (VarE (mkName "allArgs")) ) [] , ValD (TupP [VarP (mkName "varArgK"), VarP (mkName "varArgV")]) (NormalB $ AppE (VarE (mkName "unzip")) (VarE (mkName "args")) ) [] ] , BindS (TupP [VarP (mkName "_"), VarP (mkName "op")]) $ AppE (VarE (mkName "nnGetOpHandle")) (LitE (StringL _name)) , LetS [ ValD (VarP (mkName "nargs")) (NormalB (AppE (VarE (mkName "fromIntegral")) (AppE (VarE (mkName "length")) (VarE (mkName "varArgK"))))) [] ] , BindS (TupP [VarP (mkName "_"), VarP (mkName "sym")]) $ AppE (AppE (AppE (AppE (VarE (mkName "mxSymbolCreateAtomicSymbol")) (VarE (mkName "op"))) (VarE (mkName "nargs"))) (VarE (mkName "varArgK"))) (VarE (mkName "varArgV")) , BindS (VarP (mkName "_")) $ AppE (AppE (AppE (AppE (VarE (mkName "nnSymbolCompose")) (VarE (mkName "sym"))) (VarE (mkName "name"))) (ListE [])) ndargs , NoBindS $ AppE (VarE (mkName "return")) (VarE (mkName "sym")) ] let argT = (ConT . mkName $ "String") : explicitArgT <> (if hasImplicit then implicitArgT else []) sig = SigD (mkName name) $ forallArgT (foldr (\a b -> ArrowT `AppT` a `AppT` b) (AppT (ConT (mkName "IO")) (ConT (mkName "SymbolHandle"))) argT) fun = FunD (mkName name) [Clause (nameArgP <> ndarrayArgP <> ordinaryArgP <> implicitArgP) func []] return $ if null argv || deprecated || alias || _name `elem` ["_NDArray", "_Native", "_arange"] || _name `elem` ["cast", "crop"] -- duplicate with "Cast" and "Crop" || null explicitArg || _name == "take" -- Operator @take@ will take two @SymbolHandle@ as arguments, can't be marshalled as strings. || _name == "where" then [] else [sig, fun] where -- | Translate mxnet's type name to Haskell's type name. makeHsType :: String -> Type makeHsType s = case s of "boolean" -> ConT . mkName $ "Bool" "float" -> ConT . mkName $ "Float" "double" -> ConT . mkName $ "Double" "real_t" -> ConT . mkName $ "Float" 'i':'n':'t':_ -> ConT . mkName $ "Int" 'l':'o':'n':'g':_ -> ConT . mkName $ "Int" "string" -> ConT . mkName $ "String" "NDArray" -> ConT . mkName $ "SymbolHandle" "Symbol" -> ConT . mkName $ "SymbolHandle" "NDArray-or-Symbol" -> ConT . mkName $ "SymbolHandle" "NDArray-or-Symbol[]" -> AppT ListT . ConT . mkName $ "SymbolHandle" "NDArray[]" -> AppT ListT . ConT . mkName $ "SymbolHandle" "Symbol[]" -> AppT ListT . ConT . mkName $ "SymbolHandle" "Symbol or Symbol[]" -> AppT ListT . ConT . mkName $ "SymbolHandle" '{':_ -> ConT . mkName $ "String" "Shape(tuple)" -> ConT . mkName $ "String" "tuple of " -> AppT ListT . ConT . mkName $ "Float" "tuple of " -> AppT ListT . ConT . mkName $ "Double" s -> ConT . mkName $ "unknown type name: " <> s -- | Generate type signatures for implicit arguments. makeKVListT :: [(String, String, String)] -- ^ [(name, type, default value)] -> Type makeKVListT args = foldr combineKV PromotedNilT ((\(v, t, _) -> makeKV v t) <$> args) where makeKV v t = AppT (AppT (PromotedT (mkName ":=")) (LitT (StrTyLit v))) (makeHsType t) combineKV a acc = AppT (AppT (PromotedT (mkName ":")) a) acc -- | Make forall arguments signature according to it's implicit argument. makeForallArgT :: [(String, String, String)] -- ^ Implicit arguments, (name, type, default value) -> (Type -> Type) makeForallArgT [] = id makeForallArgT implicitArg = ForallT [ KindedTV (mkName "kvs") (AppT ListT (AppT (ConT (mkName "KV")) StarT)) ] [ AppT (ConT (mkName "ShowKV")) (VarT (mkName "kvs")) , AppT (AppT (ConT (mkName "MatchKVList")) (VarT (mkName "kvs"))) (makeKVListT implicitArg) ] ------------------------------------------------------------------------------- -- | @startWith s t@ means @s@ starts with @t@. startWith :: String -> String -> Bool startWith s t = take (length t) s == t -- | Prepend elements in the second map into the first one. updateMap :: [(String, String)] -> [(String, String)] -> [(String, String)] updateMap xs [] = xs updateMap xs ((k, v) : ts) = case findIndex ((== k) . fst) xs of Just _ -> xs `updateMap` ts Nothing -> ((k, v) : xs) `updateMap` ts -- | Split argument string with ",", split name, type, default value and required information. splitArgType :: String -> [String] splitArgType (' ' : xs) = splitArgType xs splitArgType ts = case break (== ',') ts of ([], _) -> [] (t, []) -> [t] (t, _:xs) -> t : splitArgType xs -- | Get explicit arguments from all arugments. getExplicitArg :: [String] -- ^ Argument names. -> [String] -- ^ Argument types. -> [(String, String)] -- ^ Return necessary arguments' name and type. getExplicitArg argv argtype = [t | Just t <- resolve <$> zip argv argtype] where resolve (v, t) = let ts = splitArgType t in if "optional" `elem` ts then Nothing else if null ts -- Seems that `tuple of ` can't be exported correctly by mxSymbolGetAtomicSymbolInfo. then Just (v, "tuple of ") else Just (v, head ts) -- | Get implicit arguments from all arguments. getImplicitArg :: [String] -- ^ Argument names. -> [String] -- ^ Argument types. -> [(String, String, String)] -- ^ Return necessary arguments' names, types and default value. getImplicitArg argv argtype = [t | Just t <- resolve <$> zip argv argtype] where resolve (v, t) = let ts = splitArgType t in if "optional" `elem` ts then (\a -> (v, head ts, a)) <$> getDefault ts else Nothing getDefault = stripPrefix "default=" . head . filter (isPrefixOf "default=")