{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Private where

import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor, ColumnMajor), transposeFromOrder)
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Scalar (RealOf, zero, one, isZero)

import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.C.String as CStr
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.C.Types (CChar, CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, poke, peek, pokeElemOff, peekElemOff)

import Text.Printf (printf)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT, runContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (Const(Const,getConst), liftA2, (<$>))

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)

import qualified Data.Complex as Complex
import Data.Complex (Complex)
import Data.Tuple.HT (swap)

import Prelude hiding (sum)


realPtr :: Ptr a -> Ptr (RealOf a)
realPtr = castPtr


fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO ()
fill a n dstPtr = evalContT $ do
   nPtr <- Call.cint n
   srcPtr <- Call.number a
   incxPtr <- Call.cint 0
   incyPtr <- Call.cint 1
   liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr


copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO ()
copyBlock n srcPtr dstPtr = evalContT $ do
   nPtr <- Call.cint n
   incxPtr <- Call.cint 1
   incyPtr <- Call.cint 1
   liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr

copyToTemp :: (Storable a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp n fptr = do
   ptr <- ContT $ withForeignPtr fptr
   tmpPtr <- Call.allocaArray n
   liftIO $ copyArray tmpPtr ptr n
   return tmpPtr


{- |
Make a temporary copy only for complex matrices.
-}
conjugateToTemp ::
   (Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp n =
   runCopyToTemp $
   Class.switchFloating
      (CopyToTemp $ ContT . withForeignPtr)
      (CopyToTemp $ ContT . withForeignPtr)
      (CopyToTemp $ complexConjugateToTemp n)
      (CopyToTemp $ complexConjugateToTemp n)

newtype CopyToTemp r a =
   CopyToTemp {runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)}

complexConjugateToTemp ::
   Class.Real a =>
   Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp n x = do
   nPtr <- Call.cint n
   xPtr <- copyToTemp n x
   incxPtr <- Call.cint 1
   liftIO $ LapackComplex.lacgv nPtr xPtr incxPtr
   return xPtr


condConjugate ::
   (Class.Floating a) => Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
condConjugate conj nPtr yPtr incyPtr =
   when (conj==Conjugated) $ lacgv nPtr yPtr incyPtr

copyConjugate ::
   (Class.Floating a) =>
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate nPtr xPtr incxPtr yPtr incyPtr = do
   BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr
   lacgv nPtr yPtr incyPtr

copyCondConjugate ::
   (Class.Floating a) =>
   Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyCondConjugate conj nPtr xPtr incxPtr yPtr incyPtr = do
   BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr
   condConjugate conj nPtr yPtr incyPtr

condConjugateToTemp ::
   (Class.Floating a) =>
   Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
condConjugateToTemp conj n x =
   case conj of
      NonConjugated -> ContT $ withForeignPtr x
      Conjugated -> conjugateToTemp n x

copyCondConjugateToTemp ::
   (Class.Floating a) =>
   Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyCondConjugateToTemp conj n a = do
   bPtr <- Call.allocaArray n
   liftIO $ evalContT $ do
      aPtr <- ContT $ withForeignPtr a
      sizePtr <- Call.cint n
      incPtr <- Call.cint 1
      liftIO $ copyCondConjugate conj sizePtr aPtr incPtr bPtr incPtr
      return bPtr



{- |
In ColumnMajor:
Copy a m-by-n-matrix with lda>=m and ldb>=m.
-}
copySubMatrix ::
   (Class.Floating a) =>
   Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix = copySubTrapezoid 'A'

copySubTrapezoid ::
   (Class.Floating a) =>
   Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubTrapezoid side m n lda aPtr ldb bPtr = evalContT $ do
   uploPtr <- Call.char side
   mPtr <- Call.cint m
   nPtr <- Call.cint n
   ldaPtr <- Call.leadingDim lda
   ldbPtr <- Call.leadingDim ldb
   liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr

copyTransposed ::
   (Class.Floating a) =>
   Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed n m aPtr ldb bPtr = evalContT $ do
   nPtr <- Call.cint n
   incaPtr <- Call.cint m
   incbPtr <- Call.cint 1
   liftIO $ sequence_ $ take m $
      zipWith
         (\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr)
         (pointerSeq 1 aPtr)
         (pointerSeq ldb bPtr)


{- |
Copy a m-by-n-matrix to ColumnMajor order.
-}
copyToColumnMajor ::
   (Class.Floating a) =>
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyToColumnMajor order m n aPtr bPtr =
   case order of
      RowMajor -> copyTransposed m n aPtr m bPtr
      ColumnMajor -> copyBlock (m*n) aPtr bPtr

copyToSubColumnMajor ::
   (Class.Floating a) =>
   Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyToSubColumnMajor order m n aPtr ldb bPtr =
   case order of
      RowMajor -> copyTransposed m n aPtr ldb bPtr
      ColumnMajor ->
         if m==ldb
           then copyBlock (m*n) aPtr bPtr
           else copySubMatrix m n m aPtr ldb bPtr


pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a]
pointerSeq k ptr = iterate (flip advancePtr k) ptr


