{-# LANGUAGE TypeOperators #-}
module Data.Array.Accelerate.CUBLAS.Level3.Batched (
   Cublas.Handle,
   Cublas.create,
   Element,
   mul,
   mac,
   LU,
   lu,
   luInv,
   inv,
   luSolve,
   newtonInverseStep,
   newtonInverse,
   ) where

import qualified Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign as Foreign
import Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign (Element, Vector)

import qualified Data.Array.Accelerate.LinearAlgebra as ALinAlg

import qualified Data.Array.Accelerate.Utility.Sliced1 as Sliced1
import qualified Data.Array.Accelerate.Utility.Sliced as Sliced
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate.Utility.Loop as Loop
import qualified Data.Array.Accelerate.Utility.Lift.Acc as Acc
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import Data.Array.Accelerate.Utility.Lift.Acc (acc, expr)

import Data.Array.Accelerate (Exp, (:.)((:.)), (!))
import qualified Data.Array.Accelerate.CUDA.Foreign as AF
import qualified Data.Array.Accelerate.IO as AIO
import qualified Data.Array.Accelerate as A

import qualified Foreign.CUDA.Cublas as Cublas

import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as MV

import Control.Monad.ST (ST)
import Control.Monad (zipWithM_)

import Data.Tuple.HT (mapSnd, uncurry3)

import Data.Word (Word32)


mul ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
   Cublas.Handle ->
   Exp a ->
   ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
mul handle alpha a b =
   A.foreignAcc
      (AF.CUDAForeignAcc "mul" $ const $ uncurry3 $ Foreign.mul handle)
      (Acc.modify (expr,acc,acc) $ \(alpha0, a0, b0) ->
         A.map (alpha0 *) $
         ALinAlg.multiplyMatrixMatrix a0 b0)
   $
   A.lift (A.unit alpha, a, b)

mac ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
   Cublas.Handle ->
   Exp a -> ALinAlg.Matrix ix a -> ALinAlg.Matrix ix a ->
   Exp a -> ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
mac handle alpha a b beta c =
   A.foreignAcc
      (AF.CUDAForeignAcc "mac" $ const $
       \((alpha0, a0, b0), (beta0, c0)) ->
          Foreign.mac handle alpha0 a0 b0 beta0 c0)
      (Acc.modify ((expr,acc,acc),(expr,acc)) $
       \((alpha0, a0, b0), (beta0, c0)) ->
         A.zipWith (+)
            (A.map (alpha0 *) $
             ALinAlg.multiplyMatrixMatrix a0 b0)
            (A.map (beta0 *) c0))
   $
   A.lift ((A.unit alpha, a, b), (A.unit beta, c))



newtype LU ix a =
   LU {
      _getLU ::
         (ALinAlg.Matrix ix a,
          ALinAlg.Vector ix Word32, ALinAlg.Scalar ix Word32)
   }

lu ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   ALinAlg.Matrix ix a -> LU ix a
lu handle =
   LU . A.unlift . cudaAcc "lu" (Foreign.lu handle)


luInv ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   LU ix a -> (ALinAlg.Matrix ix a, ALinAlg.Scalar ix Word32)
luInv handle (LU sol@(_,_,info)) =
   (cudaAcc "luInv" (Foreign.luInv handle) $ A.lift sol, info)


{- |
Returns the inverted matrix and a rank information.
If the matrix is invertible, then the rank information is zero.
Otherwise it is the matrix rank plus 1.
-}
inv, _inv ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   ALinAlg.Matrix ix a ->
   (ALinAlg.Matrix ix a, ALinAlg.Scalar ix Word32)
inv handle = luInv handle . lu handle

{- |
maximum size of matrices is 32 in CUDA-6.0 and CUDA-6.5.
-}
_inv handle =
   A.unlift . cudaAcc "inv" (Foreign.inv handle)


{- |
Matrices with sizes larger than 32
are only supported starting with CUDA-6.5.
In CUDA-6.0 you will get the error
@CUBLAS Exception: unsupported value or parameter passed to a function@.
On CUDA-6.0 you may prefer 'luInv' which works surprisingly.
-}
luSolve ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   LU ix a ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
luSolve handle (LU (luMat, pivots, _info)) =
   let perm = permutationFromPivotsAcc $ A.map (subtract 1) pivots
   in  applyRowPerm perm
       .
       cudaAcc "luSolve" (uncurry $ Foreign.luSolve handle $ Acc.singleton 1)
       .
       A.lift . (,) luMat


_applyColPerm ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   ALinAlg.Vector ix Word32 ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
_applyColPerm perm arr =
   Arrange.mapWithIndex
      (Exp.modify2 (expr:.expr:.expr) expr $
       \(ix :. j :. _i) src -> arr ! A.lift (ix :. j :. src)) $
   A.replicate (A.lift $ A.Any :. Sliced1.length arr :. A.All) $
   A.map (A.fromIntegral :: Exp Word32 -> Exp Int) perm

applyRowPerm ::
   (A.Shape ix, A.Slice ix, A.Elt a) =>
   ALinAlg.Vector ix Word32 ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
applyRowPerm perm arr =
   Arrange.mapWithIndex
      (Exp.modify2 (expr:.expr:.expr) expr $
       \(ix :. _j :. i) src -> arr ! A.lift (ix :. src :. i)) $
   A.replicate (A.lift $ A.Any :. Sliced.length arr) $
   A.map (A.fromIntegral :: Exp Word32 -> Exp Int) perm

permutationFromPivotsAcc ::
   (A.Shape ix) =>
   ALinAlg.Vector ix Word32 -> ALinAlg.Vector ix Word32
permutationFromPivotsAcc =
   cudaAcc "permutations" $ \arr -> do
      AF.peekArray arr
      let perm = permutationFromPivots arr
      AF.useArray perm
      return perm

permutationFromPivots ::
   (A.Shape ix) =>
   Vector ix Word32 -> Vector ix Word32
permutationFromPivots vec =
   let sh = A.arrayShape vec
   in  AIO.fromVectors sh $
       mapSnd (permutationsFromPivotsSlices sh) $
       AIO.toVectors vec

permutationsFromPivotsSlices ::
   (A.Shape sh) =>
   sh :. Int -> V.Vector Word32 -> V.Vector Word32
permutationsFromPivotsSlices (shape:.width) pivots = V.create (do
   perm <- MV.new $ V.length pivots
   mapM_
      (\k ->
         permutationFromPivotsMutableBackward
            (V.slice k width pivots)
            (MV.slice k width perm))
      (take (A.arraySize shape) [0, width ..])
   return perm)

{- |
works always, but requires two traversals through the array
-}
_permutationFromPivotsMutable ::
   V.Vector Word32 -> MV.MVector s Word32 -> ST s ()
_permutationFromPivotsMutable pivots perm = do
   let ixs = V.enumFromN 0 (V.length pivots)
   V.copy perm ixs
   V.zipWithM_
      (\k j -> MV.swap perm (fromIntegral k) (fromIntegral j))
      (V.reverse ixs) (V.reverse pivots)

-- | works only if forall i. pivot!!i >= i
permutationFromPivotsMutableBackward ::
   V.Vector Word32 -> MV.MVector s Word32 -> ST s ()
permutationFromPivotsMutableBackward pivots perm = do
   zipWithM_
      (\k j -> do
         MV.write perm k (fromIntegral k)
         MV.swap perm k (fromIntegral j))
      (iterate (subtract 1) $ V.length pivots - 1) (V.toList $ V.reverse pivots)

-- | works only if forall i. pivot!!i <= i
_permutationFromPivotsMutableForward ::
   V.Vector Word32 -> MV.MVector s Word32 -> ST s ()
_permutationFromPivotsMutableForward pivots perm = do
   zipWithM_
      (\k j -> do
         MV.write perm k (fromIntegral k)
         MV.swap perm k (fromIntegral j))
      [0..] (V.toList pivots)



newtonInverseStep ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
   Cublas.Handle ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
newtonInverseStep h a x =
   mac h (-1) x (mul h 1 a x) 2 x

newtonInverse ::
   (A.Shape ix, A.Slice ix, Eq ix, Element a, A.Elt a, A.IsNum a) =>
   Cublas.Handle ->
   A.Exp Int ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a ->
   ALinAlg.Matrix ix a
newtonInverse h n seed a =
   Loop.nest n (newtonInverseStep h a) seed


cudaAcc ::
   (A.Arrays res, A.Arrays acc) =>
   String -> (acc -> AF.CIO res) -> A.Acc acc -> A.Acc res
cudaAcc name f =
   A.foreignAcc
      (AF.CUDAForeignAcc name $ const f)
      (error $ name ++ ": requires CUDA backend")