{-# LANGUAGE ViewPatterns        #-}
--------------------------------------------------------------------------------
-- |
-- Module      : ArrayFire.LAPACK
-- Copyright   : David Johnson (c) 2019-2020
-- License     : BSD 3
-- Maintainer  : David Johnson <djohnson.m@gmail.com>
-- Stability   : Experimental
-- Portability : GHC
--
-- LAPACK — Linear Algebra PACKage
--
-- @
-- >>> (u,e,d) = svd (constant @Double [3,3] 10)
-- >>> u
-- ArrayFire Array
-- [3 3 1 1]
--    -0.5774     0.8165    -0.0000
--    -0.5774    -0.4082    -0.7071
--    -0.5774    -0.4082     0.7071
--
-- >>> e
-- ArrayFire Array
-- [3 1 1 1]
--    30.0000
--     0.0000
--     0.0000
--
-- >>> d
-- ArrayFire Array
-- [3 3 1 1]
--   -0.5774    -0.5774    -0.5774
--   -0.8165     0.4082     0.4082
--   -0.0000     0.7071    -0.7071
--
-- @
--------------------------------------------------------------------------------
module ArrayFire.LAPACK where

import ArrayFire.Internal.LAPACK
import ArrayFire.FFI
import ArrayFire.Types
import ArrayFire.Internal.Types

-- | Singular Value Decomposition
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__svd.htm)
--
-- The arrayfire function only returns the non zero diagonal elements of S.
--
svd
  :: AFType a
  => Array a
  -- ^ the input Matrix
  -> (Array a, Array a, Array a)
  -- ^ Output 'Array' containing (U, diagonal values of sigma, V^H)
svd :: forall a. AFType a => Array a -> (Array a, Array a, Array a)
svd = (Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
forall a.
Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
`op3p` Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
af_svd)

-- | Singular Value Decomposition (in-place)
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__svd.htm)
--
-- The arrayfire function only returns the non zero diagonal elements of S.
--
svdInPlace
  :: AFType a
  => Array a
  -- ^ the input matrix
  -> (Array a, Array a, Array a)
  -- ^ Output 'Array' containing (U, diagonal values of sigma, V^H)
svdInPlace :: forall a. AFType a => Array a -> (Array a, Array a, Array a)
svdInPlace = (Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
forall a.
Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
`op3p` Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
af_svd_inplace)

-- | Perform LU decomposition
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__lu.htm)
--
-- C Interface for LU decomposition.
--
lu
  :: AFType a
  => Array a
  -- ^ is the input matrix
  -> (Array a, Array a, Array a)
  -- ^ Returns the output 'Array's (lower, upper, pivot)
lu :: forall a. AFType a => Array a -> (Array a, Array a, Array a)
lu = (Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
forall a.
Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
`op3p` Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
af_lu)

-- | Perform LU decomposition (in-place).
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__lu.htm#ga0adcdc4b189c34644a7153c6ce9c4f7f)
--
-- C Interface for in place LU decomposition.
--
luInPlace
  :: AFType a
  => Array a
  -- ^ contains the input on entry, the packed LU decomposition on exit.
  -> Bool
  -- ^ specifies if the pivot is returned in original LAPACK compliant format
  -> Array a
  -- ^ will contain the permutation indices to map the input to the decomposition
luInPlace :: forall a. AFType a => Array a -> Bool -> Array a
luInPlace Array a
a (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
b) = Array a
a Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
forall a.
Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
`op1` (\Ptr AFArray
x AFArray
y -> Ptr AFArray -> AFArray -> CBool -> IO AFErr
af_lu_inplace Ptr AFArray
x AFArray
y CBool
b)

-- | Perform QR decomposition
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__qr.htm)
--
-- C Interface for QR decomposition.
--
qr
  :: AFType a
  => Array a
  -- ^ the input matrix
  -> (Array a, Array a, Array a)
  -- ^ Returns (q, r, tau) 'Array's
  -- /q/ is the orthogonal matrix from QR decomposition
  -- /r/ is the upper triangular matrix from QR decomposition
  -- /tau/ will contain additional information needed for solving a least squares problem using /q/ and /r/
qr :: forall a. AFType a => Array a -> (Array a, Array a, Array a)
qr = (Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
forall a.
Array a
-> (Ptr AFArray
    -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr)
-> (Array a, Array a, Array a)
`op3p` Ptr AFArray -> Ptr AFArray -> Ptr AFArray -> AFArray -> IO AFErr
af_qr)

-- | Perform QR decomposition
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__qr.htm)
--
-- C Interface for QR decomposition.
--
qrInPlace
  :: AFType a
  => Array a
  -- ^ is the input matrix on entry. It contains packed QR decomposition on exit
  -> Array a
  -- ^ will contain additional information needed for unpacking the data
