{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Symmetric (
   Symmetric,
   takeUpper,
   fromUpper,
   pack,
   assureSymmetry,

   size,
   fromList,
   autoFromList,
   identity,
   diagonal,
   takeDiagonal,
   forceOrder,
   transpose,
   adjoint,

   stack, (#%%%#),
   split,
   takeTopLeft,
   takeTopRight,
   takeBottomRight,

   toSquare,
   fromHermitian,

   multiplyVector,
   multiplyFull,
   square,

   tensorProduct,
   sumRank1, sumRank1NonEmpty,

   gramian,            gramianTransposed,
   congruenceDiagonal, congruenceDiagonalTransposed,
   congruence,         congruenceTransposed,
   anticommutator,     anticommutatorTransposed,
   addTransposed,

   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Symmetric.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Symmetric.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Symmetric.Unified as Symmetric
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Mosaic.Packed as Packed
import qualified Numeric.LAPACK.Matrix.Mosaic.Basic as Mosaic
import qualified Numeric.LAPACK.Matrix.Basic as FullBasic

import qualified Numeric.LAPACK.Matrix.Full as Full
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as ArrUnpacked
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array.Mosaic
         (Symmetric, SymmetricP, FlexHermitianP, Upper, assureMirrored)
import Numeric.LAPACK.Matrix.Array.Private (Full, General, Square, packTag)
import Numeric.LAPACK.Matrix.Layout.Private (Order)
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (one)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Bool as TBool

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+))

import qualified Data.NonEmpty as NonEmpty
import Foreign.Storable (Storable)


size :: SymmetricP pack sh a -> sh
size = Omni.squareSize . ArrMatrix.shape

transpose :: SymmetricP pack sh a -> SymmetricP pack sh a
transpose = id

adjoint ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   SymmetricP pack sh a -> SymmetricP pack sh a
adjoint = ArrMatrix.lift1 Vector.conjugate


fromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Symmetric sh a
fromList order sh = ArrMatrix.lift0 . Mosaic.fromList order sh

autoFromList :: (Storable a) => Order -> [a] -> Symmetric ShapeInt a
autoFromList order = ArrMatrix.lift0 . Mosaic.autoFromList order


toSquare ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   SymmetricP pack sh a -> Square sh a
toSquare a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Symmetric.toSquare a
      Layout.Unpacked -> ArrMatrix.liftUnpacked1 id a

fromHermitian ::
   (Layout.Packing pack, TBool.C neg, TBool.C zero, TBool.C pos,
    Shape.C sh, Class.Real a) =>
   FlexHermitianP pack neg zero pos sh a -> SymmetricP pack sh a
fromHermitian =
   ArrMatrix.lift1 $ Array.mapShape Layout.symmetricFromHermitian


identity :: (Shape.C sh, Class.Floating a) => Order -> sh -> Symmetric sh a
identity order = ArrMatrix.lift0 . Packed.identity order

diagonal ::
   (Shape.C sh, Class.Floating a) => Order -> Vector sh a -> Symmetric sh a
diagonal order = ArrMatrix.lift0 . Packed.diagonal order

takeDiagonal :: (Shape.C sh, Class.Floating a) => Symmetric sh a -> Vector sh a
takeDiagonal = Triangular.takeDiagonal . takeUpper

forceOrder ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   Order -> SymmetricP pack sh a -> SymmetricP pack sh a
forceOrder order a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 (Packed.forceOrder order) a
      Layout.Unpacked -> ArrMatrix.liftUnpacked1 (FullBasic.forceOrder order) a


pack ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   SymmetricP pack sh a -> Symmetric sh a
pack = ArrMatrix.lift1 Mosaic.pack

takeUpper :: Symmetric sh a -> Upper sh a
takeUpper = ArrMatrix.lift1 Mosaic.takeUpper

fromUpper :: Upper sh a -> Symmetric sh a
fromUpper = ArrMatrix.lift1 Mosaic.fromUpper

assureSymmetry ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   Square sh a -> SymmetricP pack sh a
assureSymmetry = assureMirrored


stack ::
   (Layout.Packing pack,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   SymmetricP pack sh0 a ->
   General sh0 sh1 a ->
   SymmetricP pack sh1 a ->
   SymmetricP pack (sh0::+sh1) a
stack a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift3 Packed.stackUpper a
      Layout.Unpacked ->
         ArrMatrix.liftUnpacked3
            (\a_ b_ c_ ->
               FullBasic.stackMosaic a_ b_ (FullBasic.transpose b_) c_)
            a

infixr 2 #%%%#

(#%%%#) ::
   (Layout.Packing pack,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   (SymmetricP pack sh0 a, General sh0 sh1 a) ->
   SymmetricP pack sh1 a ->
   SymmetricP pack (sh0::+sh1) a
(#%%%#) = uncurry stack


split ::
   (Layout.Packing pack,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   SymmetricP pack (sh0::+sh1) a ->
   (SymmetricP pack sh0 a, General sh0 sh1 a, SymmetricP pack sh1 a)
split a = (takeTopLeft a, takeTopRight a, takeBottomRight a)

takeTopLeft ::
   (Layout.Packing pack, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   SymmetricP pack (sh0::+sh1) a -> SymmetricP pack sh0 a
takeTopLeft a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Packed.takeTopLeft a
      Layout.Unpacked -> ArrUnpacked.takeTopLeft a

takeTopRight ::
   (Layout.Packing pack, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   SymmetricP pack (sh0::+sh1) a -> General sh0 sh1 a
takeTopRight a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Packed.takeTopRight a
      Layout.Unpacked -> ArrUnpacked.takeTopRight a

takeBottomRight ::
   (Layout.Packing pack, Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   SymmetricP pack (sh0::+sh1) a -> SymmetricP pack sh1 a
takeBottomRight a =
   case packTag a of
      Layout.Unpacked -> ArrUnpacked.takeBottomRight a
      Layout.Packed -> ArrMatrix.lift1 Packed.takeBottomRight a



multiplyVector ::
   (Layout.Packing pack) =>
   (Shape.C sh, Eq sh, Class.Floating a) =>
   SymmetricP pack sh a -> Vector sh a -> Vector sh a
multiplyVector = Symmetric.multiplyVector . ArrMatrix.toVector

multiplyFull ::
   (Layout.Packing pack) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width,
    Class.Floating a) =>
   SymmetricP pack height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
multiplyFull = ArrMatrix.lift2 Symmetric.multiplyFull

square ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   SymmetricP pack sh a -> SymmetricP pack sh a
square = ArrMatrix.lift1 $ Mosaic.square Omni.Arbitrary



tensorProduct ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   Order -> Vector sh a -> SymmetricP pack sh a
tensorProduct order = ArrMatrix.lift0 . Symmetric.outerUpper order

sumRank1 ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> sh -> [(a, Vector sh a)] -> SymmetricP pack sh a
sumRank1 order sh = ArrMatrix.lift0 . Basic.sumRank1 order sh

sumRank1NonEmpty ::
   (Layout.Packing pack, Shape.C sh, Eq sh, Class.Floating a) =>
   Order -> NonEmpty.T [] (a, Vector sh a) -> SymmetricP pack sh a
sumRank1NonEmpty order (NonEmpty.Cons x xs) =
   sumRank1 order (Array.shape $ snd x) (x:xs)

{-
We do not export a generic function that is polymorphic in the mirror parameter,
because Symmetric and Hermitian Gramian are actually different functions.
-}
{- |
gramian A = A^T * A
-}
gramian ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> SymmetricP pack width a
gramian = ArrMatrix.lift1 Symmetric.gramian

{- |
gramianTransposed A = A * A^T = gramian (A^T)
-}
gramianTransposed ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> SymmetricP pack height a
gramianTransposed = ArrMatrix.lift1 Symmetric.gramianTransposed

{- |
congruenceDiagonal D A = A^T * D * A
-}
congruenceDiagonal ::
   (Layout.Packing pack) =>
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> General height width a -> SymmetricP pack width a
congruenceDiagonal = ArrMatrix.lift1 . Basic.congruenceDiagonal

{- |
congruenceDiagonalTransposed A D = A * D * A^T
-}
congruenceDiagonalTransposed ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> Vector width a -> SymmetricP pack height a
congruenceDiagonalTransposed a =
   ArrMatrix.lift0 . Basic.congruenceDiagonalTransposed (ArrMatrix.toVector a)

{- |
congruence B A = A^T * B * A
-}
congruence ::
   (Layout.Packing pack) =>
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   SymmetricP pack height a -> General height width a -> SymmetricP pack width a
congruence =
   ArrMatrix.lift2 $ \b -> Symmetric.congruence (Mosaic.unpackDirty b)

{- |
congruenceTransposed B A = A * B * A^T
-}
congruenceTransposed ::
   (Layout.Packing pack) =>
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   General height width a -> SymmetricP pack width a -> SymmetricP pack height a
congruenceTransposed =
   ArrMatrix.lift2 $ \a ->
      Symmetric.congruenceTransposed a . Mosaic.unpackDirty


{- |
anticommutator A B  =  A^T * B + B^T * A

Not exactly a matrix anticommutator,
thus I like to call it Symmetric anticommutator.
-}
anticommutator ::
   (Layout.Packing pack,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a -> SymmetricP pack width a
anticommutator =
   ArrMatrix.lift2 $
      Symmetric.scaledAnticommutator Layout.SimpleMirror one

{- |
anticommutatorTransposed A B
   = A * B^T + B * A^T
   = anticommutator (transpose A) (transpose B)
-}
anticommutatorTransposed ::
   (Layout.Packing pack,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a -> SymmetricP pack height a
anticommutatorTransposed =
   ArrMatrix.lift2 $
      Symmetric.scaledAnticommutatorTransposed Layout.SimpleMirror one


{- |
addTransposed A = A^T + A
-}
addTransposed ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   Square sh a -> SymmetricP pack sh a
addTransposed a =
   let pck = Layout.autoPacking
   in ArrMatrix.requirePacking pck $
      case pck of
         Layout.Packed -> ArrMatrix.lift1 Symmetric.addMirrored a
         Layout.Unpacked ->
            ArrMatrix.liftUnpacked1 FullBasic.recheck $
               let au = ArrUnpacked.uncheck a
               in ArrMatrix.add (Full.transpose au) au



solve ::
   (Layout.Packing pack) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   SymmetricP pack sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve = ArrMatrix.lift2 Symmetric.solve

inverse ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   SymmetricP pack sh a -> SymmetricP pack sh a
inverse = ArrMatrix.lift1 Symmetric.inverse

determinant ::
   (Layout.Packing pack, Shape.C sh, Class.Floating a) =>
   SymmetricP pack sh a -> a
determinant = Linear.determinant . ArrMatrix.toVector