{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Singular (
   values,
   valuesTall,
   valuesWide,
   decompose,
   decomposeTall,
   decomposeWide,
   determinantAbsolute,
   leastSquaresMinimumNormRCond,
   pseudoInverseRCond,
   decomposePolar,
   RealOf,
   ) where

import qualified Numeric.LAPACK.Singular.Plain as Plain

import qualified Numeric.LAPACK.Matrix.Hermitian.Basic as HermitianBasic
import qualified Numeric.LAPACK.Matrix.Hermitian as Hermitian
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Shape as ExtShape
import Numeric.LAPACK.Matrix.Array (ArrayMatrix, Full, General, Square)
import Numeric.LAPACK.Matrix.Multiply ((##*#), (#*##))
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)

import Data.Tuple.HT (mapFst, mapSnd, mapPair, mapTriple)


values ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> Vector (ExtShape.Min height width) (RealOf a)
values = Plain.values . ArrMatrix.toVector

valuesTall ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   Full vert Extent.Small height width a -> Vector width (RealOf a)
valuesTall = Plain.valuesTall . ArrMatrix.toVector

valuesWide ::
   (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   Full Extent.Small horiz height width a -> Vector height (RealOf a)
valuesWide = Plain.valuesWide . ArrMatrix.toVector


determinantAbsolute ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a -> RealOf a
determinantAbsolute = Plain.determinantAbsolute . ArrMatrix.toVector


decompose ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   General height width a ->
   (Square height a,
    Vector (ExtShape.Min height width) (RealOf a),
    Square width a)
decompose = liftDecompose Plain.decompose

{- |
> let (u,s,vt) = Singular.decomposeWide a
> in a  ==  u #*## Matrix.scaleRowsReal s vt
-}
decomposeWide ::
   (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   Full Extent.Small horiz height width a ->
   (Square height a, Vector height (RealOf a),
      Full Extent.Small horiz height width a)
decomposeWide = liftDecompose Plain.decomposeWide

{- |
> let (u,s,vt) = Singular.decomposeTall a
> in a  ==  u ##*# Matrix.scaleRowsReal s vt
-}
decomposeTall ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   Full vert Extent.Small height width a ->
   (Full vert Extent.Small height width a,
      Vector width (RealOf a), Square width a)
decomposeTall = liftDecompose Plain.decomposeTall

liftDecompose ::
   (Array sha a -> (Array shb b, f, Array shc c)) ->
   ArrayMatrix sha a -> (ArrayMatrix shb b, f, ArrayMatrix shc c)
liftDecompose f =
   mapTriple (ArrMatrix.lift0, id, ArrMatrix.lift0) . f . ArrMatrix.toVector



leastSquaresMinimumNormRCond ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   RealOf a ->
   Full horiz vert height width a ->
   Full vert horiz height nrhs a ->
   (Int, Full vert horiz width nrhs a)
leastSquaresMinimumNormRCond rcond a b =
   mapSnd ArrMatrix.lift0 $
   Plain.leastSquaresMinimumNormRCond
      rcond (ArrMatrix.toVector a) (ArrMatrix.toVector b)

pseudoInverseRCond ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   RealOf a ->
   Full vert horiz height width a ->
   (Int, Full horiz vert width height a)
pseudoInverseRCond rcond =
   mapSnd (ArrMatrix.lift0 . Basic.recheck) .
   Plain.pseudoInverseRCond rcond .
   Basic.uncheck . ArrMatrix.toVector



decomposePolar ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   (Full vert horiz height width a, Matrix.Hermitian width a)
decomposePolar =
   mapPair
      (ArrMatrix.lift1 Basic.recheck,
       ArrMatrix.lift1 HermitianBasic.recheck)
   .
   getDecomposePolar
      (Extent.switchTagPair
         (DecomposePolar decomposePolarWide)
         (DecomposePolar decomposePolarWide)
         (DecomposePolar decomposePolarTall)
         (DecomposePolar $
            either
               (mapFst Matrix.fromFull . decomposePolarTall)
               (mapFst Matrix.fromFull . decomposePolarWide)
            .
            Matrix.caseTallWide))
   .
   ArrMatrix.lift1 Basic.uncheck

newtype DecomposePolar height width a vert horiz =
   DecomposePolar {
      getDecomposePolar ::
         Full vert horiz height width a ->
         (Full vert horiz height width a, Matrix.Hermitian width a)
   }

decomposePolarTall ::
   (Extent.C vert, Shape.C height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert Extent.Small height width a ->
   (Full vert Extent.Small height width a, Matrix.Hermitian width a)
decomposePolarTall a =
   let (u,s,vt) = decomposeTall a
   in (u ##*# vt, Hermitian.congruenceDiagonal s $ Matrix.fromFull vt)

decomposePolarWide ::
   (Extent.C horiz, Shape.C height, Eq height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full Extent.Small horiz height width a ->
   (Full Extent.Small horiz height width a, Matrix.Hermitian width a)
decomposePolarWide a =
   let (u,s,vt) = decomposeWide a
   in (u #*## vt, Hermitian.congruenceDiagonal s $ Matrix.fromFull vt)