{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Hermitian (
   Hermitian,
   Transposition(..),

   size,
   fromList,
   autoFromList,
   identity,
   diagonal,
   takeDiagonal,
   forceOrder,

   stack, (*%%%#),
   split,
   takeTopLeft,
   takeTopRight,
   takeBottomRight,

   multiplyVector,
   square,
   multiplyFull,
   outer,
   sumRank1, sumRank1NonEmpty,
   sumRank2, sumRank2NonEmpty,

   toSquare,
   gramian,            gramianAdjoint,
   congruenceDiagonal, congruenceDiagonalAdjoint,
   congruence,         congruenceAdjoint,
   anticommutator,     anticommutatorAdjoint,
   addAdjoint,

   solve,
   inverse,
   determinant,

   eigenvalues,
   eigensystem,
   ) where

import qualified Numeric.LAPACK.Matrix.Hermitian.Eigen as Eigen
import qualified Numeric.LAPACK.Matrix.Hermitian.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Hermitian.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import Numeric.LAPACK.Matrix.Array.Triangular (Hermitian)
import Numeric.LAPACK.Matrix.Array (Full, General, Square)
import Numeric.LAPACK.Matrix.Shape.Private (Order)
import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed))
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, one)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((:+:))

import Foreign.Storable (Storable)

import qualified Data.NonEmpty as NonEmpty
import Data.Tuple.HT (mapFst)


size :: Hermitian sh a -> sh
size = MatrixShape.hermitianSize . ArrMatrix.shape

fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Hermitian sh a
fromList order sh = ArrMatrix.lift0 . Basic.fromList order sh

autoFromList :: (Storable a) => Order -> [a] -> Hermitian ShapeInt a
autoFromList order = ArrMatrix.lift0 . Basic.autoFromList order

identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Hermitian sh a
identity order = ArrMatrix.lift0 . Basic.identity order

diagonal ::
   (Shape.C sh, Class.Floating a) =>
   Order -> Vector sh (RealOf a) -> Hermitian sh a
diagonal order = ArrMatrix.lift0 . Basic.diagonal order

takeDiagonal ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a -> Vector sh (RealOf a)
takeDiagonal = Basic.takeDiagonal . ArrMatrix.toVector

forceOrder ::
   (Shape.C sh, Class.Floating a) =>
   Order -> Hermitian sh a -> Hermitian sh a
forceOrder = ArrMatrix.lift1 . Basic.forceOrder

{- |
> toSquare (stack a b c)
>
> =
>
> toSquare a ||| b
> ===
> adjoint b ||| toSquare c

It holds @order (stack a b c) = order b@.
The function is most efficient when the order of all blocks match.
-}
stack ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   Hermitian sh0 a -> General sh0 sh1 a -> Hermitian sh1 a ->
   Hermitian (sh0:+:sh1) a
stack = ArrMatrix.lift3 Basic.stack

infixr 2 *%%%#

