{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} module CUBLASBatched where import qualified Data.Array.Accelerate.Arithmetic.LinearAlgebra as ALinAlg import qualified Data.Array.Accelerate.Utility.Lift.Acc as Acc import Data.Array.Accelerate.Utility.Lift.Acc (acc, expr) import Data.Array.Accelerate.Array.Sugar (EltRepr) import Data.Array.Accelerate (Array, DIM3, Acc, Z (..), (:.) (..), Exp) import qualified Data.Array.Accelerate.CUDA.Foreign as AF import qualified Data.Array.Accelerate.CUDA as AC import qualified Data.Array.Accelerate as A import qualified Foreign.CUDA.Cublas as Cublas import Foreign.CUDA.Ptr (DevicePtr, castDevPtr, advanceDevPtr) import Foreign.C.Types (CFloat, CDouble) import Foreign.Storable (Storable) import Data.Tuple.HT (uncurry3) type Matrix ix = Array (ix :. Int :. Int) type Vector ix = Array (ix :. Int) type Scalar ix = Array ix mul :: (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) => Cublas.Handle -> Exp a -> ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a mul handle alpha a b = A.foreignAcc (AF.CUDAForeignAcc "mul" $ uncurry3 $ mulPlain handle) (Acc.modify (expr,acc,acc) $ \(alpha0, a0, b0) -> A.map (alpha0 *) $ ALinAlg.multiplyMatrixMatrix a0 b0) $ A.lift (A.unit alpha, a, b) mulPlain :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> A.Scalar a -> Matrix ix a -> Matrix ix a -> AF.CIO (Matrix ix a) mulPlain handle alpha a b = do let (aNumMatrices :. n :. k) = A.arrayShape a let (bNumMatrices :. _k :. m) = A.arrayShape b let numMatrices = if aNumMatrices == bNumMatrices then aNumMatrices else error "mul: mismatching shapes of matrix arrays" c <- AF.allocateArray (numMatrices :. n :. m) (pas, lda) <- arrayPtrs a (pbs, ldb) <- arrayPtrs b (pcs, ldc) <- arrayPtrs c AF.liftIO $ Cublas.gemmBatched handle Cublas.N Cublas.N m n k (storableFromScalar alpha) pbs ldb pas lda 0 pcs ldc (A.arraySize numMatrices) return c mac :: (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) => Cublas.Handle -> Exp a -> ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a -> Exp a -> ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a mac handle alpha a b beta c = A.foreignAcc (AF.CUDAForeignAcc "mac" $ \((alpha0, a0, b0), (beta0, c0)) -> macPlain handle alpha0 a0 b0 beta0 c0) (Acc.modify ((expr,acc,acc),(expr,acc)) $ \((alpha0, a0, b0), (beta0, c0)) -> A.zipWith (+) (A.map (alpha0 *) $ ALinAlg.multiplyMatrixMatrix a0 b0) (A.map (beta0 *) c0)) $ A.lift ((A.unit alpha, a, b), (A.unit beta, c)) macPlain :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> A.Scalar a -> Matrix ix a -> Matrix ix a -> A.Scalar a -> Matrix ix a -> AF.CIO (Matrix ix a) macPlain handle alpha a b beta c = do let (aNumMatrices :. n :. k ) = A.arrayShape a let (bNumMatrices :. _k :. m ) = A.arrayShape b let (cNumMatrices :. n' :. m') = A.arrayShape c let numMatrices = if aNumMatrices == bNumMatrices && aNumMatrices == cNumMatrices then aNumMatrices else error "mac: mismatching shapes of matrix arrays" d <- AF.allocateArray (numMatrices :. n' :. m') AF.copyArray c d (pas, lda) <- arrayPtrs a (pbs, ldb) <- arrayPtrs b (pds, ldd) <- arrayPtrs d AF.liftIO $ Cublas.gemmBatched handle Cublas.N Cublas.N m n k (storableFromScalar alpha) pbs ldb pas lda (storableFromScalar beta) pds ldd (A.arraySize numMatrices) return d lu :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> ALinAlg.Matrix ix a -> (ALinAlg.Matrix ix a, ALinAlg.Vector ix Int, ALinAlg.Scalar ix Int) lu handle = A.unlift . A.foreignAcc (AF.CUDAForeignAcc "lu" $ luPlain handle) (error "Requires CUDA backend") luPlain :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> Matrix ix a -> AF.CIO (Matrix ix a, Vector ix Int, Scalar ix Int) luPlain handle a = do let sh@(numMatrices :. n :. k) = A.arrayShape a let size = if n == k then n else error "lu: matrices must have square shape" b <- AF.allocateArray sh AF.copyArray a b (pbs, ldb) <- arrayPtrs b pivot <- AF.allocateArray (numMatrices :. size) pivotPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray pivot info <- AF.allocateArray numMatrices infoPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray info AF.liftIO $ Cublas.getrfBatched handle size pbs ldb pivotPtr infoPtr (A.arraySize numMatrices) return (b, pivot, info) luInv :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> (ALinAlg.Matrix ix a, ALinAlg.Vector ix Int, ALinAlg.Scalar ix Int) -> ALinAlg.Matrix ix a luInv handle = A.foreignAcc (AF.CUDAForeignAcc "luInv" $ luInvPlain handle) (error "Requires CUDA backend") . A.lift luInvPlain :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> (Matrix ix a, Vector ix Int, Scalar ix Int) -> AF.CIO (Matrix ix a) luInvPlain handle (a, pivot, info) = do let sh@(numMatrices :. n :. k) = A.arrayShape a let size = if n == k then n else error "luInv: matrices must have square shape" c <- AF.allocateArray sh AF.copyArray a c (pas, lda) <- arrayPtrs a (pcs, ldc) <- arrayPtrs c pivotPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray pivot infoPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray info AF.liftIO $ Cublas.getriBatched handle size pas lda pivotPtr pcs ldc infoPtr (A.arraySize numMatrices) return c inv :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> ALinAlg.Matrix ix a -> (ALinAlg.Matrix ix a, ALinAlg.Scalar ix Int) inv handle a = let sol@(_,_,info) = lu handle a in (luInv handle sol, info) type Element a = (AF.DevicePtrs (EltRepr a) ~ ((), DevicePtr a), Fractional (StorableOf a), Cublas.Cublas (StorableOf a), Storable (StorableOf a), Real a) type family StorableOf float type instance StorableOf Float = CFloat type instance StorableOf Double = CDouble storableFromScalar :: (Real a, StorableOf a ~ b, Fractional b) => A.Scalar a -> b storableFromScalar x = realToFrac $ A.indexArray x Z arrayPtrs :: (Storable a, StorableOf e ~ a, A.Shape ix, AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) => Array (ix :. Int :. Int) e -> AF.CIO ([DevicePtr a], Int) arrayPtrs arr = do let (numMatrices :. n :. k) = A.arrayShape arr pa <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray arr return (genPointers (n*k) pa (A.arraySize numMatrices), k) genPointers :: (Storable a) => Int -> DevicePtr a -> Int -> [DevicePtr a] genPointers size p n = take n $ iterate (flip advanceDevPtr size) p genMatrices :: (Acc (Array DIM3 Double), Acc (Array DIM3 Double)) genMatrices = (a,b) where a = A.generate (A.constant sha) $ \ix -> let (Z :. i :. j :. k) = unlift ix in A.fromIntegral (i+j+k) b = A.generate (A.constant shb) $ \ix -> let (Z :. i :. j :. k) = unlift ix in A.fromIntegral (i+j+k) numMats = 100 :: Int sha = Z :. numMats :. (3 :: Int) :. (4 :: Int) shb = Z :. numMats :. (4 :: Int) :. (2 :: Int) unlift :: Exp (Z :. Int :. Int :. Int) -> Z :. Exp Int :. Exp Int :. Exp Int unlift = A.unlift test :: IO () test = do handle <- Cublas.create print genMatrices print $ AC.run $ case genMatrices of (a,b) -> mul handle 1 a b