{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.BandedHermitianPositiveDefinite.Linear (
   solve,
   solveDecomposed,
   decompose,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Linear.Private (solver)
import Numeric.LAPACK.Matrix.BandedHermitian.Basic (BandedHermitian)
import Numeric.LAPACK.Matrix.Hermitian.Private (Determinant(..))
import Numeric.LAPACK.Matrix.Triangular.Private (copyTriangleToTemp)
import Numeric.LAPACK.Matrix.Shape.Private (uploFromOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Scalar (RealOf, realPart)
import Numeric.LAPACK.Private (copyBlock, withInfo, rankMsg, definiteMsg)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary
import Type.Data.Num (integralFromProxy)
import Type.Base.Proxy (Proxy(Proxy))

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


solve ::
   (Unary.Natural offDiag, Shape.C size, Eq size,
    Extent.C vert, Extent.C horiz, Shape.C nrhs, Class.Floating a) =>
   BandedHermitian offDiag size a ->
   Full vert horiz size nrhs a -> Full vert horiz size nrhs a
solve :: BandedHermitian offDiag size a
-> Full vert horiz size nrhs a -> Full vert horiz size nrhs a
solve (Array (MatrixShape.BandedHermitian UnaryProxy offDiag
numOff Order
orderA size
shA) ForeignPtr a
a) =
   String
-> size
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz size nrhs a
-> Full vert horiz size nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"BandedHermitian.solve" size
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz size nrhs a -> Full vert horiz size nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz size nrhs a
-> Full vert horiz size nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
orderA
      let k :: Int
k = UnaryProxy offDiag -> Int
forall x y. (Integer x, Num y) => Proxy x -> y
integralFromProxy UnaryProxy offDiag
numOff
      let lda :: Int
lda = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
k
      Ptr a
aPtr <- Conjugation -> Order -> Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r.
Floating a =>
Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp Conjugation
Conjugated Order
orderA (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
lda) ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
definiteMsg String
"pbsv" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.pbsv Ptr CChar
uploPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr CInt
nrhsPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
ldxPtr

solveDecomposed ::
   (Unary.Natural offDiag, Shape.C size, Eq size,
    Extent.C vert, Extent.C horiz, Shape.C nrhs, Class.Floating a) =>
   Banded.Upper offDiag size a ->
   Full vert horiz size nrhs a -> Full vert horiz size nrhs a
solveDecomposed :: Upper offDiag size a
-> Full vert horiz size nrhs a -> Full vert horiz size nrhs a
solveDecomposed (Array (MatrixShape.Banded (UnaryProxy U0
_zero,UnaryProxy offDiag
numOff) Order
orderA Extent Small Small size size
shA) ForeignPtr a
a) =
   String
-> size
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz size nrhs a
-> Full vert horiz size nrhs a
forall vert horiz height width a.
(C vert, C horiz, C height, C width, Eq height, Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz height width a
-> Full vert horiz height width a
solver String
"BandedHermitian.solveDecomposed" (Extent Small Small size size -> size
forall height width. Extent Small Small height width -> height
Extent.squareSize Extent Small Small size size
shA) ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full vert horiz size nrhs a -> Full vert horiz size nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full vert horiz size nrhs a
-> Full vert horiz size nrhs a
forall a b. (a -> b) -> a -> b
$
         \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
orderA
      let k :: Int
k = UnaryProxy offDiag -> Int
forall x y. (Integer x, Num y) => Proxy x -> y
integralFromProxy UnaryProxy offDiag
numOff
      let lda :: Int
lda = Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
k
      Ptr a
aPtr <- Conjugation -> Order -> Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r.
Floating a =>
Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp Conjugation
Conjugated Order
orderA (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
lda) ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
rankMsg String
"pbtrs" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.pbtrs Ptr CChar
uploPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr CInt
nrhsPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
ldxPtr


decompose ::
   (Unary.Natural offDiag, Shape.C size, Class.Floating a) =>
   BandedHermitian offDiag size a -> Banded.Upper offDiag size a
decompose :: BandedHermitian offDiag size a -> Upper offDiag size a
decompose (Array (MatrixShape.BandedHermitian UnaryProxy offDiag
numOff Order
order size
sh) ForeignPtr a
a) =
   Banded U0 offDiag Small Small size size
-> (Int -> Ptr a -> IO ()) -> Upper offDiag size a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize
      ((UnaryProxy U0, UnaryProxy offDiag)
-> Order -> size -> Banded U0 offDiag Small Small size size
forall sub super size.
(UnaryProxy sub, UnaryProxy super)
-> Order -> size -> Banded sub super Small Small size size
MatrixShape.bandedSquare (UnaryProxy U0
forall a. Proxy a
Proxy,UnaryProxy offDiag
numOff) Order
order size
sh) ((Int -> Ptr a -> IO ()) -> Upper offDiag size a)
-> (Int -> Ptr a -> IO ()) -> Upper offDiag size a
forall a b. (a -> b) -> a -> b
$ \Int
bSize Ptr a
bPtr -> do
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      let k :: Int
k = UnaryProxy offDiag -> Int
forall x y. (Integer x, Num y) => Proxy x -> y
integralFromProxy UnaryProxy offDiag
numOff
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder Order
order
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ size -> Int
forall sh. C sh => sh -> Int
Shape.size size
sh
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
k
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim (Int -> FortranIO () (Ptr CInt)) -> Int -> FortranIO () (Ptr CInt)
forall a b. (a -> b) -> a -> b
$ Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
bSize Ptr a
aPtr Ptr a
bPtr
         String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
definiteMsg String
"pbtrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.pbtrf Ptr CChar
uploPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
bPtr Ptr CInt
ldbPtr


determinant ::
   (Unary.Natural offDiag, Shape.C size, Class.Floating a) =>
   BandedHermitian offDiag size a -> RealOf a
determinant :: BandedHermitian offDiag size a -> RealOf a
determinant =
   Determinant (Array (BandedHermitian offDiag size)) a
-> BandedHermitian offDiag size a -> RealOf a
forall (f :: * -> *) a. Determinant f a -> f a -> RealOf a
getDeterminant (Determinant (Array (BandedHermitian offDiag size)) a
 -> BandedHermitian offDiag size a -> RealOf a)
-> Determinant (Array (BandedHermitian offDiag size)) a
-> BandedHermitian offDiag size a
-> RealOf a
forall a b. (a -> b) -> a -> b
$
   Determinant (Array (BandedHermitian offDiag size)) Float
-> Determinant (Array (BandedHermitian offDiag size)) Double
-> Determinant
     (Array (BandedHermitian offDiag size)) (Complex Float)
-> Determinant
     (Array (BandedHermitian offDiag size)) (Complex Double)
-> Determinant (Array (BandedHermitian offDiag size)) a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Array (BandedHermitian offDiag size) Float -> RealOf Float)
-> Determinant (Array (BandedHermitian offDiag size)) Float
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (BandedHermitian offDiag size) Float -> RealOf Float
forall offDiag size a ar.
(Natural offDiag, C size, Floating a, RealOf a ~ ar, Real ar) =>
BandedHermitian offDiag size a -> ar
determinantAux) ((Array (BandedHermitian offDiag size) Double -> RealOf Double)
-> Determinant (Array (BandedHermitian offDiag size)) Double
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (BandedHermitian offDiag size) Double -> RealOf Double
forall offDiag size a ar.
(Natural offDiag, C size, Floating a, RealOf a ~ ar, Real ar) =>
BandedHermitian offDiag size a -> ar
determinantAux)
      ((Array (BandedHermitian offDiag size) (Complex Float)
 -> RealOf (Complex Float))
-> Determinant
     (Array (BandedHermitian offDiag size)) (Complex Float)
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (BandedHermitian offDiag size) (Complex Float)
-> RealOf (Complex Float)
forall offDiag size a ar.
(Natural offDiag, C size, Floating a, RealOf a ~ ar, Real ar) =>
BandedHermitian offDiag size a -> ar
determinantAux) ((Array (BandedHermitian offDiag size) (Complex Double)
 -> RealOf (Complex Double))
-> Determinant
     (Array (BandedHermitian offDiag size)) (Complex Double)
forall (f :: * -> *) a. (f a -> RealOf a) -> Determinant f a
Determinant Array (BandedHermitian offDiag size) (Complex Double)
-> RealOf (Complex Double)
forall offDiag size a ar.
(Natural offDiag, C size, Floating a, RealOf a ~ ar, Real ar) =>
BandedHermitian offDiag size a -> ar
determinantAux)

determinantAux ::
   (Unary.Natural offDiag, Shape.C size,
    Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   BandedHermitian offDiag size a -> ar
determinantAux :: BandedHermitian offDiag size a -> ar
determinantAux =
   (ar -> Int -> ar
forall a b. (Num a, Integral b) => a -> b -> a
^(Int
2::Int)) (ar -> ar)
-> (BandedHermitian offDiag size a -> ar)
-> BandedHermitian offDiag size a
-> ar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ar] -> ar
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([ar] -> ar)
-> (BandedHermitian offDiag size a -> [ar])
-> BandedHermitian offDiag size a
-> ar
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> ar) -> [a] -> [ar]
forall a b. (a -> b) -> [a] -> [b]
map a -> ar
forall a. Floating a => a -> RealOf a
realPart ([a] -> [ar])
-> (BandedHermitian offDiag size a -> [a])
-> BandedHermitian offDiag size a
-> [ar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array size a -> [a]
forall sh a. (C sh, Storable a) => Array sh a -> [a]
Array.toList (Array size a -> [a])
-> (BandedHermitian offDiag size a -> Array size a)
-> BandedHermitian offDiag size a
-> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
   Square U0 offDiag size a -> Array size a
forall sub super sh a.
(Natural sub, Natural super, C sh, Floating a) =>
Square sub super sh a -> Vector sh a
Banded.takeDiagonal (Square U0 offDiag size a -> Array size a)
-> (BandedHermitian offDiag size a -> Square U0 offDiag size a)
-> BandedHermitian offDiag size a
-> Array size a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BandedHermitian offDiag size a -> Square U0 offDiag size a
forall offDiag size a.
(Natural offDiag, C size, Floating a) =>
BandedHermitian offDiag size a -> Upper offDiag size a
decompose