createHigherArray ::
   (Shape.C sh, Class.Floating a) =>
   sh -> Int -> Int -> Int ->
   ((Ptr a, Int) -> IO rank) -> IO (rank, Array sh a)
createHigherArray shapeX m n nrhs act =
   fmap swap $ ArrayIO.unsafeCreateWithSizeAndResult shapeX $ \ _ xPtr ->
   if m>n
      then
         runContT (Call.allocaArray (m*nrhs)) $ \tmpPtr -> do
            r <- act (tmpPtr,m)
            copySubMatrix n nrhs m tmpPtr n xPtr
            return r
      else act (xPtr,n)



newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a}

sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a
sum =
   runSum $
   Class.switchFloating
      (Sum sumReal)
      (Sum sumReal)
      (Sum sumComplex)
      (Sum sumComplex)

sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal n xPtr incx =
   evalContT $ do
      nPtr <- Call.cint n
      incxPtr <- Call.cint incx
      yPtr <- Call.real one
      incyPtr <- Call.cint 0
      liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr

sumComplex, sumComplexAlt ::
   Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex n xPtr incx =
   evalContT $ do
      nPtr <- Call.cint n
      let sxPtr = realPtr xPtr
      incxPtr <- Call.cint (2*incx)
      yPtr <- Call.real one
      incyPtr <- Call.cint 0
      liftIO $
         liftA2 (Complex.:+)
            (BlasReal.dot nPtr sxPtr incxPtr yPtr incyPtr)
            (BlasReal.dot nPtr (advancePtr sxPtr 1) incxPtr yPtr incyPtr)

sumComplexAlt n aPtr inca =
   evalContT $ do
      transPtr <- Call.char 'N'
      mPtr <- Call.cint 2
      nPtr <- Call.cint n
      onePtr <- Call.number one
      inc0Ptr <- Call.cint 0
      let saPtr = realPtr aPtr
      ldaPtr <- Call.leadingDim (2*inca)
      sxPtr <- Call.allocaArray n
      incxPtr <- Call.cint 1
      betaPtr <- Call.number zero
      yPtr <- Call.alloca
      let syPtr = realPtr yPtr
      incyPtr <- Call.cint 1
      liftIO $ do
         BlasGen.copy nPtr onePtr inc0Ptr sxPtr incxPtr
         gemv
            transPtr mPtr nPtr onePtr saPtr ldaPtr
            sxPtr incxPtr betaPtr syPtr incyPtr
         peek yPtr


mulReal ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulReal n aPtr inca xPtr incx yPtr incy = evalContT $ do
   uploPtr <- Call.char 'U'
   nPtr <- Call.cint n
   kPtr <- Call.cint 0
   alphaPtr <- Call.number one
   ldaPtr <- Call.leadingDim inca
   incxPtr <- Call.cint incx
   betaPtr <- Call.number zero
   incyPtr <- Call.cint incy
   liftIO $
      BlasGen.hbmv uploPtr
         nPtr kPtr alphaPtr aPtr ldaPtr
         xPtr incxPtr betaPtr yPtr incyPtr

mul ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul n aPtr inca xPtr incx yPtr incy = evalContT $ do
   transPtr <- Call.char 'N'
   nPtr <- Call.cint n
   klPtr <- Call.cint 0
   kuPtr <- Call.cint 0
   alphaPtr <- Call.number one
   ldaPtr <- Call.leadingDim inca
   incxPtr <- Call.cint incx
   betaPtr <- Call.number zero
   incyPtr <- Call.cint incy
   liftIO $
      BlasGen.gbmv transPtr
         nPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr
         xPtr incxPtr betaPtr yPtr incyPtr

{- |
Use the foldBalanced trick.
-}
product :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
product n aPtr inca =
   case compare n 1 of
      LT -> return one
      EQ -> peek aPtr
      GT -> let n2 = div n 2; new = n-n2
            in ForeignArray.alloca (2*new-1) $ \xPtr -> do
         mulPairs n2 aPtr inca xPtr 1
         when (odd n) $ pokeElemOff xPtr n2 =<< peekElemOff aPtr ((n-1)*inca)
         productLoop new xPtr

{- |
If 'mul' would be based on a scalar loop
we would not need to cut the vector into chunks.

The invariance is:
When calling @productLoop n xPtr@,
starting from xPtr there is storage allocated for 2*n-1 elements.
-}
productLoop :: (Class.Floating a) => Int -> Ptr a -> IO a
productLoop n xPtr =
   if n==1
      then peek xPtr
      else do
         let n2 = div n 2
         mulPairs n2 xPtr 1 (advancePtr xPtr n) 1
         productLoop (n-n2) (advancePtr xPtr (2*n2))

