{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Square (
   Square,
   size,
   mapSize,
   toFull,
   toGeneral,
   fromGeneral,
   fromScalar,
   toScalar,
   fromList,
   autoFromList,

   transpose,
   adjoint,

   identity,
   identityFrom,
   identityFromWidth,
   identityFromHeight,
   diagonal,
   takeDiagonal,
   trace,

   stack, (|=|),

   multiply,
   square,
   power,
   congruence,
   congruenceAdjoint,

   solve,
   inverse,
   determinant,

   eigenvalues,
   schur,
   schurComplex,
   eigensystem,
   ComplexOf,
   ) where

import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Square.Eigen as Eigen
import qualified Numeric.LAPACK.Matrix.Square.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Square.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Basic as FullBasic

import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import Numeric.LAPACK.Matrix.Array (Full, General, Square)
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (ComplexOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((:+:))

import Foreign.Storable (Storable)

import Data.Tuple.HT (mapPair, mapSnd, mapTriple)
import Data.Complex (Complex)


size :: Square sh a -> sh
size = MatrixShape.fullHeight . ArrMatrix.shape

mapSize :: (sh0 -> sh1) -> Square sh0 a -> Square sh1 a
mapSize = ArrMatrix.lift1 . Basic.mapSize

toGeneral :: Square sh a -> General sh sh a
toGeneral = toFull

toFull ::
   (Extent.C vert, Extent.C horiz) => Square sh a -> Full vert horiz sh sh a
toFull = ArrMatrix.lift1 Basic.toFull

fromGeneral :: (Eq sh) => General sh sh a -> Square sh a
fromGeneral = ArrMatrix.lift1 Basic.fromGeneral


fromScalar :: (Storable a) => a -> Square () a
fromScalar = ArrMatrix.lift0 . Basic.fromScalar

toScalar :: (Storable a) => Square () a -> a
toScalar = Basic.toScalar . ArrMatrix.toVector

fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Square sh a
fromList sh = ArrMatrix.lift0 . Basic.fromList sh

autoFromList :: (Storable a) => [a] -> Square ShapeInt a
autoFromList = ArrMatrix.lift0 . Basic.autoFromList

transpose :: Square sh a -> Square sh a
transpose = ArrMatrix.lift1 Basic.transpose

{- |
conjugate transpose
-}
adjoint :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a
adjoint = ArrMatrix.lift1 Basic.adjoint

identity :: (Shape.C sh, Class.Floating a) => sh -> Square sh a
identity = ArrMatrix.lift0 . Basic.identity

identityFrom :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a
identityFrom = ArrMatrix.lift1 Basic.identityFrom

identityFromWidth ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Square width a
identityFromWidth = ArrMatrix.lift1 Basic.identityFromWidth

identityFromHeight ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Square height a
identityFromHeight = ArrMatrix.lift1 Basic.identityFromHeight

diagonal :: (Shape.C sh, Class.Floating a) => Vector sh a -> Square sh a
diagonal = ArrMatrix.lift0 . Basic.diagonal

takeDiagonal :: (Shape.C sh, Class.Floating a) => Square sh a -> Vector sh a
takeDiagonal = Basic.takeDiagonal . ArrMatrix.toVector

trace :: (Shape.C sh, Class.Floating a) => Square sh a -> a
trace = Basic.trace . ArrMatrix.toVector

infix 3 |=|

(|=|) ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sizeA, Eq sizeA, Shape.C sizeB, Eq sizeB, Class.Floating a) =>
   (Square sizeA a, Full vert horiz sizeA sizeB a) ->
   (Full horiz vert sizeB sizeA a, Square sizeB a) ->
   Square (sizeA:+:sizeB) a
(a,b) |=| (c,d)  =  stack a b c d

stack ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sizeA, Eq sizeA, Shape.C sizeB, Eq sizeB, Class.Floating a) =>
   Square sizeA a -> Full vert horiz sizeA sizeB a ->
   Full horiz vert sizeB sizeA a -> Square sizeB a ->
   Square (sizeA:+:sizeB) a
stack = ArrMatrix.lift4 Basic.stack

multiply ::
   (Shape.C sh, Eq sh, Class.Floating a) =>
   Square sh a -> Square sh a -> Square sh a
multiply = ArrMatrix.lift2 FullBasic.multiply

square :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a
square = ArrMatrix.lift1 Basic.square

power ::
   (Shape.C sh, Class.Floating a) =>
   Integer -> Square sh a -> Square sh a
power = ArrMatrix.lift1 . Basic.power

