{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Hermitian (
   FlexHermitian,
   Hermitian,
   HermitianPosDef,
   HermitianPosSemidef,
   Transposition(..),

   Hermitian.Semidefinite,
   Hermitian.assureFullRank,
   Hermitian.assureAnyRank,
   Hermitian.relaxSemidefinite,
   Hermitian.relaxIndefinite,
   Hermitian.assurePositiveDefiniteness,
   Hermitian.relaxDefiniteness,
   Hermitian.asUnknownDefiniteness,

   pack,
   size,
   fromList,
   autoFromList,
   identity,
   diagonal,
   takeDiagonal,
   forceOrder,

   stack, (*%%%#),
   split,
   takeTopLeft,
   takeTopRight,
   takeBottomRight,

   toSquare,
   fromSymmetric,

   negate,

   multiplyVector,
   multiplyFull,
   square,

   outer,
   sumRank1, sumRank1NonEmpty,
   sumRank2, sumRank2NonEmpty,

   gramian,            gramianAdjoint,
   congruenceDiagonal, congruenceDiagonalAdjoint,
   congruence,         congruenceAdjoint,
   anticommutator,     anticommutatorAdjoint,
   addAdjoint,

   solve,
   inverse,
   determinant,

   eigenvalues,
   eigensystem,
   ) where

import qualified Numeric.LAPACK.Matrix.Hermitian.Linear as Linear
import qualified Numeric.LAPACK.Matrix.HermitianPositiveDefinite.Linear
                                                         as LinearPD
import qualified Numeric.LAPACK.Matrix.Hermitian.Eigen as Eigen
import qualified Numeric.LAPACK.Matrix.Hermitian.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Symmetric.Unified as Symmetric
import qualified Numeric.LAPACK.Matrix.Mosaic.Packed as Packed
import qualified Numeric.LAPACK.Matrix.Mosaic.Basic as Mosaic
import qualified Numeric.LAPACK.Matrix.Basic as FullBasic

import qualified Numeric.LAPACK.Matrix.Full as Full
import qualified Numeric.LAPACK.Matrix.Array.Hermitian as Hermitian
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as ArrUnpacked
import qualified Numeric.LAPACK.Matrix.Array.Basic as OmniMatrix
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import qualified Numeric.LAPACK.Shape as ExtShape
import Numeric.LAPACK.Matrix.Array.Mosaic
         (FlexHermitian, FlexHermitianP,
          Hermitian, HermitianP,
          HermitianPosSemidef, HermitianPosSemidefP, HermitianPosDef,
          SymmetricP)
import Numeric.LAPACK.Matrix.Array.Private (Full, General, Square, packTag)
import Numeric.LAPACK.Matrix.Layout.Private (Order)
import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed))
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, one)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Bool as TBool
import Type.Data.Bool (True)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+))

import qualified Data.NonEmpty as NonEmpty
import Data.Tuple.HT (mapFst)

import Prelude hiding (negate)


size :: FlexHermitianP pack neg zero pos sh a -> sh
size = Omni.squareSize . ArrMatrix.shape


fromList ::
   (Shape.C sh, Class.Floating a) => Order -> sh -> [a] -> Hermitian sh a
fromList order sh = ArrMatrix.fromVector . Mosaic.fromList order sh

autoFromList :: (Class.Floating a) => Order -> [a] -> Hermitian ShapeInt a
autoFromList order = ArrMatrix.fromVector . Mosaic.autoFromList order


identity ::
   (Shape.C sh, Class.Floating a) =>
   Order -> sh -> HermitianPosDef sh a
identity order = ArrMatrix.lift0 . Packed.identity order

diagonal ::
   (Shape.C sh, Class.Floating a) =>
   Order -> Vector sh (RealOf a) -> Hermitian sh a
diagonal order = ArrMatrix.lift0 . Basic.diagonal order

takeDiagonal ::
   (TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh, Class.Floating a) =>
   FlexHermitian neg zero pos sh a -> Vector sh (RealOf a)
takeDiagonal = Basic.takeDiagonal . ArrMatrix.toVector

forceOrder ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh, Class.Floating a) =>
   Order ->
   FlexHermitianP pack neg zero pos sh a ->
   FlexHermitianP pack neg zero pos sh a
