{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Triangular.Linear (
   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Linear.Private (solver, diagonalMsg)
import Numeric.LAPACK.Matrix.Mosaic.Private
         (withPackingLinear, label, applyFuncPair, triArg)
import Numeric.LAPACK.Matrix.Mosaic.Basic (takeDiagonal)
import Numeric.LAPACK.Matrix.Shape.Omni (TriDiag, DiagSingleton, charFromTriDiag)
import Numeric.LAPACK.Matrix.Layout.Private
         (transposeFromOrder, uploFromOrder, uploOrder)
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Private (copyBlock, copyToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)



type Triangular pack uplo sh =
      Array (Layout.Mosaic pack Layout.NoMirror uplo sh)


solve ::
   (Layout.UpLo uplo, TriDiag diag,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   DiagSingleton diag ->
   Triangular pack uplo sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve :: DiagSingleton diag
-> Triangular pack uplo sh a
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
solve DiagSingleton diag
diag
   (Array
      shape :: Mosaic pack NoMirror uplo sh
shape@(Layout.Mosaic PackingSingleton pack
pack MirrorSingleton NoMirror
Layout.NoMirror UpLoSingleton uplo
uplo Order
orderA sh
shA)
      ForeignPtr a
a) =

   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall meas vert horiz height width a.
(Measure meas, C vert, C horiz, C height, C width, Eq height,
 Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
solver String
"Triangular.solve" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder (Order -> Char) -> Order -> Char
forall a b. (a -> b) -> a -> b
$ UpLoSingleton uplo -> Order -> Order
forall uplo. UpLoSingleton uplo -> Order -> Order
uploOrder UpLoSingleton uplo
uplo Order
orderA
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
orderA
      Ptr CChar
diagPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ DiagSingleton diag -> Char
forall diag. TriDiag diag => DiagSingleton diag -> Char
charFromTriDiag DiagSingleton diag
diag
      Ptr a
aPtr <- Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp (Mosaic pack NoMirror uplo sh -> Int
forall sh. C sh => sh -> Int
Shape.size Mosaic pack NoMirror uplo sh
shape) ForeignPtr a
a
      String
-> PackingSingleton pack
-> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
-> ContT () IO ()
forall func pack r.
(func ~ (Ptr CInt -> IO ())) =>
String
-> PackingSingleton pack
-> Labelled2 r String func func
-> ContT r IO ()
withPackingLinear String
diagonalMsg PackingSingleton pack
pack (Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
 -> ContT () IO ())
-> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
-> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Labelled
  ()
  String
  (FuncPacked
     (Ptr CChar
      -> Ptr CChar
      -> Ptr CChar
      -> Ptr CInt
      -> Ptr CInt
      -> TriArg a
      -> Ptr a
      -> Ptr CInt
      -> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())))
-> Labelled
     ()
     String
     (FuncUnpacked
        (Ptr CChar
         -> Ptr CChar
         -> Ptr CChar
         -> Ptr CInt
         -> Ptr CInt
         -> TriArg a
         -> Ptr a
         -> Ptr CInt
         -> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())))
-> Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> TriArg a
-> Ptr a
-> Ptr CInt
-> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
forall (m :: * -> *) f.
(m ~ Labelled (FuncCont f) (FuncLabel f), FunctionPair f) =>
m (FuncPacked f) -> m (FuncUnpacked f) -> f
applyFuncPair
            (String
-> (Ptr CChar
    -> Ptr CChar
    -> Ptr CChar
    -> Ptr CInt
    -> Ptr CInt
    -> Ptr a
    -> Ptr a
    -> Ptr CInt
    -> Ptr CInt
    -> IO ())
-> Labelled
     ()
     String
     (Ptr CChar
      -> Ptr CChar
      -> Ptr CChar
      -> Ptr CInt
      -> Ptr CInt
      -> Ptr a
      -> Ptr a
      -> Ptr CInt
      -> Ptr CInt
      -> IO ())
forall label a r. label -> a -> Labelled r label a
label String
"tptrs" Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.tptrs) (String
-> (Ptr CChar
    -> Ptr CChar
    -> Ptr CChar
    -> Ptr CInt
    -> Ptr CInt
    -> Ptr a
    -> Ptr CInt
    -> Ptr a
    -> Ptr CInt
    -> Ptr CInt
    -> IO ())
-> Labelled
     ()
     String
     (Ptr CChar
      -> Ptr CChar
      -> Ptr CChar
      -> Ptr CInt
      -> Ptr CInt
      -> Ptr a
      -> Ptr CInt
      -> Ptr a
      -> Ptr CInt
      -> Ptr CInt
      -> IO ())
forall label a r. label -> a -> Labelled r label a
label String
"trtrs" Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.trtrs)
            Ptr CChar
uploPtr Ptr CChar
transPtr Ptr CChar
diagPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr
            (Ptr a -> Int -> TriArg a
forall a. Ptr a -> Int -> TriArg a
triArg Ptr a
aPtr Int
n) Ptr a
xPtr Ptr CInt
ldxPtr


inverse ::
   (Layout.UpLo uplo, TriDiag diag, Shape.C sh, Class.Floating a) =>
   DiagSingleton diag ->
   Triangular pack uplo sh a -> Triangular pack uplo sh a
inverse :: DiagSingleton diag
-> Triangular pack uplo sh a -> Triangular pack uplo sh a
inverse DiagSingleton diag
diag
   (Array shape :: Mosaic pack NoMirror uplo sh
shape@(Layout.Mosaic PackingSingleton pack
pack MirrorSingleton NoMirror
Layout.NoMirror UpLoSingleton uplo
uplo Order
order sh
sh) ForeignPtr a
a)
      = Mosaic pack NoMirror uplo sh
-> (Int -> Ptr a -> IO ()) -> Triangular pack uplo sh a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize Mosaic pack NoMirror uplo sh
shape ((Int -> Ptr a -> IO ()) -> Triangular pack uplo sh a)
-> (Int -> Ptr a -> IO ()) -> Triangular pack uplo sh a
forall a b. (a -> b) -> a -> b
$ \Int
triSize Ptr a
bPtr ->
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
uploFromOrder (Order -> Char) -> Order -> Char
forall a b. (a -> b) -> a -> b
$ UpLoSingleton uplo -> Order -> Order
forall uplo. UpLoSingleton uplo -> Order -> Order
uploOrder UpLoSingleton uplo
uplo Order
order
      Ptr CChar
diagPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ DiagSingleton diag -> Char
forall diag. TriDiag diag => DiagSingleton diag -> Char
charFromTriDiag DiagSingleton diag
diag
      let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
triSize Ptr a
aPtr Ptr a
bPtr
      String
-> PackingSingleton pack
-> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
-> ContT () IO ()
forall func pack r.
(func ~ (Ptr CInt -> IO ())) =>
String
-> PackingSingleton pack
-> Labelled2 r String func func
-> ContT r IO ()
withPackingLinear String
diagonalMsg PackingSingleton pack
pack (Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
 -> ContT () IO ())
-> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
-> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Labelled
  ()
  String
  (FuncPacked
     (Ptr CChar
      -> Ptr CChar
      -> Ptr CInt
      -> TriArg a
      -> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())))
-> Labelled
     ()
     String
     (FuncUnpacked
        (Ptr CChar
         -> Ptr CChar
         -> Ptr CInt
         -> TriArg a
         -> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())))
-> Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> TriArg a
-> Labelled2 () String (Ptr CInt -> IO ()) (Ptr CInt -> IO ())
forall (m :: * -> *) f.
(m ~ Labelled (FuncCont f) (FuncLabel f), FunctionPair f) =>
m (FuncPacked f) -> m (FuncUnpacked f) -> f
applyFuncPair
            (String
-> (Ptr CChar
    -> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ())
-> Labelled
     ()
     String
     (Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ())
forall label a r. label -> a -> Labelled r label a
label String
"tptri" Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
LapackGen.tptri) (String
-> (Ptr CChar
    -> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ())
-> Labelled
     ()
     String
     (Ptr CChar
      -> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ())
forall label a r. label -> a -> Labelled r label a
label String
"trtri" Ptr CChar
-> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.trtri)
            Ptr CChar
uploPtr Ptr CChar
diagPtr Ptr CInt
nPtr (Ptr a -> Int -> TriArg a
forall a. Ptr a -> Int -> TriArg a
triArg Ptr a
bPtr Int
n)


determinant ::
   (Layout.UpLo uplo, Shape.C sh, Class.Floating a) =>
   Triangular pack uplo sh a -> a
determinant :: Triangular pack uplo sh a -> a
determinant = Vector sh a -> a
forall sh a. (C sh, Floating a) => Vector sh a -> a
Vector.product (Vector sh a -> a)
-> (Triangular pack uplo sh a -> Vector sh a)
-> Triangular pack uplo sh a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Triangular pack uplo sh a -> Vector sh a
forall uplo sh a pack mirror.
(UpLo uplo, C sh, Floating a) =>
Mosaic pack mirror uplo sh a -> Vector sh a
takeDiagonal