{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.CUBLAS.Level3.Batched.Foreign where

import Data.Array.Accelerate.Array.Sugar (EltRepr)
import Data.Array.Accelerate (Array, Shape, Z(Z), (:.)((:.)))
import qualified Data.Array.Accelerate.CUDA.Foreign as AF
import qualified Data.Array.Accelerate as A

import qualified Foreign.CUDA.Cublas as Cublas
import Foreign.CUDA.Ptr (DevicePtr, castDevPtr, advanceDevPtr)

import Foreign.C.Types (CFloat, CDouble)
import Foreign.Storable (Storable)

import Data.Word (Word32)


type Matrix ix = Array (ix :. Int :. Int)
type Vector ix = Array (ix :. Int)
type Scalar ix = Array ix


mul ::
   (Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   A.Scalar a -> Matrix ix a -> Matrix ix a ->
   AF.CIO (Matrix ix a)
mul handle alpha a b = do
   let (aNumMatrices :. n  :. ak) = A.arrayShape a
   let (bNumMatrices :. bk :. m)  = A.arrayShape b
   let k = unify "mul: matrix sizes mismatch" ak bk
   let numMatrices =
          unify "mul: mismatching shapes of matrix arrays"
             aNumMatrices bNumMatrices
   c <- AF.allocateArray (numMatrices :. n :. m)
   (pas, lda) <- arrayPtrs a
   (pbs, ldb) <- arrayPtrs b
   (pcs, ldc) <- arrayPtrs c
   AF.liftIO $
      Cublas.gemmBatched handle Cublas.N Cublas.N m n k
         (storableFromScalar alpha)
         pbs ldb
         pas lda
         0
         pcs ldc
         (A.arraySize numMatrices)
   return c

mac ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   A.Scalar a -> Matrix ix a -> Matrix ix a ->
   A.Scalar a -> Matrix ix a ->
   AF.CIO (Matrix ix a)
mac handle alpha a b beta c = do
   let (aNumMatrices :. an :. bk) = A.arrayShape a
   let (bNumMatrices :. ak :. bm) = A.arrayShape b
   let (cNumMatrices :. cn :. cm) = A.arrayShape c
   let k = unify "mac: matrix sizes mismatch" ak bk
   let n = unify "mac: matrix sizes mismatch" an cn
   let m = unify "mac: matrix sizes mismatch" bm cm
   let numMatrices =
          let msg = "mac: mismatching shapes of matrix arrays"
          in  unify msg aNumMatrices $
              unify msg bNumMatrices cNumMatrices
   d <- AF.allocateArray (numMatrices :. n :. m)
   AF.copyArray c d
   (pas, lda) <- arrayPtrs a
   (pbs, ldb) <- arrayPtrs b
   (pds, ldd) <- arrayPtrs d
   AF.liftIO $
      Cublas.gemmBatched handle Cublas.N Cublas.N m n k
         (storableFromScalar alpha)
         pbs ldb
         pas lda
         (storableFromScalar beta)
         pds ldd
         (A.arraySize numMatrices)
   return d

lu ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   Matrix ix a ->
   AF.CIO (Matrix ix a, Vector ix Word32, Scalar ix Word32)
lu handle a = do
   let sh@(numMatrices :. n :. k) = A.arrayShape a
   let size = unify "lu: matrices must have square shape" n k
   b <- AF.allocateArray sh
   AF.copyArray a b
   (pbs, ldb) <- arrayPtrs b

   pivot <- AF.allocateArray (numMatrices :. size)
   pivotPtr <- devicePtrsOfArray pivot

   info <- AF.allocateArray numMatrices
   infoPtr <- devicePtrsOfArray info

   AF.liftIO $
      Cublas.getrfBatched handle size
         pbs ldb
         pivotPtr infoPtr
         (A.arraySize numMatrices)
   return (b, pivot, info)

luInv ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   (Matrix ix a, Vector ix Word32, Scalar ix Word32) ->
   AF.CIO (Matrix ix a)
luInv handle (a, pivot, info) = do
   let sh@(numMatrices :. n :. k) = A.arrayShape a
   let size = unify "luInv: matrices must have square shape" n k
   c <- AF.allocateArray sh
   AF.copyArray a c
   (pas, lda) <- arrayPtrs a
   (pcs, ldc) <- arrayPtrs c

   pivotPtr <- devicePtrsOfArray pivot
   infoPtr <- devicePtrsOfArray info

   AF.liftIO $
      Cublas.getriBatched handle size
         pas lda
         pivotPtr
         pcs ldc
         infoPtr
         (A.arraySize numMatrices)
   return c

luSolve ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   A.Scalar a ->
   Matrix ix a ->
   Matrix ix a ->
   AF.CIO (Matrix ix a)
luSolve handle alpha a b = do
   let    (aNumMatrices :. an :. ak) = A.arrayShape a
   let sh@(bNumMatrices :. bk :. m)  = A.arrayShape b
   let n =
          unify "luSolve: matrices must have square shape" an $
          unify "luSolve: matrices dimensions must match" ak bk
   let count =
          A.arraySize $
          unify "mul: mismatching shapes of matrix arrays"
             aNumMatrices bNumMatrices
   c <- AF.allocateArray sh
   AF.copyArray b c
   (pas, lda) <- arrayPtrs a
   (pcs, ldc) <- arrayPtrs c

   AF.liftIO $ do
      Cublas.trsmBatched handle
         Cublas.SideRight Cublas.Upper Cublas.N Cublas.NonUnit m n
         (storableFromScalar alpha)
         pas lda
         pcs ldc
         count
      Cublas.trsmBatched handle
         Cublas.SideRight Cublas.Lower Cublas.N Cublas.Unit m n
         (storableFromScalar alpha)
         pas lda
         pcs ldc
         count
   return c

inv ::
   (A.Shape ix, Eq ix, Element a, A.Elt a) =>
   Cublas.Handle ->
   Matrix ix a ->
   AF.CIO (Matrix ix a, Scalar ix Word32)
inv handle a = do
   let sh@(numMatrices :. n  :. k) = A.arrayShape a
   let size = unify "inv: matrices must have square shape" n k
   b <- AF.allocateArray sh
   (pas, lda) <- arrayPtrs a
   (pbs, ldb) <- arrayPtrs b
   info <- AF.allocateArray numMatrices
   infoPtr <- fmap (castDevPtr . snd) $ AF.devicePtrsOfArray info

   AF.liftIO $
      Cublas.matinvBatched handle size
         pas lda
         pbs ldb
         infoPtr
         (A.arraySize numMatrices)
   return (b, info)


class
   (AF.DevicePtrs (EltRepr a) ~ ((), DevicePtr a),
    Fractional (StorableOf a),
    Cublas.Cublas (StorableOf a),
    Storable (StorableOf a),
    Real a) =>
       Element a where

instance Element Float where
instance Element Double where


type family StorableOf float
type instance StorableOf Float = CFloat
type instance StorableOf Double = CDouble

storableFromScalar ::
   (Real a, StorableOf a ~ b, Fractional b) => A.Scalar a -> b
storableFromScalar x = realToFrac $ A.indexArray x Z

genPointers ::
   (Storable a) =>
   Int -> Int -> DevicePtr a -> [DevicePtr a]
genPointers n size p =
   take n $ iterate (flip advanceDevPtr size) p

arrayPtrs ::
   (A.Shape ix,
    Storable a, StorableOf e ~ a,
    AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) =>
   Matrix ix e -> AF.CIO ([DevicePtr a], Int)
arrayPtrs arr = do
   let (numMatrices :. n :. k) = A.arrayShape arr
   pa <- devicePtrsOfArray arr
   return (genPointers (A.arraySize numMatrices) (n*k) pa, k)

devicePtrsOfArray ::
   (A.Shape ix, AF.DevicePtrs (EltRepr e) ~ ((), DevicePtr e)) =>
   Scalar ix e -> AF.CIO (DevicePtr a)
devicePtrsOfArray arr = do
   ((), pa) <- AF.devicePtrsOfArray arr
   return $ castDevPtr pa

unify :: (Eq a) => String -> a -> a -> a
unify msg a b  =  if a == b then a else error msg