module Foreign.CUDA.Cublas.TH where
import Control.Applicative
import Control.Arrow
import Control.Monad ((>=>), join, void)
import GHC.Exts (groupWith)
import Language.Haskell.TH as TH
import Language.C
import Language.C.System.GCC
import Data.List (isInfixOf, isPrefixOf, isSuffixOf)
import Data.Char (toLower, toUpper)
import Data.Maybe (mapMaybe)
import Debug.Trace
import qualified Foreign.C.Types as C
import qualified Foreign as F
import Foreign.Storable.Complex ()
import Data.Complex (Complex(..))
import System.FilePath.Posix ((</>))
import Foreign.CUDA as FC
import qualified Foreign.CUDA.Runtime.Stream as FC
import qualified Foreign.CUDA.Cublas.Types as BL
import qualified Foreign.CUDA.Cusparse.Types as SP
import qualified Foreign.CUDA.Cublas.Error as BL
import qualified Foreign.CUDA.Cusparse.Error as SP
try :: [(Bool, a)] -> a
try ((p,y):conds) = if p then y else try conds
try [] = error "TH.try: No match!"
data TypeInfo = TI
{ ctype :: Q Type
, hsinput :: Either Convert Create
, hsoutput :: Maybe (Either Convert Destroy) }
data TypeDat = TD
{ ct :: Q Type
, hst :: Q Type
, c2hs :: (Q Exp, ExpType)
, hs2c :: (Q Exp, ExpType) }
data ExpType = Pure | Monadic
data TypeC = VoidC | IntC | FloatC | DoubleC | EnumC String
| ComplexC TypeC | ArbStructC String | PtrC TypeC | PhonyC TH.Name
| ArrC TypeC
deriving (Eq, Show)
prim :: Q Type -> Q Type -> Q Exp -> Q Exp -> TypeDat
prim ct hst c2hs hs2c = TD ct hst (c2hs, Pure) (hs2c, Pure)
simple :: Q Type -> TypeDat
simple t = prim t t [| id |] [| id |]
bothc :: (a -> b) -> Complex a -> Complex b
bothc f (a :+ b) = f a :+ f b
typeDat :: TypeC -> TypeDat
typeDat (PhonyC n) = simple (varT n)
typeDat VoidC = simple [t| () |]
typeDat (PtrC t) = prim [t| F.Ptr $(ctype) |] [t| FC.DevicePtr $(ctype) |] [| FC.DevicePtr |] [| FC.useDevicePtr |] where
ctype = ct (typeDat t)
typeDat (ArrC t) = typeDat (PtrC t)
typeDat IntC = prim [t| C.CInt |] [t| Int |] [| fromIntegral |] [| fromIntegral |]
typeDat FloatC = simple [t| C.CFloat |]
typeDat DoubleC = simple [t| C.CDouble |]
typeDat (EnumC str) = prim [t| C.CInt |] x [| toEnum . fromIntegral |] [| fromIntegral . fromEnum |] where
x = case str of
"cublasStatus_t" -> [t| BL.Status |]
"cublasOperation_t" -> [t| BL.Operation |]
"cublasSideMode_t" -> [t| BL.SideMode |]
"cublasFillMode_t" -> [t| BL.FillMode |]
"cublasPointerMode_t" -> [t| BL.PointerMode |]
"cublasAtomicsMode_t" -> [t| BL.AtomicsMode |]
"cublasDiagType_t" -> [t| BL.DiagType |]
"cusparseStatus_t" -> [t| SP.Status |]
"cusparseOperation_t" -> [t| SP.Operation |]
"cusparseDirection_t" -> [t| SP.Direction |]
"cusparseHybPartition_t" -> [t| SP.HybPartition |]
"cusparseFillMode_t" -> [t| SP.FillMode |]
"cusparsePointerMode_t" -> [t| SP.PointerMode |]
"cusparseDiagType_t" -> [t| SP.DiagType |]
"cusparseIndexBase_t" -> [t| SP.IndexBase |]
"cusparseAction_t" -> [t| SP.Action |]
"cusparseMatrixType_t" -> [t| SP.MatrixType |]
"cusparseSolvePolicy_t" -> [t| SP.SolvePolicy |]
otherwise -> error ("typeDat.EnumC : Missing type: " ++ str)
typeDat (ArbStructC str) = case str of
"cublasHandle_t" -> prim [t| F.Ptr () |] [t| BL.Handle |] [| BL.Handle |] [| BL.useHandle |]
"cusparseHandle_t" -> prim [t| F.Ptr () |] [t| SP.Handle |] [| SP.Handle |] [| SP.useHandle |]
"cusparseHybMat_t" -> prim [t| F.Ptr () |] [t| SP.HybMat |] [| SP.HybMat |] [| SP.useHybMat |]
"cusparseMatDescr_t" -> prim [t| F.Ptr () |] [t| SP.MatDescr |] [| SP.MatDescr |] [| SP.useMatDescr |]
"cusparseSolveAnalysisInfo_t" -> prim [t| F.Ptr () |] [t| SP.SolveAnalysisInfo |] [| SP.SolveAnalysisInfo |] [| SP.useSolveAnalysisInfo |]
"csrsv2Info_t" -> prim [t| F.Ptr () |] [t| SP.Csrsv2Info |] [| SP.Csrsv2Info |] [| SP.useCsrsv2Info |]
"csric02Info_t" -> prim [t| F.Ptr () |] [t| SP.Csric02Info |] [| SP.Csric02Info |] [| SP.useCsric02Info |]
"csrilu02Info_t" -> prim [t| F.Ptr () |] [t| SP.Csrilu02Info |] [| SP.Csrilu02Info |] [| SP.useCsrilu02Info |]
"bsrsv2Info_t" -> prim [t| F.Ptr () |] [t| SP.Bsrsv2Info |] [| SP.Bsrsv2Info |] [| SP.useBsrsv2Info |]
"bsric02Info_t" -> prim [t| F.Ptr () |] [t| SP.Bsric02Info |] [| SP.Bsric02Info |] [| SP.useBsric02Info |]
"bsrilu02Info_t" -> prim [t| F.Ptr () |] [t| SP.Bsrilu02Info |] [| SP.Bsrilu02Info |] [| SP.useBsrilu02Info |]
"cudaStream_t" -> prim [t| F.Ptr () |] [t| FC.Stream |] [| FC.Stream |] [| FC.useStream |]
typeDat (ComplexC t) = prim
[t| Complex $(ctype) |]
[t| Complex $(hstype) |]
[| bothc $(fromC) |]
[| bothc $(toC) |]
where
TD ctype hstype (fromC, Pure) (toC, Pure) = typeDat t
convertT x y = Left (Convert x y)
createT = Right . Create
destroyT = Right . Destroy
data Convert = Convert (Q Type) (Q Exp)
newtype Create = Create (Q Exp)
newtype Destroy = Destroy (Q Exp)
pointerify :: Q Type -> Q Type
pointerify x = [t| F.Ptr $(x) |]
useT :: TypeC -> TypeInfo
useT = useT' . typeDat where
useT' (TD ct hst c2hs (hs2c,purity)) = TI
ct
(convertT hst exp)
Nothing
where
exp = case purity of Pure -> [| return . $(hs2c) |]; Monadic -> hs2c
inT :: TypeC -> TypeInfo
inT (PtrC t) = inT' (typeDat t) where
inT' (TD ct hst c2hs (hs2c,purity)) = TI
(pointerify ct)
(convertT hst exp)
(Just (destroyT [| F.free |]))
where
exp = case purity of Pure -> [| F.new . $(hs2c) |] ; Monadic -> undefined
inT (ArrC t) = inT' (typeDat t) where
inT' (TD ct hst c2hs (hs2c,purity)) = TI
(pointerify ct)
(convertT [t| [ $(hst) ] |] exp)
(Just (destroyT [| FC.free . FC.DevicePtr |]))
where
exp = case purity of Pure -> [| fmap FC.useDevicePtr . FC.newListArray . map $(hs2c) |] ; Monadic -> undefined
outT :: TypeC -> TypeInfo
outT (PtrC t) = outT' (typeDat t) where
outT' (TD ct hst (c2hs,purity) hs2c) = TI
ct
(createT [| F.malloc |])
(Just (convertT hst [| \p -> do { x <- F.peek p ; F.free p; $(exp) x } |]))
where
exp = case purity of Pure -> [| return . $(c2hs) |] ; Monadic -> c2hs
inOutT :: TypeC -> TypeInfo
inOutT (PtrC t) = inOutT' (typeDat t) where
inOutT' (TD ct hst (c2hs,purity1) (hs2c,purity2)) = TI
(pointerify ct)
(convertT hst exp1)
(Just (convertT hst [| \p -> do { x <- F.peek p ; F.free p; $(exp2) x } |]))
where
exp1 = case purity1 of Pure -> [| F.new . $(hs2c) |] ; Monadic -> hs2c
exp2 = case purity2 of Pure -> [| return . $(c2hs) |] ; Monadic -> c2hs
convert :: CTypeSpecifier a -> TypeC
convert (CVoidType _) = VoidC
convert (CIntType _) = IntC
convert (CFloatType _) = FloatC
convert (CDoubleType _) = DoubleC
convert (CTypeDef ident _) = try
[ (s `elem` ["cublasHandle_t", "cusparseHybMat_t", "cusparseHandle_t", "cusparseMatDescr_t", "cusparseSolveAnalysisInfo_t", "cudaStream_t", "csrsv2Info_t", "csric02Info_t", "csrilu02Info_t", "bsrsv2Info_t", "bsric02Info_t", "bsrilu02Info_t" ] , ArbStructC s)
, (s=="cuComplex", ComplexC FloatC)
, (s=="cuDoubleComplex", ComplexC DoubleC)
, (True, EnumC s) ]
where
s = identToString ident
convert _ = VoidC
convert' :: [CDeclarationSpecifier a] -> TypeC
convert' (CTypeSpec x:_) = convert x
convert' (_:xs) = convert' xs
convert' [] = error "convert': invalid CDeclarationSpecifier list"
typeOf :: (TypeC -> Q Type) -> CDeclaration a -> Q Type
typeOf proj (CDecl basetype [(Just (CDeclr (Just ident) ptrs _ _ _), _, _)] _) =
foldr f (proj $ convert' basetype) ptrs
where
f (CPtrDeclr _ _) b = [t| F.Ptr $(b) |]
f _ _ = error "haven't implemented other things"
pointerification :: CDeclaration a -> (TypeC -> TypeC)
pointerification (CDecl _ [(Just (CDeclr _ ptrs _ _ _), _, _)] _) = foldr (.) id $ map f ptrs where
f (CPtrDeclr _ _) = PtrC
f (CArrDeclr _ _ _) = ArrC
f _ = id
baseType :: CDeclaration a -> TypeC
baseType (CDecl basetype _ _) = convert' basetype
cType :: CDeclaration a -> TypeC
cType d = (pointerification d) (baseType d)
typeInfo :: String -> CVar -> TypeInfo
typeInfo fn (n, typec) = ($ typec) $ try
[
(case typec of ArrC _ -> True; otherwise -> False
, inT)
, (n `elem` ["alpha", "beta", "a", "b", "c", "d1", "d2", "x1", "y1", "s"]
, inT)
, ( "create" `isPrefixOf` fn || n == "result"
, outT)
, ("DevHostPtr" `isSuffixOf` n
, outT)
, (True
, useT)
]
declName :: CDeclaration a -> Maybe String
declName (CDecl _ [(Just (CDeclr (Just ident) _ _ _ _), _, _)] _) = Just (identToString ident)
declName _ = Nothing
outMarshall :: TypeC -> (Q Exp, Q Type -> Q Type)
outMarshall (EnumC "cublasStatus_t") = ([| BL.resultIfOk |], id)
outMarshall (EnumC "cusparseStatus_t") = ([| SP.resultIfOk |], id)
outMarshall VoidC = ([| return . snd |], id)
outMarshall x = ([| return . fst |], const (hst $ typeDat x))
createf' :: (String, CFunction) -> Q [Dec]
createf' (foreignname, cf@(fname, rettype, args)) = do
ins <- mapM (safeName "_in") args
toCs <- mapM (safeName "_out") args
(outstatements, (outtypes, outs)) <- second (unzip . filterMaybes) . unzip <$> collect (zip3 args argsTI toCs)
let instatements = map inMarsh (zip3 argsTI ins toCs)
ret <- newName "res"
let runstatement = bindS (varP ret) (foldl f z toCs)
let returnstatement = [| $(checkStatusExp) ( $(outputConv) $(varE ret), $(tupE (map varE outs)) ) |]
expr <- doE $ concat [instatements, runstatement:outstatements, [noBindS returnstatement]]
let usedins = map snd . filter (isused . fst) $ zip argsTI ins
let fdec = FunD fcall [Clause (map VarP usedins) (NormalB expr) []]
tdec <- sigD fcall $ funTypeMod checkStatusType argsTI
return [tdec, fdec]
where
safeName :: String -> CVar -> Q TH.Name
safeName end (s:str, _) = newName (toLower s : str ++ end)
argsTI = functionTypeInfo cf
(outputConv, _) = c2hs (typeDat rettype)
(checkStatusExp, checkStatusType) = outMarshall rettype
isused (TI _ (Left _) _) = True
isused _ = False
fcall = mkName fname
z = varE (mkName foreignname)
f x e = appE x (varE e)
inMarsh (ti,e,e') = case hsinput ti of
Left (Convert t a) -> bindS (varP e') (appE a (varE e))
Right (Create a) -> bindS (varP e') a
collect ( (arg, TI _ _ (Just cleanup), e) : xs) = do
e' <- safeName "_out" arg
let outinfo = case cleanup of
Left (Convert t a) -> (bindS (varP e') (appE a (varE e)), Just (t, e'))
Right (Destroy a) -> (noBindS (appE a (varE e)), Nothing)
ys <- collect xs
return (outinfo:ys)
collect (_:xs) = collect xs
collect [] = return []
collecti (TI _ (Left (Convert t _)) _) =
[t]
collecti _ = []
cublasFile, cusparseFile :: FilePath
cublasFile = CUDA_INCLUDE_DIR </> "cublas_v2.h"
cusparseFile = CUDA_INCLUDE_DIR </> "cusparse_v2.h"
filterMaybes :: [Maybe a] -> [a]
filterMaybes [] = []
filterMaybes (Just x:xs) = x : filterMaybes xs
filterMaybes (Nothing:xs) = filterMaybes xs
funname :: CDeclaration a -> String
funname (CDecl _ [(Just (CDeclr (Just ident ) _ _ _ _), _, _)] _) = identToString ident
funname _ = "Weird!"
desired :: String -> CFunction -> Bool
desired prefix (name, _, _) =
any (`isPrefixOf` name) $ map (prefix ++) ("Get" : map (:[]) "SDCZX")
infol :: Show a => CDerivedDeclarator a -> Maybe [[String]]
infol (CFunDeclr (Right (ys,_)) _ _) = Just $ map f ys where
f (CDecl specs _ _) = map show specs
infol _ = Nothing
funArgs :: CDeclarator a -> Maybe [CDeclaration a]
funArgs (CDeclr _ [(CFunDeclr (Right (ys,_)) _ _)] _ _ _) = Just ys
funArgs _ = Nothing
funDecl :: CDeclaration a -> Maybe (CDeclarator a)
funDecl (CDecl _ [(Just declarator, _, _)] _) = Just declarator
funDecl _ = Nothing
maybeFunction :: CDeclaration a -> Maybe (CFunction)
maybeFunction d@(CDecl returnType _ _) = do
args <- funArgs =<< funDecl d
retName <- declName d
argNames <- mapM declName args
let argTypes = map cType args
return (retName, convert' returnType, zip argNames argTypes )
maybeExternalDec :: CExternalDeclaration a -> Maybe (CDeclaration a)
maybeExternalDec (CDeclExt d) = Just d
maybeExternalDec _ = Nothing
type CVar = (String, TypeC)
type CFunction = ( String , TypeC , [CVar] )
getFunctions :: FilePath -> IO [CFunction]
getFunctions fp = do
Right (CTranslUnit xs _) <- parseCFile (newGCC "/usr/bin/gcc") Nothing [] fp
return $ mapMaybe (maybeExternalDec >=> maybeFunction) xs
createf :: FilePath -> CFunction -> Q Dec
createf fp (name, ret, args) =
forImpD cCall safe (fp ++ ' ':name) (mkName name) cFunType
where
cFunType = foldr f z (map (ct . typeDat . snd) args)
z = [t| IO $(ct . typeDat $ ret) |]
f x y = [t| $(x) -> $(y) |]
sharedDecs :: String -> [CFunction] -> [(String, [(String, CFunction)])]
sharedDecs prefix xs = xs'' where
g x@(s,ret,args) = do
newname <- dropc <$> goodName prefix s
return (s, (newname, ret, args))
xs' = mapMaybe g xs
fst3 (s,_,_) = s
dropc name = if last name == 'c' then init name else name
xs'' = map ( (fst3 . snd . head) &&& id) .
filter sdFilter . groupWith (tail . fst3 . snd) $ xs'
sdFilter xs = length xs == 4 && not (
any (`isInfixOf` (fst (head xs))) ["rot_v2", "rotg_v2", "hybsv_analysis", "numericBoost"] )
mkClass :: String -> [CFunction] -> Q Dec
mkClass (p:prefix) xs = classD (return []) className [PlainTV typeName] [] decs where
className = mkName (toUpper p:prefix)
typeName = mkName "a"
decs = map (f . phonifyF) xs
mkPhony :: TypeC -> TypeC
mkPhony (PtrC t) = PtrC (mkPhony t)
mkPhony (ArrC t) = ArrC (mkPhony t)
mkPhony x = let t' = PhonyC typeName in
case x of DoubleC -> t'; FloatC -> t'; ComplexC _ -> t'; y -> y
phonifyF :: CFunction -> CFunction
phonifyF (name, ret, args) = (name, mkPhony ret, map (second mkPhony) args)
f cfunc@(name, _, _) = sigD (mkName (tail name)) (funType $ functionTypeInfo cfunc)
mkClassInstances :: String -> [(String, [(String,CFunction)])] -> [Q Dec]
mkClassInstances (p:prefix) xs = map (\c -> makeInstance c $ map (f c) xs) "sdcz" where
makeInstance c decs = instanceD (return []) classSig (decs) where
classSig = appT (return . ConT $ mkName (toUpper p:prefix)) (ct . typeDat $ typeMap c)
f c (_, funcs) = (!! 1) <$> createf' (foreignn, (name, ret, args)) where
[(foreignn,((_:name), ret, args))] = filter (\(_,((s:_),_,_))-> s==c) funcs
typeMap :: Char -> TypeC
typeMap 'c' = ComplexC FloatC
typeMap 'z' = ComplexC DoubleC
typeMap 'd' = DoubleC
typeMap 's' = FloatC
typeMap _ = error "typeMap: Invalid character"
makeClassDecs :: String -> FilePath -> IO (Q [Dec])
makeClassDecs str fp = do
sds <- sharedDecs str <$> getFunctions fp
return $ sequence (mkClass str (map (snd . head . snd) sds) : mkClassInstances str sds)
makeFFIDecs :: String -> FilePath -> IO (Q [Dec])
makeFFIDecs str fp = sequence . map (createf fp) . filter (desired str) <$> getFunctions fp
makeAllFuncs :: String -> FilePath -> IO (Q [Dec])
makeAllFuncs str fp = fmap concat . sequence . mapMaybe (fmap createf' . alter). filter (desired str) <$> getFunctions fp where
alter (fname, rettype, args) = do
newname <- goodName str fname
return (fname, (newname, rettype, args))
goodName :: String -> String -> Maybe String
goodName prefix = f where
v2suff = "_v2"
l = length prefix
f str = if pre == prefix then Just (toLower x : xs) else Nothing
where
(pre, name) = splitAt l str
(name', v2) = splitAt (length name length v2suff) name
(x : xs) = if v2 == v2suff then name' else name
doIO :: IO (Q [a]) -> Q [a]
doIO = join . runIO
inTypes :: [TypeInfo] -> [Q Type]
inTypes = mapMaybe f where
f (TI _ (Left (Convert t _)) _) = Just t
f _ = Nothing
outTypes :: [TypeInfo] -> [Q Type]
outTypes = mapMaybe f where
f (TI _ _ (Just (Left (Convert t _)))) = Just t
f _ = Nothing
functionTypeInfo :: CFunction -> [TypeInfo]
functionTypeInfo (fname, ret, args) = map (typeInfo fname) args
funTypeMod :: (Q Type -> Q Type) -> [TypeInfo] -> Q Type
funTypeMod f args = foldr arrow z ins where
arrow x y = [t| $(x) -> $(y) |]
z = [t| IO $( f $ foldl appT (tupleT (length outs)) outs) |]
[ins, outs] = map ($ args) [inTypes, outTypes]
funType :: [TypeInfo] -> Q Type
funType = funTypeMod id