{- |
congruence B A = A^H * B * A


The meaning and order of matrix factors of these functions is consistent:

* 'Numeric.LAPACK.Matrix.Square.congruence'
* 'Numeric.LAPACK.Matrix.Hermitian.gramian'
* 'Numeric.LAPACK.Matrix.Hermitian.anticommutator'
* 'Numeric.LAPACK.Matrix.Hermitian.congruence'
* 'Numeric.LAPACK.Matrix.Hermitian.congruenceDiagonal'
-}
congruence ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Square height a -> General height width a -> Square width a
congruence = ArrMatrix.lift2 Basic.congruence

{- |
congruenceAdjoint A B = A * B * A^H
-}
congruenceAdjoint ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Square width a -> Square height a
congruenceAdjoint = ArrMatrix.lift2 Basic.congruenceAdjoint



solve ::
   (Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Square sh a -> Full vert horiz sh nrhs a -> Full vert horiz sh nrhs a
solve = ArrMatrix.lift2 Linear.solve

inverse :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a
inverse = ArrMatrix.lift1 Linear.inverse

determinant :: (Shape.C sh, Class.Floating a) => Square sh a -> a
determinant = Linear.determinant . ArrMatrix.toVector



eigenvalues ::
   (Shape.C sh, Class.Floating a) =>
   Square sh a -> Vector sh (ComplexOf a)
eigenvalues = Eigen.values . ArrMatrix.toVector

{- |
If @(q,r) = schur a@, then @a = q \<\> r \<\> adjoint q@,
where @q@ is unitary (orthogonal)
and @r@ is a right-upper triangular matrix for complex @a@
and a 1x1-or-2x2-block upper triangular matrix for real @a@.
With @takeDiagonal r@ you get all eigenvalues of @a@ if @a@ is complex
and the real parts of the eigenvalues if @a@ is real.
Complex conjugated eigenvalues of a real matrix @a@
are encoded as 2x2 blocks along the diagonal.


The meaning and order of matrix factors of these functions is consistent:

* 'Numeric.LAPACK.Matrix.Square.schur'
* 'Numeric.LAPACK.Matrix.Square.schurComplex'
* 'Numeric.LAPACK.Matrix.Hermitian.eigensystem'
* 'Numeric.LAPACK.Matrix.BandedHermitian.eigensystem'
* 'Numeric.LAPACK.Matrix.Square.congruenceAdjoint'
* 'Numeric.LAPACK.Matrix.Hermitian.gramianAdjoint'
* 'Numeric.LAPACK.Matrix.Hermitian.anticommutatorAdjoint'
* 'Numeric.LAPACK.Matrix.Hermitian.congruenceAdjoint'
* 'Numeric.LAPACK.Matrix.Hermitian.congruenceDiagonalAdjoint'
-}
schur ::
   (Shape.C sh, Class.Floating a) =>
   Square sh a -> (Square sh a, Square sh a)
schur =
   mapPair (ArrMatrix.lift0, ArrMatrix.lift0) . Eigen.schur . ArrMatrix.toVector

schurComplex ::
   (Shape.C sh, Class.Real a, Complex a ~ ac) =>
   Square sh ac -> (Square sh ac, Triangular.Upper sh ac)
schurComplex = mapSnd Triangular.takeUpper . schur


{- |
@(vr,d,vlAdj) = eigensystem a@

Counterintuitively, @vr@ contains the right eigenvectors as columns
and @vlAdj@ contains the left conjugated eigenvectors as rows.
The idea is to provide a decomposition of @a@.
If @a@ is diagonalizable, then @vr@ and @vlAdj@
are almost inverse to each other.
More precisely, @vlAdj \<\> vr@ is a diagonal matrix,
but not necessarily an identity matrix.
This is because all eigenvectors are normalized to Euclidean norm 1.
With the following scaling, the decomposition becomes perfect:

> let scal = takeDiagonal $ vlAdj <> vr
> a == vr #*\ Vector.divide d scal ##*# vlAdj

If @a@ is non-diagonalizable
then some columns of @vr@ and corresponding rows of @vlAdj@ are left zero
and the above property does not hold.


The meaning and order of result matrices of these functions is consistent:

* 'Numeric.LAPACK.Matrix.Square.eigensystem'
* 'Numeric.LAPACK.Matrix.Triangular.eigensystem'
* 'Numeric.LAPACK.Singular.decompose'
* 'Numeric.LAPACK.Singular.decomposeTall'
* 'Numeric.LAPACK.Singular.decomposeWide'
-}
eigensystem ::
   (Shape.C sh, Class.Floating a, ComplexOf a ~ ac) =>
   Square sh a -> (Square sh ac, Vector sh ac, Square sh ac)
eigensystem =
   mapTriple (ArrMatrix.lift0, id, ArrMatrix.lift0) .
   Eigen.decompose . ArrMatrix.toVector