{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Triangular (
   Triangular, MatrixShape.UpLo,
   Upper, FlexUpper, Triangular.UnitUpper, Triangular.QuasiUpper,
   Lower, FlexLower, Triangular.UnitLower,
   size,
   fromList, autoFromList,
   lowerFromList, autoLowerFromList,
   upperFromList, autoUpperFromList,
   asLower, asUpper,
   requireUnitDiagonal, requireArbitraryDiagonal,
   relaxUnitDiagonal, strictArbitraryDiagonal,
   OmniMatrix.identityOrder,
   diagonal,
   takeDiagonal,
   transpose,
   adjoint,

   stackLower, (#%%%),
   stackUpper, (%%%#),
   splitLower,
   splitUpper,
   takeTopLeft,
   takeTopRight,
   takeBottomLeft,
   takeBottomRight,

   pack,
   toSquare,
   takeLower,
   takeUpper,

   fromLowerRowMajor, toLowerRowMajor,
   fromUpperRowMajor, toUpperRowMajor,
   forceOrder, adaptOrder,

   add, sub,

   multiplyVector,
   square,
   multiply,
   multiplyFull,

   solve,
   inverse,
   determinant,

   eigenvalues,
   eigensystem,
   ) where

import qualified Numeric.LAPACK.Matrix.Triangular.Eigen as Eigen
import qualified Numeric.LAPACK.Matrix.Triangular.Linear as Linear
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Mosaic.Basic as Mosaic
import qualified Numeric.LAPACK.Matrix.Mosaic.Packed as Packed
import qualified Numeric.LAPACK.Matrix.Mosaic.Generic as Mos
import qualified Numeric.LAPACK.Matrix.Basic as FullBasic

import qualified Numeric.LAPACK.Matrix.Array.Mosaic as Triangular
import qualified Numeric.LAPACK.Matrix.Array.Unpacked as ArrUnpacked
import qualified Numeric.LAPACK.Matrix.Array.Basic as OmniMatrix
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Class as MatrixClass
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Array.Mosaic (
   Triangular, TriangularP,
   Lower, FlexLower, FlexLowerP,
   Upper, FlexUpper, FlexUpperP,
   )
import Numeric.LAPACK.Matrix.Array.Unpacked (Unpacked)
import Numeric.LAPACK.Matrix.Array.Private (Full, General, Square, packTag, diagTag)
import Numeric.LAPACK.Matrix.Shape.Omni (Arbitrary, Unit)
import Numeric.LAPACK.Matrix.Layout.Private (Order, Filled)
import Numeric.LAPACK.Matrix.Private (ShapeInt)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)
import Data.Array.Comfort.Shape ((::+))

import Foreign.Storable (Storable)

import Data.Tuple.HT (mapPair)
import Data.Function.HT (Id)


size :: TriangularP pack lo diag up sh a -> sh
size = Omni.squareSize . ArrMatrix.shape

transpose ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    MatrixShape.TriDiag diag, Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> TriangularP pack up diag lo sh a
transpose = Matrix.transpose

adjoint ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    MatrixShape.TriDiag diag, Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> TriangularP pack up diag lo sh a
adjoint = MatrixClass.adjoint


fromList ::
   (MatrixShape.UpLo lo up, Shape.C sh, Storable a) =>
   Order -> sh -> [a] -> Triangular lo Arbitrary up sh a
fromList order sh xs =
   let m =
         case uploTag m of
            MatrixShape.Upper -> upperFromList order sh xs
            MatrixShape.Lower -> lowerFromList order sh xs
   in m

lowerFromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Lower sh a
lowerFromList order sh = ArrMatrix.lift0 . Mos.fromList order sh

upperFromList :: (Shape.C sh, Storable a) => Order -> sh -> [a] -> Upper sh a
upperFromList order sh = ArrMatrix.lift0 . Mos.fromList order sh


autoFromList ::
   (MatrixShape.UpLo lo up, Storable a) =>
   Order -> [a] -> Triangular lo Arbitrary up ShapeInt a
autoFromList order xs =
   let m =
         case uploTag m of
            MatrixShape.Upper -> autoUpperFromList order xs
            MatrixShape.Lower -> autoLowerFromList order xs
   in m

autoLowerFromList :: (Storable a) => Order -> [a] -> Lower ShapeInt a
autoLowerFromList order = ArrMatrix.lift0 . Mos.autoFromList order

autoUpperFromList :: (Storable a) => Order -> [a] -> Upper ShapeInt a
autoUpperFromList order = ArrMatrix.lift0 . Mos.autoFromList order


asLower :: Id (FlexLowerP pack diag sh a)
asLower = id

asUpper :: Id (FlexUpperP pack diag sh a)
asUpper = id


requireUnitDiagonal :: Id (TriangularP pack lo Unit up sh a)
requireUnitDiagonal = id

requireArbitraryDiagonal :: Id (TriangularP pack lo Arbitrary up sh a)
requireArbitraryDiagonal = id


pack ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> Triangular lo diag up sh a
pack a =
   case uploTag a of
      MatrixShape.Lower -> ArrMatrix.lift1 Mosaic.pack a
      MatrixShape.Upper -> ArrMatrix.lift1 Mosaic.pack a

toSquare ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> Square sh a
toSquare = OmniMatrix.toFull

takeLower ::
   (Omni.Property property, Omni.Strip upper) =>
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Unpacked property Filled upper meas Extent.Small horiz height width a ->
   Lower height a
takeLower = ArrMatrix.lift0 . Basic.takeLower . ArrMatrix.unpackedToVector

takeUpper ::
   (Omni.Property property, Omni.Strip lower) =>
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Unpacked property lower Filled meas vert Extent.Small height width a ->
   Upper width a
takeUpper = ArrMatrix.lift0 . Basic.takeUpper . ArrMatrix.unpackedToVector

fromLowerRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Array (Shape.Triangular Shape.Lower sh) a -> Lower sh a
fromLowerRowMajor = ArrMatrix.lift0 . Basic.fromLowerRowMajor

fromUpperRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Array (Shape.Triangular Shape.Upper sh) a -> Upper sh a
fromUpperRowMajor = ArrMatrix.lift0 . Basic.fromUpperRowMajor

toLowerRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Lower sh a -> Array (Shape.Triangular Shape.Lower sh) a
toLowerRowMajor = Basic.toLowerRowMajor . ArrMatrix.toVector

toUpperRowMajor ::
   (Shape.C sh, Class.Floating a) =>
   Upper sh a -> Array (Shape.Triangular Shape.Upper sh) a
toUpperRowMajor = Basic.toUpperRowMajor . ArrMatrix.toVector

forceOrder ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   Order -> TriangularP pack lo diag up sh a -> TriangularP pack lo diag up sh a
forceOrder = ArrMatrix.forceOrder

{- |
@adaptOrder x y@ contains the data of @y@ with the layout of @x@.
-}
adaptOrder ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a ->
   TriangularP pack lo diag up sh a ->
   TriangularP pack lo diag up sh a
adaptOrder = ArrMatrix.adaptOrder

add, sub ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    Eq lo, Eq up, Eq sh, Shape.C sh, Class.Floating a) =>
   TriangularP pack lo Arbitrary up sh a ->
   TriangularP pack lo Arbitrary up sh a ->
   TriangularP pack lo Arbitrary up sh a
add = ArrMatrix.add
sub = ArrMatrix.sub


{-
identity ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    Shape.C sh, Class.Floating a) =>
   Order -> sh -> Triangular lo Unit up sh a
identity order = ArrMatrix.lift0 . Basic.identity order
-}

