{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ConstraintKinds #-}
module Numeric.LAPACK.Matrix.Triangular.Linear (
   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Banded.Linear as BandedLin
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Symmetric.Private as Symmetric
import qualified Numeric.LAPACK.Matrix.Triangular.Private as Tri
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Linear.Private (solver, withInfo)
import Numeric.LAPACK.Matrix.Triangular.Basic
         (Triangular, Symmetric, PowerDiag, takeDiagonal, strictNonUnitDiagonal)
import Numeric.LAPACK.Matrix.Shape.Private
         (transposeFromOrder, uploFromOrder, uploOrder, charFromTriDiag)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape (triangleSize)

import System.IO.Unsafe (unsafePerformIO)

import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (peek)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


solve ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve =
   Tri.getMultiplyRight $
   MatrixShape.switchDiagUpLoSym
      (Tri.MultiplyRight $ BandedLin.solve . Banded.fromDiagonal)
      (Tri.MultiplyRight solveTriangular)
      (Tri.MultiplyRight solveTriangular)
      (Tri.MultiplyRight $ solveSymmetric . strictNonUnitDiagonal)

solveTriangular ::
   (MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solveTriangular (Array (MatrixShape.Triangular diag uplo orderA shA) a) =
   solver "Triangular.solve" shA $ \n nPtr nrhsPtr xPtr ldxPtr -> do
      uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo orderA
      transPtr <- Call.char $ transposeFromOrder orderA
      diagPtr <- Call.char $ charFromTriDiag diag
      apPtr <- copyToTemp (triangleSize n) a
      liftIO $
         withInfo "tptrs" $
            LapackGen.tptrs uploPtr transPtr diagPtr
               nPtr nrhsPtr apPtr xPtr ldxPtr

solveSymmetric ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Symmetric sh a ->
   Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solveSymmetric (Array (MatrixShape.Triangular _diag _uplo orderA shA) a) =
   Symmetric.solve "Symmetric.solve" NonConjugated orderA shA a


inverse ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a ->
   Triangular lo (PowerDiag lo up diag) up sh a
inverse =
   Tri.getPower $
   MatrixShape.switchDiagUpLoSym
      (Tri.Power inverseDiagonal)
      (Tri.Power inverseTriangular)
      (Tri.Power inverseTriangular)
      (Tri.Power $ inverseSymmetric . strictNonUnitDiagonal)

inverseDiagonal ::
   (MatrixShape.TriDiag diag, Shape.C sh, Class.Floating a) =>
   Tri.FlexDiagonal diag sh a -> Tri.FlexDiagonal diag sh a
inverseDiagonal a =
   MatrixShape.caseTriDiag
      (MatrixShape.triangularDiag $ Array.shape a)
      a (Vector.recip a)

inverseTriangular ::
   (MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> Triangular lo diag up sh a
inverseTriangular (Array shape@(MatrixShape.Triangular diag uplo order sh) a) =
      Array.unsafeCreateWithSize shape $ \triSize bPtr ->
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder $ uploOrder uplo order
      diagPtr <- Call.char $ charFromTriDiag diag
      nPtr <- Call.cint $ Shape.size sh
      aPtr <- ContT $ withForeignPtr a
      liftIO $ do
         copyBlock triSize aPtr bPtr
         withInfo "tptri" $ LapackGen.tptri uploPtr diagPtr nPtr bPtr

inverseSymmetric ::
   (Shape.C sh, Class.Floating a) => Symmetric sh a -> Symmetric sh a
inverseSymmetric (Array shape@(MatrixShape.Triangular _diag _uplo order sh) a) =
   Array.unsafeCreateWithSize shape $
      Symmetric.inverse NonConjugated order (Shape.size sh) a


determinant ::
   (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> a
determinant =
   Tri.getMultiplyRight $
   MatrixShape.switchDiagUpLoSym
      (Tri.MultiplyRight determinantTriangular)
      (Tri.MultiplyRight determinantTriangular)
      (Tri.MultiplyRight determinantTriangular)
      (Tri.MultiplyRight $ determinantSymmetric . strictNonUnitDiagonal)

determinantTriangular ::
   (MatrixShape.DiagUpLo lo up, Shape.C sh, Class.Floating a) =>
   Triangular lo diag up sh a -> a
determinantTriangular = product . Array.toList . takeDiagonal

determinantSymmetric ::
   (Shape.C sh, Class.Floating a) => Symmetric sh a -> a
determinantSymmetric (Array (MatrixShape.Triangular _diag _uplo order sh) a) =
   unsafePerformIO $
      Symmetric.determinant NonConjugated
         peekBlockDeterminant order (Shape.size sh) a

peekBlockDeterminant ::
   (Class.Floating a) => (Ptr a, Maybe (Ptr a, Ptr a)) -> IO a
peekBlockDeterminant (a0Ptr,ext) = do
   a0 <- peek a0Ptr
   case ext of
      Nothing -> return a0
      Just (a1Ptr,bPtr) -> do
         a1 <- peek a1Ptr
         b <- peek bPtr
         return (a0*a1 - b*b)