{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Triangular.Linear (
   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Linear.Private (solver, diagonalMsg)
import Numeric.LAPACK.Matrix.Mosaic.Private
         (withPackingLinear, label, applyFuncPair, triArg)
import Numeric.LAPACK.Matrix.Mosaic.Basic (takeDiagonal)
import Numeric.LAPACK.Matrix.Shape.Omni (TriDiag, DiagSingleton, charFromTriDiag)
import Numeric.LAPACK.Matrix.Layout.Private
         (transposeFromOrder, uploFromOrder, uploOrder)
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)



type Triangular pack uplo sh =
      Array (Layout.Mosaic pack Layout.NoMirror uplo sh)


solve ::
   (Layout.UpLo uplo, TriDiag diag,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   DiagSingleton diag ->
   Triangular pack uplo sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve diag
   (Array
      shape@(Layout.Mosaic pack Layout.NoMirror uplo orderA shA)
      a) =

   solver "Triangular.solve" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo orderA
      transPtr <- Call.char $ transposeFromOrder orderA
      diagPtr <- Call.char $ charFromTriDiag diag
      aPtr <- copyToTemp (Shape.size shape) a
      withPackingLinear diagonalMsg pack $
         applyFuncPair
            (label "tptrs" LapackGen.tptrs) (label "trtrs" LapackGen.trtrs)
            uploPtr transPtr diagPtr nPtr nrhsPtr
            (triArg aPtr n) xPtr ldxPtr


inverse ::
   (Layout.UpLo uplo, TriDiag diag, Shape.C sh, Class.Floating a) =>
   DiagSingleton diag ->
   Triangular pack uplo sh a -> Triangular pack uplo sh a
inverse diag
   (Array shape@(Layout.Mosaic pack Layout.NoMirror uplo order sh) a)
      = Array.unsafeCreateWithSize shape $ \triSize bPtr ->
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo order
      diagPtr <- Call.char $ charFromTriDiag diag
      let n = Shape.size sh
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      liftIO $ copyBlock triSize aPtr bPtr
      withPackingLinear diagonalMsg pack $
         applyFuncPair
            (label "tptri" LapackGen.tptri) (label "trtri" LapackGen.trtri)
            uploPtr diagPtr nPtr (triArg bPtr n)


determinant ::
   (Layout.UpLo uplo, Shape.C sh, Class.Floating a) =>
   Triangular pack uplo sh a -> a
determinant = Vector.product . takeDiagonal