{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Linear.LowerUpper (
   LowerUpper,
   Plain.Tall,
   Plain.Wide,
   Plain.Square,
   Plain.LiberalSquare,
   Plain.Transposition(..),
   Plain.Conjugation(..),
   Plain.Inversion(..),
   Plain.mapExtent,
   fromMatrix,
   toMatrix,
   solve,
   multiplyFull,

   Plain.determinant,

   extractP,
   multiplyP,

   extractL,
   wideExtractL,
   wideMultiplyL,
   wideSolveL,

   extractU,
   tallExtractU,
   tallMultiplyU,
   tallSolveU,

   Plain.caseTallWide,
   ) where

import qualified Numeric.LAPACK.Linear.Plain as Plain
import Numeric.LAPACK.Linear.Plain (LowerUpper)

import qualified Numeric.LAPACK.Matrix.Array.Unpacked as Unpacked
import qualified Numeric.LAPACK.Matrix.Array.Mosaic as Tri
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Shape as ExtShape
import Numeric.LAPACK.Matrix.Array.Private (Full)
import Numeric.LAPACK.Matrix.Modifier (Transposition, Conjugation, Inversion)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape


{- |
@LowerUpper.fromMatrix a@
computes the LU decomposition of matrix @a@ with row pivotisation.

You can reconstruct @a@ from @lu@ depending on whether @a@ is tall or wide.

> LU.multiplyP NonInverted lu $ LU.extractL lu ##*# LU.tallExtractU lu
> LU.multiplyP NonInverted lu $ LU.wideExtractL lu #*## LU.extractU lu
-}
fromMatrix ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a ->
   LowerUpper meas vert horiz height width a
fromMatrix = Plain.fromMatrix . ArrMatrix.toVector

solve ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
   Plain.Square height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
solve = ArrMatrix.lift1 . Plain.solve


extractP ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    ExtShape.Permutable height, Shape.C width) =>
   Inversion -> LowerUpper meas vert horiz height width a ->
   Matrix.Permutation height a
extractP inverted = PermMatrix.fromPermutation . Plain.extractP inverted

multiplyP ::
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA,
    Extent.Measure measB, Extent.C vertB, Extent.C horizB,
    Eq height, ExtShape.Permutable height, Shape.C widthA, Shape.C widthB,
    Class.Floating a) =>
   Inversion ->
   LowerUpper measA vertA horizA height widthA a ->
   Full measB vertB horizB height widthB a ->
   Full measB vertB horizB height widthB a
multiplyP inverted = ArrMatrix.lift1 . Plain.multiplyP inverted



extractL ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    ExtShape.Permutable height, ExtShape.Permutable width, Class.Floating a) =>
   LowerUpper meas vert horiz height width a ->
   Unpacked.LowerTrapezoid meas vert horiz height width a
extractL = ArrMatrix.liftUnpacked0 . Plain.extractL

wideExtractL ::
   (Extent.Measure meas, Extent.C horiz,
    ExtShape.Permutable height, Shape.C width, Class.Floating a) =>
   LowerUpper meas Extent.Small horiz height width a -> Tri.UnitLower height a
wideExtractL = ArrMatrix.lift0 . Plain.wideExtractL

{- |
@wideMultiplyL transposed lu a@ multiplies the square part of @lu@
containing the lower triangular matrix with @a@.

> wideMultiplyL NonTransposed lu a == wideExtractL lu #*## a
> wideMultiplyL Transposed lu a == Tri.transpose (wideExtractL lu) #*## a
-}
wideMultiplyL ::
   (Extent.Measure measA, Extent.C horizA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    ExtShape.Permutable height, Eq height,
    Shape.C widthA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpper measA Extent.Small horizA height widthA a ->
   Full meas vert horiz height widthB a ->
   Full meas vert horiz height widthB a
wideMultiplyL transposed = ArrMatrix.lift1 . Plain.wideMultiplyL transposed

wideSolveL ::
   (Extent.Measure measA, Extent.C horizA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    ExtShape.Permutable height, Eq height, Shape.C width, Shape.C nrhs,
    Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpper measA Extent.Small horizA height width a ->
   Full meas vert horiz height nrhs a -> Full meas vert horiz height nrhs a
wideSolveL transposed conjugated =
   ArrMatrix.lift1 . Plain.wideSolveL transposed conjugated


extractU ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    ExtShape.Permutable height, ExtShape.Permutable width, Class.Floating a) =>
   LowerUpper meas vert horiz height width a ->
   Unpacked.UpperTrapezoid meas vert horiz height width a
extractU = ArrMatrix.liftUnpacked0 . Plain.extractU

tallExtractU ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, ExtShape.Permutable width, Class.Floating a) =>
   LowerUpper meas vert Extent.Small height width a -> Tri.Upper width a
tallExtractU = ArrMatrix.lift0 . Plain.tallExtractU

{- |
@tallMultiplyU transposed lu a@ multiplies the square part of @lu@
containing the upper triangular matrix with @a@.

> tallMultiplyU NonTransposed lu a == tallExtractU lu #*## a
> tallMultiplyU Transposed lu a == Tri.transpose (tallExtractU lu) #*## a
-}
tallMultiplyU ::
   (Extent.Measure measA, Extent.C vertA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    ExtShape.Permutable height, Eq height,
    Shape.C heightA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpper measA vertA Extent.Small heightA height a ->
   Full meas vert horiz height widthB a ->
   Full meas vert horiz height widthB a
tallMultiplyU transposed = ArrMatrix.lift1 . Plain.tallMultiplyU transposed

tallSolveU ::
   (Extent.Measure measA, Extent.C vertA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, ExtShape.Permutable width, Eq width, Shape.C nrhs,
    Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpper measA vertA Extent.Small height width a ->
   Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a
tallSolveU transposed conjugated =
   ArrMatrix.lift1 . Plain.tallSolveU transposed conjugated



toMatrix ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   LowerUpper meas vert horiz height width a ->
   Full meas vert horiz height width a
toMatrix = ArrMatrix.lift0 . Plain.toMatrix


multiplyFull ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   LowerUpper meas vert horiz height fuse a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
multiplyFull = ArrMatrix.lift1 . Plain.multiplyFull