mulPairs ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs n aPtr inca xPtr incx =
   let inca2 = 2*inca
   in mul n aPtr inca2 (advancePtr aPtr inca) inca2 xPtr incx


newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()}

lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv =
   getLACGV $
   Class.switchFloating
      (LACGV $ const $ const $ const $ return ())
      (LACGV $ const $ const $ const $ return ())
      (LACGV LapackComplex.lacgv)
      (LACGV LapackComplex.lacgv)


{-
Work around an inconsistency of BLAS.
In case of a zero-column matrix
BLAS's gemv and gbmv do not initialize the target vector.
In contrast, these work-arounds do.
-}
{-# INLINE gemv #-}
gemv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
gemv transPtr mPtr nPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
   initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
   BlasGen.gemv transPtr mPtr nPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr

{-# INLINE gbmv #-}
gbmv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> IO ()
gbmv transPtr mPtr nPtr klPtr kuPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do
   initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr
   BlasGen.gbmv transPtr mPtr nPtr klPtr kuPtr
      alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr

initializeMV ::
   Class.Floating a =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr = do
   trans <- peek transPtr
   let (mtPtr,ntPtr) =
         if trans == CStr.castCharToCChar 'N'
            then (mPtr,nPtr) else (nPtr,mPtr)
   n <- peek ntPtr
   beta <- peek betaPtr
   when (n == 0 && isZero beta) $
      Marshal.with 0 $ \incbPtr ->
      BlasGen.copy mtPtr betaPtr incbPtr yPtr incyPtr


multiplyMatrix ::
   (Class.Floating a) =>
   Order -> Order -> Int -> Int -> Int ->
   ForeignPtr a -> ForeignPtr a -> Ptr a -> IO ()
multiplyMatrix orderA orderB m k n a b cPtr = do
   let lda = case orderA of RowMajor -> k; ColumnMajor -> m
   let ldb = case orderB of RowMajor -> n; ColumnMajor -> k
   let ldc = m
   evalContT $ do
      transaPtr <- Call.char $ transposeFromOrder orderA
      transbPtr <- Call.char $ transposeFromOrder orderB
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      kPtr <- Call.cint k
      alphaPtr <- Call.number one
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      ldbPtr <- Call.leadingDim ldb
      betaPtr <- Call.number zero
      ldcPtr <- Call.leadingDim ldc
      liftIO $
         BlasGen.gemm
            transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr
            bPtr ldbPtr betaPtr cPtr ldcPtr



withAutoWorkspaceInfo ::
   (Class.Floating a) =>
   String -> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo msg name computation =
   withInfo msg name $ \infoPtr ->
   withAutoWorkspace $ \workPtr lworkPtr ->
      computation workPtr lworkPtr infoPtr

withAutoWorkspace ::
   (Class.Floating a) =>
   (Ptr a -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspace computation = evalContT $ do
   lworkPtr <- Call.cint (-1)
   lwork <- liftIO $ alloca $ \workPtr -> do
      computation workPtr lworkPtr
      max 1 . ceilingSize <$> peek workPtr
   workPtr <- Call.allocaArray lwork
   liftIO $ pokeCInt lworkPtr lwork
   liftIO $ computation workPtr lworkPtr

withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo msg name computation = alloca $ \infoPtr -> do
   computation infoPtr
   info <- peekCInt infoPtr
   case compare info (0::Int) of
      EQ -> return ()
      LT -> error $ printf argMsg name (-info)
      GT -> error $ name ++ ": " ++ printf msg info

argMsg :: String
argMsg = "%s: illegal value in %d-th argument"

errorCodeMsg :: String
errorCodeMsg = "unknown error code %d"

rankMsg :: String
rankMsg = "deficient rank %d"

definiteMsg :: String
definiteMsg = "minor of order %d not positive definite"

eigenMsg :: String
eigenMsg = "%d off-diagonal elements not converging"


pokeCInt :: Ptr CInt -> Int -> IO ()
pokeCInt ptr = poke ptr . fromIntegral

peekCInt :: Ptr CInt -> IO Int
peekCInt ptr = fromIntegral <$> peek ptr


ceilingSize :: (Class.Floating a) => a -> Int
ceilingSize =
   getFlip $
   Class.switchFloating
      (Flip ceiling)
      (Flip ceiling)
      (Flip $ ceiling . Complex.realPart)
      (Flip $ ceiling . Complex.realPart)


caseRealComplexFunc :: (Class.Floating a) => f a -> b -> b -> b
caseRealComplexFunc f r c =
   getConstFunc f $
   Class.switchFloating (Const r) (Const r) (Const c) (Const c)

getConstFunc :: f c -> Const a c -> a
getConstFunc _ = getConst


data ComplexPart = RealPart | ImaginaryPart
   deriving (Eq, Ord, Show, Enum, Bounded)