{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.HermitianPositiveDefinite.Linear (
   solve,
   solveDecomposed,
   inverse,
   decompose,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Symmetric.Unified as Symmetric
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Hermitian.Private (Determinant(..))
import Numeric.LAPACK.Matrix.Mosaic.Private
         (withPackingLinear, label, applyFuncPair, triArg, copyTriangleToTemp)
import Numeric.LAPACK.Matrix.Mosaic.Basic (takeDiagonal)
import Numeric.LAPACK.Matrix.Layout.Private (uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver)
import Numeric.LAPACK.Scalar (RealOf, realPart, zero)
import Numeric.LAPACK.Private
         (copySubTrapezoid, copyBlock, fill, rankMsg, definiteMsg)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


type Hermitian pack sh = Array (Layout.HermitianP pack sh)
type Upper pack sh = Array (Layout.UpperTriangularP pack sh)

solve ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Hermitian pack sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve (Array shape@(Layout.Mosaic pack _mirror _upper orderA shA) a) =
   solver "Hermitian.solve" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      uploPtr <- Call.char $ uploFromOrder orderA
      aPtr <- copyTriangleToTemp Conjugated orderA (Shape.size shape) a
      withPackingLinear definiteMsg pack $
         applyFuncPair
            (label "ppsv" LapackGen.ppsv)
            (label "posv" LapackGen.posv)
            uploPtr nPtr nrhsPtr (triArg aPtr n) xPtr ldxPtr

solveDecomposed ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Upper pack sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solveDecomposed
   (Array
      shape@(Layout.Mosaic pack Layout.NoMirror _upper orderA shA)
      a) =
   solver "Hermitian.solveDecomposed" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      uploPtr <- Call.char $ uploFromOrder orderA
      aPtr <- copyTriangleToTemp Conjugated orderA (Shape.size shape) a
      withPackingLinear rankMsg pack $
         applyFuncPair
            (label "pptrs" LapackGen.pptrs)
            (label "potrs" LapackGen.potrs)
            uploPtr nPtr nrhsPtr (triArg aPtr n) xPtr ldxPtr


inverse ::
   (Shape.C sh, Class.Floating a) => Hermitian pack sh a -> Hermitian pack sh a
inverse
   (Array shape@(Layout.Mosaic pack _mirror _upper order sh) a) =
      Array.unsafeCreateWithSize shape $ \triSize bPtr -> do

   let n = Shape.size sh
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      liftIO $ copyBlock triSize aPtr bPtr
      withPackingLinear definiteMsg pack $
         applyFuncPair
            (label "pptrf" LapackGen.pptrf)
            (label "potrf" LapackGen.potrf)
            uploPtr nPtr (triArg bPtr n)
      withPackingLinear rankMsg pack $
         applyFuncPair
            (label "pptri" LapackGen.pptri)
            (label "potri" LapackGen.potri)
            uploPtr nPtr (triArg bPtr n)
   Symmetric.complement pack Conjugated order n bPtr

decompose ::
   (Shape.C sh, Class.Floating a) => Hermitian pack sh a -> Upper pack sh a
decompose
   (Array (Layout.Mosaic pack _mirror upper order sh) a) =
      Array.unsafeCreateWithSize
         (Layout.Mosaic pack Layout.NoMirror upper order sh) $
      \triSize bPtr -> do
   evalContT $ do
      let uplo = uploFromOrder order
      uploPtr <- Call.char uplo
      let n = Shape.size sh
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      let packed =
            case pack of Layout.Packed -> True; Layout.Unpacked -> False
      liftIO $
         if packed
            then copyBlock triSize aPtr bPtr
            else do
               fill zero (n*n) bPtr
               copySubTrapezoid uplo n n n aPtr n bPtr
      withPackingLinear definiteMsg pack $
         applyFuncPair
            (label "pptrf" LapackGen.pptrf) (label "potrf" LapackGen.potrf)
            uploPtr nPtr (triArg bPtr n)


determinant :: (Shape.C sh, Class.Floating a) => Hermitian pack sh a -> RealOf a
determinant =
   getDeterminant $
   Class.switchFloating
      (Determinant determinantAux) (Determinant determinantAux)
      (Determinant determinantAux) (Determinant determinantAux)

determinantAux ::
   (Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   Hermitian pack sh a -> ar
determinantAux =
   (^(2::Int)) . Vector.product . Array.map realPart . takeDiagonal . decompose