module Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign where
import Data.Array.Accelerate.Array.Sugar (EltRepr)
import Data.Array.Accelerate (Array, Shape, Z(Z), (:.)((:.)))
import qualified Data.Array.Accelerate.CUDA.Foreign as AF
import qualified Data.Array.Accelerate as A
import qualified Foreign.CUDA.Cublas as Cublas
import Foreign.CUDA.Ptr (DevicePtr, castDevPtr, advanceDevPtr)
import Foreign.C.Types (CFloat, CDouble)
import Foreign.Storable (Storable)
import Data.Word (Word32)
type Matrix ix = Array (ix :. Int :. Int)
type Vector ix = Array (ix :. Int)
type Scalar ix = Array ix
mul ::
(Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
A.Scalar a -> Matrix ix a -> Matrix ix a ->
AF.CIO (Matrix ix a)
mul handle alpha a b = do
let (aNumMatrices :. n :. ak) = A.arrayShape a
let (bNumMatrices :. bk :. m) = A.arrayShape b
let k = unify "mul: matrix sizes mismatch" ak bk
let numMatrices =
unify "mul: mismatching shapes of matrix arrays"
aNumMatrices bNumMatrices
c <- AF.allocateArray (numMatrices :. n :. m)
(pas, lda) <- arrayPtrs a
(pbs, ldb) <- arrayPtrs b
(pcs, ldc) <- arrayPtrs c
AF.liftIO $
Cublas.gemmBatched handle Cublas.N Cublas.N m n k
(storableFromScalar alpha)
pbs ldb
pas lda
0
pcs ldc
(A.arraySize numMatrices)
return c
mac ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
A.Scalar a -> Matrix ix a -> Matrix ix a ->
A.Scalar a -> Matrix ix a ->
AF.CIO (Matrix ix a)
mac handle alpha a b beta c = do
let (aNumMatrices :. an :. bk) = A.arrayShape a
let (bNumMatrices :. ak :. bm) = A.arrayShape b
let (cNumMatrices :. cn :. cm) = A.arrayShape c
let k = unify "mac: matrix sizes mismatch" ak bk
let n = unify "mac: matrix sizes mismatch" an cn
let m = unify "mac: matrix sizes mismatch" bm cm
let numMatrices =
let msg = "mac: mismatching shapes of matrix arrays"
in unify msg aNumMatrices $
unify msg bNumMatrices cNumMatrices
d <- AF.allocateArray (numMatrices :. n :. m)
AF.copyArray c d
(pas, lda) <- arrayPtrs a
(pbs, ldb) <- arrayPtrs b
(pds, ldd) <- arrayPtrs d
AF.liftIO $
Cublas.gemmBatched handle Cublas.N Cublas.N m n k
(storableFromScalar alpha)
pbs ldb
pas lda
(storableFromScalar beta)
pds ldd
(A.arraySize numMatrices)
return d
lu ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
Matrix ix a ->
AF.CIO (Matrix ix a, Vector ix Word32, Scalar ix Word32)
lu handle a = do
let sh@(numMatrices :. n :. k) = A.arrayShape a
let size = unify "lu: matrices must have square shape" n k
b <- AF.allocateArray sh
AF.copyArray a b
(pbs, ldb) <- arrayPtrs b
pivot <- AF.allocateArray (numMatrices :. size)
pivotPtr <- devicePtrsOfArray pivot
info <- AF.allocateArray numMatrices
infoPtr <- devicePtrsOfArray info
AF.liftIO $
Cublas.getrfBatched handle size
pbs ldb
pivotPtr infoPtr
(A.arraySize numMatrices)
return (b, pivot, info)
luInv ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
(Matrix ix a, Vector ix Word32, Scalar ix Word32) ->
AF.CIO (Matrix ix a)
luInv handle (a, pivot, info) = do
let sh@(numMatrices :. n :. k) = A.arrayShape a
let size = unify "luInv: matrices must have square shape" n k
c <- AF.allocateArray sh
AF.copyArray a c
(pas, lda) <- arrayPtrs a
(pcs, ldc) <- arrayPtrs c
pivotPtr <- devicePtrsOfArray pivot
infoPtr <- devicePtrsOfArray info
AF.liftIO $
Cublas.getriBatched handle size
pas lda
pivotPtr
pcs ldc
infoPtr
(A.arraySize numMatrices)
return c
luSolve ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
A.Scalar a ->
Matrix ix a ->
Matrix ix a ->
AF.CIO (Matrix ix a)
luSolve handle alpha a b = do
let (aNumMatrices :. an :. ak) = A.arrayShape a
let sh@(bNumMatrices :. bk :. m) = A.arrayShape b
let n =
unify "luSolve: matrices must have square shape" an $
unify "luSolve: matrices dimensions must match" ak bk
let count =
A.arraySize $
unify "mul: mismatching shapes of matrix arrays"
aNumMatrices bNumMatrices
c <- AF.allocateArray sh
AF.copyArray b c
(pas, lda) <- arrayPtrs a
(pcs, ldc) <- arrayPtrs c
AF.liftIO $ do
Cublas.trsmBatched handle
Cublas.SideRight Cublas.Upper Cublas.N Cublas.NonUnit m n
(storableFromScalar alpha)
pas lda
pcs ldc
count
Cublas.trsmBatched handle
Cublas.SideRight Cublas.Lower Cublas.N Cublas.Unit m n
(storableFromScalar alpha)
pas lda
pcs ldc
count
return c
inv ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
Matrix ix a ->
AF.CIO (Matrix ix a, Scalar ix Word32)
inv handle a = do
let sh@(numMatrices :. n :. k) = A.arrayShape a
let size = unify "inv: matrices must have square shape" n k
b <- AF.allocateArray sh
(pas, lda) <- arrayPtrs a
(pbs, ldb) <- arrayPtrs b
info <- AF.allocateArray numMatrices
infoPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray info
AF.liftIO $
Cublas.matinvBatched handle size
pas lda
pbs ldb
infoPtr
(A.arraySize numMatrices)
return (b, info)
class
(AF.DevicePtrs (EltRepr a) ~ ((), DevicePtr a),
Fractional (StorableOf a),
Cublas.Cublas (StorableOf a),
Storable (StorableOf a),
Real a) =>
Element a where
instance Element Float where
instance Element Double where
type family StorableOf float
type instance StorableOf Float = CFloat
type instance StorableOf Double = CDouble
storableFromScalar ::
(Real a, StorableOf a ~ b, Fractional b) => A.Scalar a -> b
storableFromScalar x = realToFrac $ A.indexArray x Z
genPointers ::
(Storable a) =>
Int -> Int -> DevicePtr a -> [DevicePtr a]
genPointers n size p =
take n $ iterate (flip advanceDevPtr size) p
arrayPtrs ::
(A.Shape ix,
Storable a, StorableOf e ~ a,
AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) =>
Matrix ix e -> AF.CIO ([DevicePtr a], Int)
arrayPtrs arr = do
let (numMatrices :. n :. k) = A.arrayShape arr
pa <- devicePtrsOfArray arr
return (genPointers (A.arraySize numMatrices) (n*k) pa, k)
devicePtrsOfArray ::
(A.Shape ix, AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) =>
Scalar ix e -> AF.CIO (DevicePtr a)
devicePtrsOfArray arr = do
((), pa) <- AF.devicePtrsOfArray arr
return $ castDevPtr pa
unify :: (Eq a) => String -> a -> a -> a
unify msg a b = if a == b then a else error msg