module Numeric.LAPACK.Matrix.Symmetric.Private where

import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Triangular.Private
         (diagonalPointerPairs, columnMajorPointers, rowMajorPointers,
          forPointers, pack, unpackToTemp, copyTriangleToTemp)
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor,ColumnMajor), uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver, withDeterminantInfo, withInfo)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (copyBlock, copyToTemp, copyCondConjugate)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape (triangleSize)

import Foreign.Marshal.Array (advancePtr)
import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, peek)

import qualified System.IO.Lazy as LazyIO

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative ((<$>))


unpack :: Class.Floating a =>
   Conjugation -> Order -> Int -> Ptr a -> Ptr a -> IO ()
unpack conj order n packedPtr fullPtr = evalContT $ do
   incxPtr <- Call.cint 1
   incyPtr <- Call.cint n
   liftIO $ case order of
      RowMajor ->
         forPointers (rowMajorPointers n fullPtr packedPtr) $
               \nPtr (dstPtr,srcPtr) -> do
            copyCondConjugate conj nPtr srcPtr incxPtr dstPtr incyPtr
            BlasGen.copy nPtr srcPtr incxPtr dstPtr incxPtr
      ColumnMajor ->
         forPointers (columnMajorPointers n fullPtr packedPtr) $
               \nPtr ((dstRowPtr,dstColumnPtr),srcPtr) -> do
            copyCondConjugate conj nPtr srcPtr incxPtr dstRowPtr incyPtr
            BlasGen.copy nPtr srcPtr incxPtr dstColumnPtr incxPtr


square ::
   (Class.Floating a) =>
   Conjugation -> Order -> Int -> ForeignPtr a -> Ptr a -> IO ()
square conj order n a bpPtr =
   evalContT $ do
      sidePtr <- Call.char 'L'
      uploPtr <- Call.char 'U'
      nPtr <- Call.cint n
      ldPtr <- Call.leadingDim n
      aPtr <- unpackToTemp (unpack conj order) n a
      bPtr <- Call.allocaArray (n*n)
      alphaPtr <- Call.number one
      betaPtr <- Call.number zero
      liftIO $ do
         (if conj==Conjugated then BlasGen.hemm else BlasGen.symm)
            sidePtr uploPtr
            nPtr nPtr alphaPtr aPtr ldPtr
            aPtr ldPtr betaPtr bPtr ldPtr
         pack order n bPtr bpPtr


solve ::
   (Extent.C vert, Extent.C horiz,
    Shape.C width, Shape.C height, Eq height, Class.Floating a) =>
   String -> Conjugation -> Order -> height -> ForeignPtr a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
solve name conj order sh a =
   solver name sh $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      uploPtr <- Call.char $ uploFromOrder order
      apPtr <- copyTriangleToTemp conj order (triangleSize n) a
      ipivPtr <- Call.allocaArray n
      liftIO $
         let (lapackName,slv) =
               case conj of
                  Conjugated -> ("hpsv", LapackGen.hpsv)
                  NonConjugated -> ("spsv", LapackGen.spsv)
         in withInfo lapackName $
               slv uploPtr nPtr nrhsPtr apPtr ipivPtr xPtr ldxPtr


inverse ::
   Class.Floating a =>
   Conjugation -> Order -> Int -> ForeignPtr a -> Int -> Ptr a -> IO ()
inverse conj order n a triSize bPtr = evalContT $ do
   uploPtr <- Call.char $ uploFromOrder order
   nPtr <- Call.cint n
   aPtr <- ContT $ withForeignPtr a
   ipivPtr <- Call.allocaArray n
   workPtr <- Call.allocaArray n
   liftIO $ do
      copyBlock triSize aPtr bPtr
      case conj of
         Conjugated -> do
            withInfo "hptrf" $ LapackGen.hptrf uploPtr nPtr bPtr ipivPtr
            withInfo "hptri" $ LapackGen.hptri uploPtr nPtr bPtr ipivPtr workPtr
         NonConjugated -> do
            withInfo "sptrf" $ LapackGen.sptrf uploPtr nPtr bPtr ipivPtr
            withInfo "sptri" $ LapackGen.sptri uploPtr nPtr bPtr ipivPtr workPtr


blockDiagonalPointers ::
   (Storable a) =>
   Order -> [(Ptr CInt, Ptr a)] -> LazyIO.T [(Ptr a, Maybe (Ptr a, Ptr a))]
blockDiagonalPointers order =
   let go ((ipiv0Ptr,a0Ptr):ptrs0) = do
         ipiv <- LazyIO.interleave $ peek ipiv0Ptr
         (ext,ptrTuples) <-
            if ipiv >= 0
               then (,) Nothing <$> go ptrs0
               else
                  case ptrs0 of
                     [] -> error "Symmetric.determinant: incomplete 2x2 block"
                     (_ipiv1Ptr,a1Ptr):ptrs1 ->
                        let bPtr =
                              case order of
                                 ColumnMajor -> advancePtr a1Ptr (-1)
                                 RowMajor -> advancePtr a0Ptr 1
                        in (,) (Just (a1Ptr,bPtr)) <$> go ptrs1
         return $ (a0Ptr,ext) : ptrTuples
       go [] = return []
   in go

determinant ::
   (Class.Floating a, Class.Floating ar) =>
   Conjugation -> ((Ptr a, Maybe (Ptr a, Ptr a)) -> IO ar) ->
   Order -> Int -> ForeignPtr a -> IO ar
determinant conj peekBlockDeterminant order n a = evalContT $ do
   uploPtr <- Call.char $ uploFromOrder order
   nPtr <- Call.cint n
   aPtr <- copyToTemp (triangleSize n) a
   ipivPtr <- Call.allocaArray n
   let (name,trf) =
         case conj of
            Conjugated -> ("hptrf", LapackGen.hptrf)
            NonConjugated -> ("sptrf", LapackGen.sptrf)
   liftIO $ withDeterminantInfo name
      (trf uploPtr nPtr aPtr ipivPtr)
      (((return $!) =<<) $
       LazyIO.run
         (fmap product $
          mapM (LazyIO.interleave . peekBlockDeterminant) =<<
          blockDiagonalPointers order
            (diagonalPointerPairs order n ipivPtr aPtr)))