{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Quadratic (
   asDiagonal,
   asSymmetric,

   size,
   mapSize,

   identity,
   diagonal, Diagonal,
   OmniMatrix.takeDiagonal,

   takeTopLeft,
   takeTopRight,
   takeBottomLeft,
   takeBottomRight,
   ) where

import qualified Numeric.LAPACK.Matrix.Symmetric as Symmetric
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Mosaic.Private as Mos

import qualified Numeric.LAPACK.Matrix.BandedHermitian.Basic as BandedHermitian
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Square.Basic as Square
import qualified Numeric.LAPACK.Matrix.Basic as Full

import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Class as MatrixClass
import qualified Numeric.LAPACK.Matrix.Array.Basic as OmniMatrix
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import Numeric.LAPACK.Matrix.Array.Banded (FlexDiagonal)
import Numeric.LAPACK.Matrix.Array.Mosaic (Symmetric)
import Numeric.LAPACK.Matrix.Array.Private (General, Quadratic)
import Numeric.LAPACK.Matrix.Shape.Omni (Arbitrary)
import Numeric.LAPACK.Matrix.Layout.Private (Order)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+))
import Data.Function.HT (Id)



asDiagonal :: Id (FlexDiagonal diag sh a)
asDiagonal = id

asSymmetric :: Id (Symmetric sh a)
asSymmetric = id


size :: Quadratic pack property lower upper sh a -> sh
size = Omni.squareSize . ArrMatrix.shape

{- |
The number of rows and columns must be maintained by the shape mapping function.
-}
mapSize ::
   (Shape.C shA, Shape.C shB) =>
   (shA -> shB) ->
   Quadratic pack property lower upper shA a ->
   Quadratic pack property lower upper shB a
mapSize = OmniMatrix.mapSquareSize


identity ::
   (Omni.Quadratic pack property lower upper, Shape.C sh, Class.Floating a) =>
   Order -> sh -> Quadratic pack property lower upper sh a
identity = OmniMatrix.identityOrder

diagonal ::
   (Diagonal property, Omni.Quadratic pack property lower upper,
    Shape.C sh, Class.Floating a) =>
   Order -> Vector sh a -> Quadratic pack property lower upper sh a
diagonal order v = diagonalAux (Omni.quadratic order $ Array.shape v) v

class (Omni.Property property) => Diagonal property where
   diagonalAux ::
      (Omni.Quadratic pack property lower upper) =>
      (Shape.C sh, Class.Floating a) =>
      MatrixShape.Quadratic pack property lower upper sh -> Vector sh a ->
      Quadratic pack property lower upper sh a

instance Diagonal Arbitrary where
   diagonalAux omni v =
      case omni of
         Omni.Full _ -> squareDiagonal omni v
         Omni.LowerTriangular _ -> Triangular.diagonal (Omni.order omni) v
         Omni.UpperTriangular _ -> Triangular.diagonal (Omni.order omni) v
         Omni.Banded _ ->
            ArrMatrix.Array $ Array.mapShape Omni.Banded $
            Banded.diagonalFat (Omni.order omni) v

instance Diagonal Omni.Symmetric where
   diagonalAux omni v =
      case omni of
         Omni.Full _ -> squareDiagonal omni v
         Omni.LowerTriangular _ -> error "lower triangular not symmetric"
         Omni.UpperTriangular _ -> error "upper triangular not symmetric"
         Omni.Symmetric _ -> Symmetric.diagonal (Omni.order omni) v

squareDiagonal ::
   (Omni.Property property, Omni.Strip lower, Omni.Strip upper) =>
   (Shape.C sh, Class.Floating a) =>
   MatrixShape.Quadratic pack property lower upper sh -> Vector sh a ->
   Quadratic Layout.Unpacked property lower upper sh a
squareDiagonal omni =
   ArrMatrix.liftUnpacked0 . Square.diagonalOrder (Omni.order omni)


takeTopLeft ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Quadratic pack property lower upper (sh0::+sh1) a ->
   Quadratic pack property lower upper sh0 a
takeTopLeft a =
   case ArrMatrix.shape a of
      Omni.Full _ ->
         ArrMatrix.liftUnpacked1
            (Full.recheck . Square.fromFull . Full.uncheck .
             Full.takeTop . Full.takeLeft . Square.toFull) a
      Omni.LowerTriangular _ -> Triangular.takeTopLeft a
      Omni.UpperTriangular _ -> Triangular.takeTopLeft a
      Omni.Symmetric _ -> Symmetric.takeTopLeft a
      Omni.Hermitian _ -> ArrMatrix.lift1 Mos.takeTopLeft a
      Omni.Banded _ -> ArrMatrix.lift1 Banded.takeTopLeftSquare a
      Omni.BandedHermitian _ -> ArrMatrix.lift1 BandedHermitian.takeTopLeft a
      Omni.UnitBandedTriangular _ -> ArrMatrix.lift1 Banded.takeTopLeftSquare a

takeBottomLeft ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Quadratic pack property Layout.Filled upper (sh0::+sh1) a ->
   General sh1 sh0 a
takeBottomLeft a =
   case ArrMatrix.shape a of
      Omni.Full _ ->
         ArrMatrix.liftUnpacked1
            (Full.takeBottom . Full.takeLeft . Square.toFull) a
      Omni.LowerTriangular _ -> Triangular.takeBottomLeft a
      Omni.Symmetric _ -> Matrix.transpose $ Symmetric.takeTopRight a
      Omni.Hermitian _ ->
         MatrixClass.adjoint $ ArrMatrix.lift1 Mos.takeTopRight a

takeTopRight ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Quadratic pack property lower Layout.Filled (sh0::+sh1) a ->
   General sh0 sh1 a
takeTopRight a =
   case ArrMatrix.shape a of
      Omni.Full _ ->
         ArrMatrix.liftUnpacked1
            (Full.takeTop . Full.takeRight . Square.toFull) a
      Omni.UpperTriangular _ -> Triangular.takeTopRight a
      Omni.Symmetric _ -> Symmetric.takeTopRight a
      Omni.Hermitian _ -> ArrMatrix.lift1 Mos.takeTopRight a

takeBottomRight ::
   (Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   Quadratic pack property lower upper (sh0::+sh1) a ->
   Quadratic pack property lower upper sh1 a
takeBottomRight a =
   case ArrMatrix.shape a of
      Omni.Full _ ->
         ArrMatrix.liftUnpacked1
            (Full.recheck . Square.fromFull . Full.uncheck .
             Full.takeBottom . Full.takeRight . Square.toFull) a
      Omni.LowerTriangular _ -> Triangular.takeBottomRight a
      Omni.UpperTriangular _ -> Triangular.takeBottomRight a
      Omni.Symmetric _ -> Symmetric.takeBottomRight a
      Omni.Hermitian _ -> ArrMatrix.lift1 Mos.takeBottomRight a
      Omni.Banded _ -> ArrMatrix.lift1 Banded.takeBottomRightSquare a
      Omni.BandedHermitian _ ->
         ArrMatrix.lift1 BandedHermitian.takeBottomRight a
      Omni.UnitBandedTriangular _ ->
         ArrMatrix.lift1 Banded.takeBottomRightSquare a