qrInPlace :: forall a. AFType a => Array a -> Array a
qrInPlace = (Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
forall a.
Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
`op1` Ptr AFArray -> AFArray -> IO AFErr
af_qr_inplace)

-- | Perform Cholesky Decomposition
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__cholesky.htm)
--
-- This function decomposes a positive definite matrix A into two triangular matrices.
--
cholesky
  :: AFType a
  => Array a
  -- ^ input 'Array'
  -> Bool
  -- ^ a boolean determining if out is upper or lower triangular
  -> (Int, Array a)
  -- ^ contains the triangular matrix. Multiply 'Int' with its conjugate transpose reproduces the input array.
  -- is 0 if cholesky decomposition passes, if not it returns the rank at which the decomposition failed.
cholesky :: forall a. AFType a => Array a -> Bool -> (Int, Array a)
cholesky Array a
a (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
b) = do
  let (CInt
x',Array a
y') = Array a
-> (Ptr AFArray -> Ptr CInt -> AFArray -> IO AFErr)
-> (CInt, Array a)
forall b a.
Storable b =>
Array a
-> (Ptr AFArray -> Ptr b -> AFArray -> IO AFErr) -> (b, Array a)
op1b Array a
a (\Ptr AFArray
x Ptr CInt
y AFArray
z -> Ptr AFArray -> Ptr CInt -> AFArray -> CBool -> IO AFErr
af_cholesky Ptr AFArray
x Ptr CInt
y AFArray
z CBool
b)
  (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
x', Array a
y')

-- | Perform Cholesky Decomposition
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__cholesky.htm)
--
-- C Interface for in place cholesky decomposition.
--
choleskyInplace
  :: AFType a
  => Array a
  -- ^ is the input matrix on entry. It contains the triangular matrix on exit.
  -> Bool
  -- ^ a boolean determining if in is upper or lower triangular
  -> Int
  -- ^ is 0 if cholesky decomposition passes, if not it returns the rank at which the decomposition failed.
choleskyInplace :: forall a. AFType a => Array a -> Bool -> Int
choleskyInplace Array a
a (Int -> CBool
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CBool) -> (Bool -> Int) -> Bool -> CBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> Int
forall a. Enum a => a -> Int
fromEnum -> CBool
b) =
  CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> CInt -> Int
forall a b. (a -> b) -> a -> b
$ Array a -> (Ptr CInt -> AFArray -> IO AFErr) -> CInt
forall a b.
Storable a =>
Array b -> (Ptr a -> AFArray -> IO AFErr) -> a
infoFromArray Array a
a (\Ptr CInt
x AFArray
y -> Ptr CInt -> AFArray -> CBool -> IO AFErr
af_cholesky_inplace Ptr CInt
x AFArray
y CBool
b)

-- | Solve a system of equations
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__solve__func__gen.htm)
--
solve
  :: AFType a
  => Array a
  -- ^ is the coefficient matrix
  -> Array a
  -- ^ is the measured values
  -> MatProp
  -- ^ determining various properties of matrix a
  -> Array a
  -- ^ is the matrix of unknown variables
solve :: forall a. AFType a => Array a -> Array a -> MatProp -> Array a
solve Array a
a Array a
b MatProp
m =
  Array a
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
forall b a.
Array b
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
op2 Array a
a Array a
b (\Ptr AFArray
x AFArray
y AFArray
z -> Ptr AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr
af_solve Ptr AFArray
x AFArray
y AFArray
z (MatProp -> AFMatProp
toMatProp MatProp
m))

-- | Solve a system of equations.
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__solve__lu__func__gen.htm)
--
solveLU
  :: AFType a
  => Array a
  -- ^ is the output matrix from packed LU decomposition of the coefficient matrix
  -> Array a
  -- ^ is the pivot array from packed LU decomposition of the coefficient matrix
  -> Array a
  -- ^ is the matrix of measured values
  -> MatProp
  -- ^ determining various properties of matrix a
  -> Array a
  -- ^ will contain the matrix of unknown variables
solveLU :: forall a.
AFType a =>
Array a -> Array a -> Array a -> MatProp -> Array a
solveLU Array a
a Array a
b Array a
c MatProp
m =
  Array a
-> Array a
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
forall b a.
Array b
-> Array a
-> Array a
-> (Ptr AFArray -> AFArray -> AFArray -> AFArray -> IO AFErr)
-> Array a
op3 Array a
a Array a
b Array a
c (\Ptr AFArray
x AFArray
y AFArray
z AFArray
w -> Ptr AFArray
-> AFArray -> AFArray -> AFArray -> AFMatProp -> IO AFErr
af_solve_lu Ptr AFArray
x AFArray
y AFArray
z AFArray
w (MatProp -> AFMatProp
toMatProp MatProp
m))