(*%%%#) ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   (Hermitian sh0 a, General sh0 sh1 a) -> Hermitian sh1 a ->
   Hermitian (sh0:+:sh1) a
(*%%%#) = uncurry stack


split ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a ->
   (Hermitian sh0 a, General sh0 sh1 a, Hermitian sh1 a)
split a = (takeTopLeft a, takeTopRight a, takeBottomRight a)

takeTopLeft ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a -> Hermitian sh0 a
takeTopLeft = ArrMatrix.lift1 Basic.takeTopLeft

takeTopRight ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a -> General sh0 sh1 a
takeTopRight = ArrMatrix.lift1 Basic.takeTopRight

takeBottomRight ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Hermitian (sh0:+:sh1) a -> Hermitian sh1 a
takeBottomRight = ArrMatrix.lift1 Basic.takeBottomRight


multiplyVector ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Transposition -> Hermitian sh a -> Vector sh a -> Vector sh a
multiplyVector order = Basic.multiplyVector order . ArrMatrix.toVector

square ::
   (Shape.C sh, Eq sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a
square = ArrMatrix.lift1 Basic.square

multiplyFull ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width,
    Class.Floating a) =>
   Transposition -> Hermitian height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
multiplyFull = ArrMatrix.lift2 . Basic.multiplyFull

outer ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Hermitian sh a
outer order = ArrMatrix.lift0 . Basic.outer order

sumRank1 ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(RealOf a, Vector sh a)] -> Hermitian sh a
sumRank1 order sh = ArrMatrix.lift0 . Basic.sumRank1 order sh

sumRank1NonEmpty ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> NonEmpty.T [] (RealOf a, Vector sh a) -> Hermitian sh a
sumRank1NonEmpty order (NonEmpty.Cons x xs) =
   sumRank1 order (Array.shape $ snd x) (x:xs)

sumRank2 ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> Hermitian sh a
sumRank2 order sh = ArrMatrix.lift0 . Basic.sumRank2 order sh

sumRank2NonEmpty ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> NonEmpty.T [] (a, (Vector sh a, Vector sh a)) -> Hermitian sh a
sumRank2NonEmpty order (NonEmpty.Cons xy xys) =
   sumRank2 order (Array.shape $ fst $ snd xy) (xy:xys)

toSquare ::
   (Shape.C sh, Class.Floating a) => Hermitian sh a -> Square sh a
toSquare = ArrMatrix.lift1 Basic.toSquare

{- |
gramian A = A^H * A
-}
gramian ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Hermitian width a
gramian = ArrMatrix.lift1 Basic.gramian

{- |
gramianAdjoint A = A * A^H = gramian (A^H)
-}
gramianAdjoint ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Hermitian height a
gramianAdjoint = ArrMatrix.lift1 Basic.gramianAdjoint

{- |
congruenceDiagonal D A = A^H * D * A
-}
congruenceDiagonal ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height (RealOf a) -> General height width a -> Hermitian width a
congruenceDiagonal = ArrMatrix.lift1 . Basic.congruenceDiagonal

{- |
congruenceDiagonalAdjoint A D = A * D * A^H
-}
congruenceDiagonalAdjoint ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Vector width (RealOf a) -> Hermitian height a
congruenceDiagonalAdjoint a =
   ArrMatrix.lift0 . Basic.congruenceDiagonalAdjoint (ArrMatrix.toVector a)

{- |
congruence B A = A^H * B * A
-}
congruence ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Hermitian height a -> General height width a -> Hermitian width a
congruence = ArrMatrix.lift2 Basic.congruence

{- |
congruenceAdjoint B A = A * B * A^H
-}
congruenceAdjoint ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Hermitian width a -> Hermitian height a
congruenceAdjoint = ArrMatrix.lift2 Basic.congruenceAdjoint


{- |
anticommutator A B  =  A^H * B + B^H * A

Not exactly a matrix anticommutator,
thus I like to call it Hermitian anticommutator.
-}
anticommutator ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a -> Hermitian width a
anticommutator = ArrMatrix.lift2 $ Basic.scaledAnticommutator one

{- |
anticommutatorAdjoint A B
   = A * B^H + B * A^H
   = anticommutator (adjoint A) (adjoint B)
-}
anticommutatorAdjoint ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a -> Hermitian height a
anticommutatorAdjoint = ArrMatrix.lift2 $ Basic.scaledAnticommutatorAdjoint one

{- |
scaledAnticommutator alpha A B  =  alpha * A^H * B + conj alpha * B^H * A
-}
_scaledAnticommutator ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   a ->
   Full vert horiz height width a ->
   Full vert horiz height width a -> Hermitian width a
_scaledAnticommutator = ArrMatrix.lift2 . Basic.scaledAnticommutator

{- |
addAdjoint A = A^H + A
-}
addAdjoint :: (Shape.C sh, Class.Floating a) => Square sh a -> Hermitian sh a
addAdjoint = ArrMatrix.lift1 Basic.addAdjoint



solve ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Hermitian sh a -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve = ArrMatrix.lift2 Linear.solve

inverse :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> Hermitian sh a
inverse = ArrMatrix.lift1 Linear.inverse

determinant :: (Shape.C sh, Class.Floating a) => Hermitian sh a -> RealOf a
determinant = Linear.determinant . ArrMatrix.toVector



eigenvalues ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a -> Vector sh (RealOf a)
eigenvalues = Eigen.values . ArrMatrix.toVector

{- |
For symmetric eigenvalue problems, @eigensystem@ and @schur@ coincide.
-}
eigensystem ::
   (Shape.C sh, Class.Floating a) =>
   Hermitian sh a -> (Square sh a, Vector sh (RealOf a))
eigensystem = mapFst ArrMatrix.lift0 . Eigen.decompose . ArrMatrix.toVector