{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Symmetric (
   Symmetric,
   size,
   fromList, autoFromList,
   identity,
   diagonal,
   takeDiagonal,
   transpose,
   adjoint,

   stack, (#%%%#),
   split,

   toSquare,

   gramian,            gramianTransposed,
   congruenceDiagonal, congruenceDiagonalTransposed,
   congruence,         congruenceTransposed,
   anticommutator,     anticommutatorTransposed,
   ) where

import qualified Numeric.LAPACK.Matrix.Symmetric.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular

import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import Numeric.LAPACK.Matrix.Array.Triangular (Symmetric)
import Numeric.LAPACK.Matrix.Array (Full, General, Square)
import Numeric.LAPACK.Matrix.Shape.Private (Order)
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (one)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((:+:))

import Foreign.Storable (Storable)


size :: Symmetric sh a -> sh
size = MatrixShape.triangularSize . ArrMatrix.shape

transpose :: Symmetric sh a -> Symmetric sh a
transpose = Triangular.transpose

adjoint :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Symmetric sh a
adjoint = Triangular.adjoint


fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Symmetric sh a
fromList = Triangular.symmetricFromList

autoFromList :: (Storable a) => Order -> [a] -> Symmetric ShapeInt a
autoFromList = Triangular.autoSymmetricFromList


toSquare :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Square sh a
toSquare = Triangular.toSquare


identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Symmetric sh a
identity order = Triangular.relaxUnitDiagonal . Triangular.identity order

diagonal ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Symmetric sh a
diagonal = Triangular.diagonal

takeDiagonal :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Vector sh a
takeDiagonal = Triangular.takeDiagonal


stack ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   Symmetric sh0 a ->
   General sh0 sh1 a ->
   Symmetric sh1 a ->
   Symmetric (sh0:+:sh1) a
stack = Triangular.stackSymmetric

infixr 2 #%%%#

(#%%%#) ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   (Symmetric sh0 a, General sh0 sh1 a) ->
   Symmetric sh1 a ->
   Symmetric (sh0:+:sh1) a
(#%%%#) = uncurry stack


split ::
   (Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   Symmetric (sh0:+:sh1) a ->
   (Symmetric sh0 a, General sh0 sh1 a, Symmetric sh1 a)
split = Triangular.splitSymmetric



{- |
gramian A = A^T * A
-}
gramian ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Symmetric width a
gramian = ArrMatrix.lift1 Basic.gramian

{- |
gramianTransposed A = A * A^T = gramian (A^T)
-}
gramianTransposed ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Symmetric height a
gramianTransposed = ArrMatrix.lift1 Basic.gramianTransposed

{- |
congruenceDiagonal D A = A^T * D * A
-}
congruenceDiagonal ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> General height width a -> Symmetric width a
congruenceDiagonal = ArrMatrix.lift1 . Basic.congruenceDiagonal

{- |
congruenceDiagonalTransposed A D = A * D * A^T
-}
congruenceDiagonalTransposed ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Vector width a -> Symmetric height a
congruenceDiagonalTransposed a =
   ArrMatrix.lift0 . Basic.congruenceDiagonalTransposed (ArrMatrix.toVector a)

{- |
congruence B A = A^T * B * A
-}
congruence ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Symmetric height a -> General height width a -> Symmetric width a
congruence = ArrMatrix.lift2 Basic.congruence

{- |
congruenceTransposed B A = A * B * A^T
-}
congruenceTransposed ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Symmetric width a -> Symmetric height a
congruenceTransposed = ArrMatrix.lift2 Basic.congruenceTransposed


{- |
anticommutator A B  =  A^T * B + B^T * A

Not exactly a matrix anticommutator,
thus I like to call it Symmetric anticommutator.
-}
anticommutator ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a -> Symmetric width a
anticommutator = ArrMatrix.lift2 $ Basic.scaledAnticommutator one

{- |
anticommutatorTransposed A B
   = A * B^T + B * A^T
   = anticommutator (transpose A) (transpose B)
-}
anticommutatorTransposed ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a -> Symmetric height a
anticommutatorTransposed =
   ArrMatrix.lift2 $ Basic.scaledAnticommutatorTransposed one