forceOrder order a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 (Packed.forceOrder order) a
      Layout.Unpacked -> ArrMatrix.liftUnpacked1 (FullBasic.forceOrder order) a


pack ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a -> FlexHermitian neg zero pos sh a
pack = ArrMatrix.lift1 Mosaic.pack


{- |
> toSquare (stack a b c)
>
> =
>
> toSquare a ||| b
> ===
> adjoint b ||| toSquare c

It holds @order (stack a b c) = order b@.
The function is most efficient when the order of all blocks match.
-}
stack ::
   (Layout.Packing pack,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   HermitianP pack sh0 a -> General sh0 sh1 a -> HermitianP pack sh1 a ->
   HermitianP pack (sh0::+sh1) a
stack a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift3 Packed.stackUpper a
      Layout.Unpacked ->
         ArrMatrix.liftUnpacked3
            (\a_ b_ c_ ->
               FullBasic.stackMosaic a_ b_ (FullBasic.adjoint b_) c_)
            a

infixr 2 *%%%#

(*%%%#) ::
   (Layout.Packing pack,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   (HermitianP pack sh0 a, General sh0 sh1 a) -> HermitianP pack sh1 a ->
   HermitianP pack (sh0::+sh1) a
(*%%%#) = uncurry stack


{-
The definiteness is transfered from the big matrix to its parts,
because it literally means to restrict the set of vectors in @x^T*A*x@
to ones that have parts of the vectors zeroed.
-}
split ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexHermitianP pack neg zero pos (sh0::+sh1) a ->
   (FlexHermitianP pack neg zero pos sh0 a,
    General sh0 sh1 a,
    FlexHermitianP pack neg zero pos sh1 a)
split a = (takeTopLeft a, takeTopRight a, takeBottomRight a)

{- |
Sub-matrices maintain definiteness of the original matrix.
Consider x^* A x > 0.
Then y^* (take A) y = x^* A x where some components of x are zero.
-}
takeTopLeft ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexHermitianP pack neg zero pos (sh0::+sh1) a ->
   FlexHermitianP pack neg zero pos sh0 a
takeTopLeft a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Packed.takeTopLeft a
      Layout.Unpacked -> ArrUnpacked.takeTopLeft a

takeTopRight ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexHermitianP pack neg zero pos (sh0::+sh1) a -> General sh0 sh1 a
takeTopRight a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Packed.takeTopRight a
      Layout.Unpacked -> ArrUnpacked.takeTopRight a

takeBottomRight ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexHermitianP pack neg zero pos (sh0::+sh1) a ->
   FlexHermitianP pack neg zero pos sh1 a
takeBottomRight a =
   case packTag a of
      Layout.Unpacked -> ArrUnpacked.takeBottomRight a
      Layout.Packed -> ArrMatrix.lift1 Packed.takeBottomRight a


negate ::
   (TBool.C neg, TBool.C zero, TBool.C pos, Shape.C sh, Class.Floating a) =>
   Hermitian.AnyHermitianP pack neg zero pos bands sh a ->
   Hermitian.AnyHermitianP pack pos zero neg bands sh a
negate a =
   case ArrMatrix.shape a of
      Omni.Full _ -> ArrMatrix.liftUnpacked1 Vector.negate a
      Omni.Hermitian _ -> ArrMatrix.lift1 Vector.negate a
      Omni.BandedHermitian _ -> ArrMatrix.lift1 Vector.negate a


multiplyVector ::
   (Layout.Packing pack) =>
   (TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh, Eq sh, Class.Floating a) =>
   Transposition -> FlexHermitianP pack neg zero pos sh a ->
   Vector sh a -> Vector sh a
multiplyVector trans =
   (case trans of
      NonTransposed -> Symmetric.multiplyVector
      Transposed -> Symmetric.multiplyVector . Mosaic.transpose) .
   ArrMatrix.toVector

multiplyFull ::
   (Layout.Packing pack) =>
   (TBool.C neg, TBool.C zero, TBool.C pos,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width,
    Class.Floating a) =>
   Transposition -> FlexHermitianP pack neg zero pos height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
multiplyFull trans =
   ArrMatrix.lift2 $
   case trans of
      NonTransposed -> Symmetric.multiplyFull
      Transposed -> Symmetric.multiplyFull . Mosaic.transpose

square ::
   (Layout.Packing pack) =>
   (TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a ->
   FlexHermitianP pack neg zero pos sh a
square = ArrMatrix.lift1 $ Mosaic.square Omni.Arbitrary


outer ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   Order -> Vector sh a -> HermitianPosSemidefP pack sh a
outer order = ArrMatrix.lift0 . Symmetric.outerUpper order

sumRank1 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(RealOf a, Vector sh a)] -> HermitianPosSemidefP pack sh a
sumRank1 order sh = ArrMatrix.lift0 . Basic.sumRank1 order sh

sumRank1NonEmpty ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order ->
   NonEmpty.T [] (RealOf a, Vector sh a) -> HermitianPosSemidefP pack sh a
sumRank1NonEmpty order (NonEmpty.Cons x xs) =
   sumRank1 order (Array.shape $ snd x) (x:xs)

sumRank2 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, (Vector sh a, Vector sh a))] -> HermitianP pack sh a
sumRank2 order sh = ArrMatrix.lift0 . Basic.sumRank2 order sh

sumRank2NonEmpty ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order ->
   NonEmpty.T [] (a, (Vector sh a, Vector sh a)) -> HermitianP pack sh a
sumRank2NonEmpty order (NonEmpty.Cons xy xys) =
   sumRank2 order (Array.shape $ fst $ snd xy) (xy:xys)


toSquare ::
   (Layout.Packing pack,
    TBool.C neg, TBool.C zero, TBool.C pos, Shape.C sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a -> Square sh a
toSquare a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Symmetric.toSquare a
      Layout.Unpacked -> ArrMatrix.liftUnpacked1 id a

fromSymmetric ::
   (Layout.Packing pack, Shape.C sh, Class.Real a) =>
   SymmetricP pack sh a -> HermitianP pack sh a
fromSymmetric =
   ArrMatrix.lift1 $ Array.mapShape Layout.hermitianFromSymmetric


{- |
gramian A = A^H * A
-}
gramian ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> HermitianPosSemidefP pack width a
gramian = ArrMatrix.lift1 Symmetric.gramian

{- |
gramianAdjoint A = A * A^H = gramian (A^H)
-}
gramianAdjoint ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> HermitianPosSemidefP pack height a
gramianAdjoint = ArrMatrix.lift1 Symmetric.gramianTransposed

{- |
congruenceDiagonal D A = A^H * D * A
-}
congruenceDiagonal ::
   (Layout.Packing pack) =>
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height (RealOf a) -> General height width a -> HermitianP pack width a
congruenceDiagonal = ArrMatrix.lift1 . Symmetric.congruenceRealDiagonal

{- |
congruenceDiagonalAdjoint A D = A * D * A^H
-}
congruenceDiagonalAdjoint ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Vector width (RealOf a) -> HermitianP pack height a
congruenceDiagonalAdjoint a =
   ArrMatrix.lift0 .
      Symmetric.congruenceRealDiagonalTransposed (ArrMatrix.toVector a)

{- |
congruence B A = A^H * B * A
-}
congruence ::
   (Layout.Packing pack) =>
   (TBool.C neg, TBool.C pos,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   FlexHermitianP pack neg True pos height a ->
   General height width a ->
   FlexHermitianP pack neg True pos width a
congruence =
   ArrMatrix.lift2 $ \b -> Symmetric.congruence (Mosaic.unpackDirty b)

{- |
congruenceAdjoint B A = A * B * A^H
-}
congruenceAdjoint ::
   (Layout.Packing pack) =>
   (TBool.C neg, TBool.C pos,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a ->
   FlexHermitianP pack neg True pos width a ->
   FlexHermitianP pack neg True pos height a
congruenceAdjoint =
   ArrMatrix.lift2 $ \a ->
      Symmetric.congruenceTransposed a . Mosaic.unpackDirty


{- |
anticommutator A B  =  A^H * B + B^H * A

Not exactly a matrix anticommutator,
thus I like to call it Hermitian anticommutator.
-}
anticommutator ::
   (Layout.Packing pack) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a -> HermitianP pack width a
anticommutator =
   ArrMatrix.lift2 $
      Symmetric.scaledAnticommutator Layout.ConjugateMirror one

{- |
anticommutatorAdjoint A B
   = A * B^H + B * A^H
   = anticommutator (adjoint A) (adjoint B)
-}
anticommutatorAdjoint ::
   (Layout.Packing pack) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a -> HermitianP pack height a
anticommutatorAdjoint =
   ArrMatrix.lift2 $
      Symmetric.scaledAnticommutatorTransposed Layout.ConjugateMirror one

{- |
scaledAnticommutator alpha A B  =  alpha * A^H * B + conj alpha * B^H * A
-}
_scaledAnticommutator ::
   (Layout.Packing pack) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a -> HermitianP pack width a
_scaledAnticommutator =
   ArrMatrix.lift2 .
      Symmetric.scaledAnticommutator Layout.ConjugateMirror

{- |
addAdjoint A = A^H + A
-}
addAdjoint ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   Square sh a -> HermitianP pack sh a
addAdjoint a =
   let pck = Layout.autoPacking
   in ArrMatrix.requirePacking pck $
      case pck of
         Layout.Packed -> ArrMatrix.lift1 Symmetric.addMirrored a
         Layout.Unpacked ->
            ArrMatrix.liftUnpacked1 FullBasic.recheck $
               let au = ArrUnpacked.uncheck a
               in ArrMatrix.add (Full.adjoint au) au



solve ::
   (Layout.Packing pack) =>
   (TBool.C neg, TBool.C zero, TBool.C pos,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve a =
   case Omni.hermitianSet $ ArrMatrix.shape a of
      (TBool.False, _, TBool.True) -> ArrMatrix.lift2 LinearPD.solve a
      (TBool.True, _, TBool.False) ->
         ArrMatrix.negate . ArrMatrix.lift2 LinearPD.solve (negate a)
      _ -> ArrMatrix.lift2 Symmetric.solve a

inverse ::
   (Layout.Packing pack,
    TBool.C neg, TBool.C zero, TBool.C pos, Shape.C sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a ->
   FlexHermitianP pack neg zero pos sh a
inverse a =
   case Omni.hermitianSet $ ArrMatrix.shape a of
      (TBool.False, _, TBool.True) -> ArrMatrix.lift1 LinearPD.inverse a
      (TBool.True, _, TBool.False) ->
         negate $ ArrMatrix.lift1 LinearPD.inverse $ negate a
      _ -> ArrMatrix.lift1 Symmetric.inverse a

determinant ::
   (Layout.Packing pack,
    TBool.C neg, TBool.C zero, TBool.C pos, Shape.C sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a -> RealOf a
determinant a =
   case Omni.hermitianSet $ ArrMatrix.shape a of
      (TBool.False, TBool.False, TBool.True) ->
         LinearPD.determinant $ ArrMatrix.toVector a
      (TBool.True, TBool.False, TBool.False) ->
         case Scalar.complexSingletonOfFunctor a of
            Scalar.Real -> determinantNegDef a
            Scalar.Complex -> determinantNegDef a
      _ -> Linear.determinant $ ArrMatrix.toVector a

determinantNegDef ::
   (Layout.Packing pack,
    Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) =>
   FlexHermitianP pack TBool.True TBool.False TBool.False sh a -> ar
determinantNegDef a =
   OmniMatrix.signNegativeDeterminant (ArrMatrix.shape a) *
   (LinearPD.determinant $ ArrMatrix.toVector $ negate a)



eigenvalues ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    ExtShape.Permutable sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a -> Vector sh (RealOf a)
eigenvalues = Eigen.values . ArrMatrix.toVector

{- |
For symmetric eigenvalue problems, @eigensystem@ and @schur@ coincide.
-}
eigensystem ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    ExtShape.Permutable sh, Class.Floating a) =>
   FlexHermitianP pack neg zero pos sh a -> (Square sh a, Vector sh (RealOf a))
eigensystem = mapFst ArrMatrix.lift0 . Eigen.decompose . ArrMatrix.toVector