module Data.Array.Accelerate.CUBLAS.Level2.Batched (
   Cublas.Handle,
   Cublas.create,
   Level3.Element,
   mul,
   mac,
   ) where

import qualified Data.Array.Accelerate.CUBLAS.Level3.Batched as Level3
import Data.Array.Accelerate.CUBLAS.Level3.Batched (Element)

import qualified Data.Array.Accelerate.LinearAlgebra as ALinAlg

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Exp)

import qualified Foreign.CUDA.Cublas as Cublas


mul ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
   Cublas.Handle ->
   Exp a ->
   ALinAlg.Matrix ix a -> ALinAlg.Vector ix a ->
   ALinAlg.Vector ix a
mul handle alpha a b =
   ALinAlg.vectorFromColumn $
   Level3.mul handle alpha a (ALinAlg.columnFromVector b)

mac ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
   Cublas.Handle ->
   Exp a -> ALinAlg.Matrix ix a -> ALinAlg.Vector ix a ->
   Exp a -> ALinAlg.Vector ix a ->
   ALinAlg.Vector ix a
mac handle alpha a b beta c =
   A.reshape (A.shape c) $
   Level3.mac handle
      alpha a (ALinAlg.columnFromVector b)
      beta (ALinAlg.columnFromVector c)