{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} module Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign where import Data.Array.Accelerate.Array.Sugar (EltRepr) import Data.Array.Accelerate (Array, Shape, Z(Z), (:.)((:.))) import qualified Data.Array.Accelerate.CUDA.Foreign as AF import qualified Data.Array.Accelerate as A import qualified Foreign.CUDA.Cublas as Cublas import Foreign.CUDA.Ptr (DevicePtr, castDevPtr, advanceDevPtr) import Foreign.C.Types (CFloat, CDouble) import Foreign.Storable (Storable) import Data.Word (Word32) type Matrix ix = Array (ix :. Int :. Int) type Vector ix = Array (ix :. Int) type Scalar ix = Array ix mul :: (Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> A.Scalar a -> Matrix ix a -> Matrix ix a -> AF.CIO (Matrix ix a) mul handle alpha a b = do let (aNumMatrices :. n :. ak) = A.arrayShape a let (bNumMatrices :. bk :. m) = A.arrayShape b let k = unify "mul: matrix sizes mismatch" ak bk let numMatrices = unify "mul: mismatching shapes of matrix arrays" aNumMatrices bNumMatrices c <- AF.allocateArray (numMatrices :. n :. m) (pas, lda) <- arrayPtrs a (pbs, ldb) <- arrayPtrs b (pcs, ldc) <- arrayPtrs c AF.liftIO $ Cublas.gemmBatched handle Cublas.N Cublas.N m n k (storableFromScalar alpha) pbs ldb pas lda 0 pcs ldc (A.arraySize numMatrices) return c mac :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> A.Scalar a -> Matrix ix a -> Matrix ix a -> A.Scalar a -> Matrix ix a -> AF.CIO (Matrix ix a) mac handle alpha a b beta c = do let (aNumMatrices :. an :. bk) = A.arrayShape a let (bNumMatrices :. ak :. bm) = A.arrayShape b let (cNumMatrices :. cn :. cm) = A.arrayShape c let k = unify "mac: matrix sizes mismatch" ak bk let n = unify "mac: matrix sizes mismatch" an cn let m = unify "mac: matrix sizes mismatch" bm cm let numMatrices = let msg = "mac: mismatching shapes of matrix arrays" in unify msg aNumMatrices $ unify msg bNumMatrices cNumMatrices d <- AF.allocateArray (numMatrices :. n :. m) AF.copyArray c d (pas, lda) <- arrayPtrs a (pbs, ldb) <- arrayPtrs b (pds, ldd) <- arrayPtrs d AF.liftIO $ Cublas.gemmBatched handle Cublas.N Cublas.N m n k (storableFromScalar alpha) pbs ldb pas lda (storableFromScalar beta) pds ldd (A.arraySize numMatrices) return d lu :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> Matrix ix a -> AF.CIO (Matrix ix a, Vector ix Word32, Scalar ix Word32) lu handle a = do let sh@(numMatrices :. n :. k) = A.arrayShape a let size = unify "lu: matrices must have square shape" n k b <- AF.allocateArray sh AF.copyArray a b (pbs, ldb) <- arrayPtrs b pivot <- AF.allocateArray (numMatrices :. size) pivotPtr <- devicePtrsOfArray pivot info <- AF.allocateArray numMatrices infoPtr <- devicePtrsOfArray info AF.liftIO $ Cublas.getrfBatched handle size pbs ldb pivotPtr infoPtr (A.arraySize numMatrices) return (b, pivot, info) luInv :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> (Matrix ix a, Vector ix Word32, Scalar ix Word32) -> AF.CIO (Matrix ix a) luInv handle (a, pivot, info) = do let sh@(numMatrices :. n :. k) = A.arrayShape a let size = unify "luInv: matrices must have square shape" n k c <- AF.allocateArray sh AF.copyArray a c (pas, lda) <- arrayPtrs a (pcs, ldc) <- arrayPtrs c pivotPtr <- devicePtrsOfArray pivot infoPtr <- devicePtrsOfArray info AF.liftIO $ Cublas.getriBatched handle size pas lda pivotPtr pcs ldc infoPtr (A.arraySize numMatrices) return c luSolve :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> A.Scalar a -> Matrix ix a -> Matrix ix a -> AF.CIO (Matrix ix a) luSolve handle alpha a b = do let (aNumMatrices :. an :. ak) = A.arrayShape a let sh@(bNumMatrices :. bk :. m) = A.arrayShape b let n = unify "luSolve: matrices must have square shape" an $ unify "luSolve: matrices dimensions must match" ak bk let count = A.arraySize $ unify "mul: mismatching shapes of matrix arrays" aNumMatrices bNumMatrices c <- AF.allocateArray sh AF.copyArray b c (pas, lda) <- arrayPtrs a (pcs, ldc) <- arrayPtrs c AF.liftIO $ do Cublas.trsmBatched handle Cublas.SideRight Cublas.Upper Cublas.N Cublas.NonUnit m n (storableFromScalar alpha) pas lda pcs ldc count Cublas.trsmBatched handle Cublas.SideRight Cublas.Lower Cublas.N Cublas.Unit m n (storableFromScalar alpha) pas lda pcs ldc count return c inv :: (A.Shape ix, Eq ix, Element a, A.Elt a) => Cublas.Handle -> Matrix ix a -> AF.CIO (Matrix ix a, Scalar ix Word32) inv handle a = do let sh@(numMatrices :. n :. k) = A.arrayShape a let size = unify "inv: matrices must have square shape" n k b <- AF.allocateArray sh (pas, lda) <- arrayPtrs a (pbs, ldb) <- arrayPtrs b info <- AF.allocateArray numMatrices infoPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray info AF.liftIO $ Cublas.matinvBatched handle size pas lda pbs ldb infoPtr (A.arraySize numMatrices) return (b, info) type Element a = (AF.DevicePtrs (EltRepr a) ~ ((), DevicePtr a), Fractional (StorableOf a), Cublas.Cublas (StorableOf a), Storable (StorableOf a), Real a) type family StorableOf float type instance StorableOf Float = CFloat type instance StorableOf Double = CDouble storableFromScalar :: (Real a, StorableOf a ~ b, Fractional b) => A.Scalar a -> b storableFromScalar x = realToFrac $ A.indexArray x Z genPointers :: (Storable a) => Int -> Int -> DevicePtr a -> [DevicePtr a] genPointers n size p = take n $ iterate (flip advanceDevPtr size) p arrayPtrs :: (A.Shape ix, Storable a, StorableOf e ~ a, AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) => Matrix ix e -> AF.CIO ([DevicePtr a], Int) arrayPtrs arr = do let (numMatrices :. n :. k) = A.arrayShape arr pa <- devicePtrsOfArray arr return (genPointers (A.arraySize numMatrices) (n*k) pa, k) devicePtrsOfArray :: (A.Shape ix, AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) => Scalar ix e -> AF.CIO (DevicePtr a) devicePtrsOfArray arr = do ((), pa) <- AF.devicePtrsOfArray arr return $ castDevPtr pa unify :: (Eq a) => String -> a -> a -> a unify msg a b = if a == b then a else error msg