-- | Invert a matrix.
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__ops__func__inv.htm)
--
-- C Interface for inverting a matrix.
--
inverse
  :: AFType a
  => Array a
  -- ^ is input matrix
  -> MatProp
  -- ^ determining various properties of matrix in
  -> Array a
  -- ^ will contain the inverse of matrix in
inverse :: forall a. AFType a => Array a -> MatProp -> Array a
inverse Array a
a MatProp
m =
  Array a
a Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
forall a.
Array a -> (Ptr AFArray -> AFArray -> IO AFErr) -> Array a
`op1` (\Ptr AFArray
x AFArray
y  -> Ptr AFArray -> AFArray -> AFMatProp -> IO AFErr
af_inverse Ptr AFArray
x AFArray
y (MatProp -> AFMatProp
toMatProp MatProp
m))

-- | Find the rank of the input matrix
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__factor__func__rank.htm)
--
-- This function uses af::qr to find the rank of the input matrix within the given tolerance.
--
rank
  :: AFType a
  => Array a
  -- ^ is input matrix
  -> Double
  -- ^ is the tolerance value
  -> Int
  -- ^ will contain the rank of in
rank :: forall a. AFType a => Array a -> Double -> Int
rank Array a
a Double
b =
  CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Array a
a Array a -> (Ptr CUInt -> AFArray -> IO AFErr) -> CUInt
forall a b.
Storable a =>
Array b -> (Ptr a -> AFArray -> IO AFErr) -> a
`infoFromArray` (\Ptr CUInt
x AFArray
y -> Ptr CUInt -> AFArray -> Double -> IO AFErr
af_rank Ptr CUInt
x AFArray
y Double
b))

-- | Find the determinant of a Matrix
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__ops__func__det.htm)
--
-- C Interface for finding the determinant of a matrix.
--
det
  :: AFType a
  => Array a
  -- ^ is input matrix
  -> (Double,Double)
  -- ^ will contain the real and imaginary part of the determinant of in
det :: forall a. AFType a => Array a -> (Double, Double)
det = (Array a
-> (Ptr Double -> Ptr Double -> AFArray -> IO AFErr)
-> (Double, Double)
forall a b arr.
(Storable a, Storable b) =>
Array arr -> (Ptr a -> Ptr b -> AFArray -> IO AFErr) -> (a, b)
`infoFromArray2` Ptr Double -> Ptr Double -> AFArray -> IO AFErr
af_det)

-- | Find the norm of the input matrix.
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__ops__func__norm.htm)
--
-- This function can return the norm using various metrics based on the type paramter.
--
norm
  :: AFType a
  => Array a
  -- ^ is the input matrix
  -> NormType
  -- ^ specifies the 'NormType'
  -> Double
  -- ^ specifies the value of P when type is one of AF_NORM_VECTOR_P, AF_NORM_MATRIX_L_PQ is used. It is ignored for other values of type
  -> Double
  -- ^ specifies the value of Q when type is AF_NORM_MATRIX_L_PQ. This parameter is ignored if type is anything else
  -> Double
  -- ^ will contain the norm of in
norm :: forall a.
AFType a =>
Array a -> NormType -> Double -> Double -> Double
norm Array a
arr (NormType -> AFNormType
fromNormType -> AFNormType
a) Double
b Double
c =
  Array a
arr Array a -> (Ptr Double -> AFArray -> IO AFErr) -> Double
forall a b.
Storable a =>
Array b -> (Ptr a -> AFArray -> IO AFErr) -> a
`infoFromArray` (\Ptr Double
w AFArray
y -> Ptr Double -> AFArray -> AFNormType -> Double -> Double -> IO AFErr
af_norm Ptr Double
w AFArray
y AFNormType
a Double
b Double
c)

-- | Is LAPACK available
--
-- [ArrayFire Docs](http://arrayfire.org/docs/group__lapack__helper__func__available.htm)
--
isLAPACKAvailable
  :: Bool
  -- ^ Returns if LAPACK is available
isLAPACKAvailable :: Bool
isLAPACKAvailable =
  Int -> Bool
forall a. Enum a => Int -> a
toEnum (Int -> Bool) -> (CBool -> Int) -> CBool -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CBool -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CBool -> Bool) -> CBool -> Bool
forall a b. (a -> b) -> a -> b
$ (Ptr CBool -> IO AFErr) -> CBool
forall a. Storable a => (Ptr a -> IO AFErr) -> a
afCall1' Ptr CBool -> IO AFErr
af_is_lapack_available