module Data.Array.Accelerate.CUBLAS.Level3.Batched (
Cublas.Handle,
Cublas.create,
Element,
mul,
mac,
LU,
lu,
luInv,
inv,
luSolve,
newtonInverseStep,
newtonInverse,
) where
import qualified Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign as Foreign
import Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign (Element, Vector)
import qualified Data.Array.Accelerate.LinearAlgebra as ALinAlg
import qualified Data.Array.Accelerate.Utility.Sliced1 as Sliced1
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate.Utility.Loop as Loop
import qualified Data.Array.Accelerate.Utility.Lift.Acc as Acc
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import Data.Array.Accelerate.Utility.Lift.Acc (acc, expr)
import Data.Array.Accelerate (Exp, (:.)((:.)), (!))
import qualified Data.Array.Accelerate.CUDA.Foreign as AF
import qualified Data.Array.Accelerate.IO as AIO
import qualified Data.Array.Accelerate as A
import qualified Foreign.CUDA.Cublas as Cublas
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as MV
import Control.Monad.ST (ST)
import Control.Monad (zipWithM_)
import Data.Tuple.HT (mapSnd, uncurry3)
import Data.Word (Word32)
mul ::
(A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
Cublas.Handle ->
Exp a ->
ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
mul handle alpha a b =
A.foreignAcc
(AF.CUDAForeignAcc "mul" $ const $ uncurry3 $ Foreign.mul handle)
(Acc.modify (expr,acc,acc) $ \(alpha0, a0, b0) ->
A.map (alpha0 *) $
ALinAlg.multiplyMatrixMatrix a0 b0)
$
A.lift (A.unit alpha, a, b)
mac ::
(A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
Cublas.Handle ->
Exp a -> ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a ->
Exp a -> ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
mac handle alpha a b beta c =
A.foreignAcc
(AF.CUDAForeignAcc "mac" $ const $
\((alpha0, a0, b0), (beta0, c0)) ->
Foreign.mac handle alpha0 a0 b0 beta0 c0)
(Acc.modify ((expr,acc,acc),(expr,acc)) $
\((alpha0, a0, b0), (beta0, c0)) ->
A.zipWith (+)
(A.map (alpha0 *) $
ALinAlg.multiplyMatrixMatrix a0 b0)
(A.map (beta0 *) c0))
$
A.lift ((A.unit alpha, a, b), (A.unit beta, c))
newtype LU ix a =
LU {
_getLU ::
(ALinAlg.Matrix ix a,
ALinAlg.Vector ix Word32, ALinAlg.Scalar ix Word32)
}
lu ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
ALinAlg.Matrix ix a -> LU ix a
lu handle =
LU . A.unlift . cudaAcc "lu" (Foreign.lu handle)
luInv ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
LU ix a -> (ALinAlg.Matrix ix a, ALinAlg.Scalar ix Word32)
luInv handle (LU sol@(_,_,info)) =
(cudaAcc "luInv" (Foreign.luInv handle) $ A.lift sol, info)
inv, _inv ::
(A.Shape ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
ALinAlg.Matrix ix a ->
(ALinAlg.Matrix ix a, ALinAlg.Scalar ix Word32)
inv handle = luInv handle . lu handle
_inv handle =
A.unlift . cudaAcc "inv" (Foreign.inv handle)
luSolve ::
(A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a) =>
Cublas.Handle ->
LU ix a ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
luSolve handle (LU (luMat, pivots, _info)) =
let perm = permutationFromPivotsAcc $ A.map (subtract 1) pivots
in applyRowPerm perm
.
cudaAcc "luSolve" (uncurry $ Foreign.luSolve handle $ Acc.singleton 1)
.
A.lift . (,) luMat
_applyColPerm ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
ALinAlg.Vector ix Word32 ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
_applyColPerm perm arr =
Arrange.mapWithIndex
(Exp.modify2 (expr:.expr:.expr) expr $
\(ix :. j :. _i) src -> arr ! A.lift (ix :. j :. src)) $
A.replicate (A.lift $ A.Any :. Sliced1.length arr :. A.All) $
A.map (A.fromIntegral :: Exp Word32 -> Exp Int) perm
applyRowPerm ::
(A.Shape ix, A.Slice ix, A.Elt a) =>
ALinAlg.Vector ix Word32 ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
applyRowPerm perm arr =
Arrange.mapWithIndex
(Exp.modify2 (expr:.expr:.expr) expr $
\(ix :. _j :. i) src -> arr ! A.lift (ix :. src :. i)) $
A.replicate (A.lift $ A.Any :. Sliced.length arr) $
A.map (A.fromIntegral :: Exp Word32 -> Exp Int) perm
permutationFromPivotsAcc ::
(A.Shape ix) =>
ALinAlg.Vector ix Word32 -> ALinAlg.Vector ix Word32
permutationFromPivotsAcc =
cudaAcc "permutations" $ \arr -> do
AF.peekArray arr
let perm = permutationFromPivots arr
AF.useArray perm
return perm
permutationFromPivots ::
(A.Shape ix) =>
Vector ix Word32 -> Vector ix Word32
permutationFromPivots vec =
let sh = A.arrayShape vec
in AIO.fromVectors sh $
mapSnd (permutationsFromPivotsSlices sh) $
AIO.toVectors vec
permutationsFromPivotsSlices ::
(A.Shape sh) =>
sh :. Int -> V.Vector Word32 -> V.Vector Word32
permutationsFromPivotsSlices (shape:.width) pivots = V.create (do
perm <- MV.new $ V.length pivots
mapM_
(\k ->
permutationFromPivotsMutableBackward
(V.slice k width pivots)
(MV.slice k width perm))
(take (A.arraySize shape) [0, width ..])
return perm)
_permutationFromPivotsMutable ::
V.Vector Word32 -> MV.MVector s Word32 -> ST s ()
_permutationFromPivotsMutable pivots perm = do
let ixs = V.enumFromN 0 (V.length pivots)
V.copy perm ixs
V.zipWithM_
(\k j -> MV.swap perm (fromIntegral k) (fromIntegral j))
(V.reverse ixs) (V.reverse pivots)
permutationFromPivotsMutableBackward ::
V.Vector Word32 -> MV.MVector s Word32 -> ST s ()
permutationFromPivotsMutableBackward pivots perm = do
zipWithM_
(\k j -> do
MV.write perm k (fromIntegral k)
MV.swap perm k (fromIntegral j))
(iterate (subtract 1) $ V.length pivots 1) (V.toList $ V.reverse pivots)
_permutationFromPivotsMutableForward ::
V.Vector Word32 -> MV.MVector s Word32 -> ST s ()
_permutationFromPivotsMutableForward pivots perm = do
zipWithM_
(\k j -> do
MV.write perm k (fromIntegral k)
MV.swap perm k (fromIntegral j))
[0..] (V.toList pivots)
newtonInverseStep ::
(A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
Cublas.Handle ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
newtonInverseStep h a x =
mac h (1) x (mul h 1 a x) 2 x
newtonInverse ::
(A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
Cublas.Handle ->
A.Exp Int ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a ->
ALinAlg.Matrix ix a
newtonInverse h n seed a =
Loop.nest n (newtonInverseStep h a) seed
cudaAcc ::
(A.Arrays res, A.Arrays acc) =>
String -> (acc -> AF.CIO res) -> A.Acc acc -> A.Acc res
cudaAcc name f =
A.foreignAcc
(AF.CUDAForeignAcc name $ const f)
(error $ name ++ ": requires CUDA backend")