diagonal ::
   (MatrixShape.UpLo lo up, Shape.C sh, Class.Floating a) =>
   Order -> Vector sh a -> Triangular lo Arbitrary up sh a
diagonal order v =
   getDiagonal $
   MatrixShape.switchUpLo
      (Diagonal $ ArrMatrix.lift0 $ Packed.diagonal order v)
      (Diagonal $ ArrMatrix.lift0 $ Packed.diagonal order v)

newtype Diagonal_ sh a lo up =
   Diagonal {
      getDiagonal :: Triangular lo Arbitrary up sh a
   }

takeDiagonal ::
   (MatrixShape.PowerStrip lo, MatrixShape.PowerStrip up,
    MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> Vector sh a
takeDiagonal = OmniMatrix.takeDiagonal



relaxUnitDiagonal ::
   (MatrixShape.TriDiag diag) =>
   TriangularP pack lo Unit up sh a -> TriangularP pack lo diag up sh a
relaxUnitDiagonal a@(ArrMatrix.Array _arr) =
   case ArrMatrix.shape a of
      Omni.Full _ -> ArrMatrix.liftUnpacked1 id a
      Omni.LowerTriangular _ -> ArrMatrix.lift1 id a
      Omni.UpperTriangular _ -> ArrMatrix.lift1 id a
      Omni.UnitBandedTriangular _ ->
         let m =
               case diagTag m of
                  Omni.Unit -> a
                  Omni.Arbitrary -> ArrMatrix.lift1 id a
         in m

strictArbitraryDiagonal ::
   (MatrixShape.TriDiag diag) =>
   TriangularP pack lo diag up sh a -> TriangularP pack lo Arbitrary up sh a
strictArbitraryDiagonal a =
   case diagTag a of
      Omni.Arbitrary -> a
      Omni.Unit ->
         case ArrMatrix.shape a of
            Omni.Full _ -> ArrMatrix.liftUnpacked1 id a
            Omni.LowerTriangular _ -> ArrMatrix.lift1 id a
            Omni.UpperTriangular _ -> ArrMatrix.lift1 id a
            Omni.UnitBandedTriangular _ -> ArrMatrix.lift1 id a


infixr 2 %%%#
infixl 2 #%%%

stackLower ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexLowerP pack diag sh0 a ->
   General sh1 sh0 a ->
   FlexLowerP pack diag sh1 a ->
   FlexLowerP pack diag (sh0::+sh1) a
stackLower a0 =
   case packTag a0 of
      Layout.Packed -> ArrMatrix.lift3 Packed.stackLower a0
      Layout.Unpacked -> ($ a0) $
         ArrMatrix.liftUnpacked3 $ \a b c ->
            FullBasic.stackMosaic
               a (Vector.zero $ Layout.inverse $ Array.shape b)
               b c

(#%%%) ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexLowerP pack diag sh0 a ->
   (General sh1 sh0 a, FlexLowerP pack diag sh1 a) ->
   FlexLowerP pack diag (sh0::+sh1) a
(#%%%) = uncurry . stackLower

stackUpper ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexUpperP pack diag sh0 a ->
   General sh0 sh1 a ->
   FlexUpperP pack diag sh1 a ->
   FlexUpperP pack diag (sh0::+sh1) a
stackUpper a0 =
   case packTag a0 of
      Layout.Packed -> ArrMatrix.lift3 Packed.stackUpper a0
      Layout.Unpacked -> ($ a0) $
         ArrMatrix.liftUnpacked3 $ \a b c ->
            FullBasic.stackMosaic a b
               (Vector.zero $ Layout.inverse $ Array.shape b) c

(%%%#) ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   (FlexUpperP pack diag sh0 a, General sh0 sh1 a) ->
   FlexUpperP pack diag sh1 a ->
   FlexUpperP pack diag (sh0::+sh1) a
(%%%#) = uncurry stackUpper


splitLower ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexLowerP pack diag (sh0::+sh1) a ->
   (FlexLowerP pack diag sh0 a, General sh1 sh0 a, FlexLowerP pack diag sh1 a)
splitLower a = (takeTopLeft a, takeBottomLeft a, takeBottomRight a)

splitUpper ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Eq sh0, Shape.C sh1, Eq sh1, Class.Floating a) =>
   FlexUpperP pack diag (sh0::+sh1) a ->
   (FlexUpperP pack diag sh0 a, General sh0 sh1 a, FlexUpperP pack diag sh1 a)
splitUpper a = (takeTopLeft a, takeTopRight a, takeBottomRight a)


takeTopLeft ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   TriangularP pack lo diag up (sh0::+sh1) a ->
   TriangularP pack lo diag up sh0 a
takeTopLeft a =
   case packTag a of
      Layout.Unpacked -> ArrUnpacked.takeTopLeft a
      Layout.Packed ->
         case uploTag a of
            MatrixShape.Lower -> ArrMatrix.lift1 Packed.takeTopLeft a
            MatrixShape.Upper -> ArrMatrix.lift1 Packed.takeTopLeft a

takeBottomLeft ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexLowerP pack diag (sh0::+sh1) a -> General sh1 sh0 a
takeBottomLeft a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Packed.takeBottomLeft a
      Layout.Unpacked -> ArrUnpacked.takeBottomLeft a

takeTopRight ::
   (Layout.Packing pack, MatrixShape.TriDiag diag,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   FlexUpperP pack diag (sh0::+sh1) a -> General sh0 sh1 a
takeTopRight a =
   case packTag a of
      Layout.Packed -> ArrMatrix.lift1 Packed.takeTopRight a
      Layout.Unpacked -> ArrUnpacked.takeTopRight a

takeBottomRight ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh0, Shape.C sh1, Class.Floating a) =>
   TriangularP pack lo diag up (sh0::+sh1) a ->
   TriangularP pack lo diag up sh1 a
takeBottomRight a =
   case packTag a of
      Layout.Unpacked -> ArrUnpacked.takeBottomRight a
      Layout.Packed ->
         case uploTag a of
            MatrixShape.Lower -> ArrMatrix.lift1 Packed.takeBottomRight a
            MatrixShape.Upper -> ArrMatrix.lift1 Packed.takeBottomRight a


uploTag ::
   (MatrixShape.UpLo lo up) =>
   TriangularP pack lo diag up sh a -> MatrixShape.UpLoSingleton lo up
uploTag _ = MatrixShape.autoUplo


multiplyVector ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Eq sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> Vector sh a -> Vector sh a
multiplyVector a =
   case uploTag a of
      MatrixShape.Upper ->
         Basic.multiplyVector (diagTag a) $ ArrMatrix.toVector a
      MatrixShape.Lower ->
         Basic.multiplyVector (diagTag a) $ ArrMatrix.toVector a

square ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a ->
   TriangularP pack lo diag up sh a
square a =
   case uploTag a of
      MatrixShape.Upper -> ArrMatrix.lift1 (Mosaic.square $ diagTag a) a
      MatrixShape.Lower -> ArrMatrix.lift1 (Mosaic.square $ diagTag a) a

multiply ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Eq sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a ->
   TriangularP pack lo diag up sh a ->
   TriangularP pack lo diag up sh a
multiply a =
   case uploTag a of
      MatrixShape.Upper -> ArrMatrix.lift2 (Basic.multiply $ diagTag a) a
      MatrixShape.Lower -> ArrMatrix.lift2 (Basic.multiply $ diagTag a) a

multiplyFull ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   TriangularP pack lo diag up height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
multiplyFull a =
   case uploTag a of
      MatrixShape.Upper -> ArrMatrix.lift2 (Basic.multiplyFull $ diagTag a) a
      MatrixShape.Lower -> ArrMatrix.lift2 (Basic.multiplyFull $ diagTag a) a



solve ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   TriangularP pack lo diag up sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve a =
   case uploTag a of
      MatrixShape.Upper -> ArrMatrix.lift2 (Linear.solve $ diagTag a) a
      MatrixShape.Lower -> ArrMatrix.lift2 (Linear.solve $ diagTag a) a

inverse ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a ->
   TriangularP pack lo diag up sh a
inverse a =
   case uploTag a of
      MatrixShape.Upper -> ArrMatrix.lift1 (Linear.inverse $ diagTag a) a
      MatrixShape.Lower -> ArrMatrix.lift1 (Linear.inverse $ diagTag a) a

determinant ::
   (Layout.Packing pack, MatrixShape.UpLo lo up, MatrixShape.TriDiag diag,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> a
determinant a =
   case diagTag a of
      MatrixShape.Unit -> Scalar.one
      MatrixShape.Arbitrary ->
         case uploTag a of
            MatrixShape.Upper -> Linear.determinant $ ArrMatrix.toVector a
            MatrixShape.Lower -> Linear.determinant $ ArrMatrix.toVector a



eigenvalues ::
   (Layout.Packing pack, MatrixShape.DiagUpLo lo up,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo diag up sh a -> Vector sh a
eigenvalues = OmniMatrix.takeDiagonal

{- |
@(vr,d,vlAdj) = eigensystem a@

Counterintuitively, @vr@ contains the right eigenvectors as columns
and @vlAdj@ contains the left conjugated eigenvectors as rows.
The idea is to provide a decomposition of @a@.
If @a@ is diagonalizable, then @vr@ and @vlAdj@
are almost inverse to each other.
More precisely, @vlAdj \<\> vr@ is a diagonal matrix,
but not necessarily an identity matrix.
This is because all eigenvectors are normalized
such that 'Numeric.LAPACK.Vector.normInf1' is 1.
With the following scaling, the decomposition becomes perfect:

> let scal = takeDiagonal $ vlAdj <> vr
> a == vr <> diagonal (Vector.divide d scal) <> vlAdj

If @a@ is non-diagonalizable
then some columns of @vr@ and corresponding rows of @vlAdj@ are left zero
and the above property does not hold.
-}
eigensystem ::
   (Layout.Packing pack, MatrixShape.DiagUpLo lo up,
    Shape.C sh, Class.Floating a) =>
   TriangularP pack lo Arbitrary up sh a ->
   (TriangularP pack lo Arbitrary up sh a, Vector sh a,
    TriangularP pack lo Arbitrary up sh a)
eigensystem a =
   let (vr,vl) =
         getEigensystem
            (MatrixShape.switchDiagUpLo
               (Eigensystem $
                  (\eye -> (eye, Matrix.transpose eye)) .
                  OmniMatrix.identityFromShape .
                  Omni.uncheckedDiagonal Layout.ColumnMajor .
                  Omni.squareSize . ArrMatrix.shape)
               (Eigensystem $
                  mapPair (ArrMatrix.lift0, ArrMatrix.lift0) .
                  Eigen.decompose . ArrMatrix.toVector)
               (Eigensystem $
                  mapPair (ArrMatrix.lift0, ArrMatrix.lift0) .
                  Eigen.decompose . ArrMatrix.toVector))
            a
   in (vr, eigenvalues a, vl)

newtype Eigensystem pack sh a lo up =
   Eigensystem {
      getEigensystem ::
         TriangularP pack lo Arbitrary up sh a ->
         (TriangularP pack lo Arbitrary up sh a,
          TriangularP pack lo Arbitrary up sh a)
   }