{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Linear.LowerUpper (
   LowerUpper,
   Plain.Square,
   Plain.Tall,
   Plain.Wide,
   Plain.Transposition(..),
   Plain.Conjugation(..),
   Plain.Inversion(..),
   Plain.mapExtent,
   fromMatrix,
   toMatrix,
   solve,
   multiplyFull,

   Plain.determinant,

   extractP,
   multiplyP,

   extractL,
   wideExtractL,
   wideMultiplyL,
   wideSolveL,

   extractU,
   tallExtractU,
   tallMultiplyU,
   tallSolveU,

   Plain.caseTallWide,
   ) where

import qualified Numeric.LAPACK.Linear.Plain as Plain
import Numeric.LAPACK.Linear.Plain (LowerUpper)

import qualified Numeric.LAPACK.Matrix.Array.Triangular as Tri
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Permutation as PermMatrix
import qualified Numeric.LAPACK.Matrix as Matrix
import Numeric.LAPACK.Matrix.Array (Full)
import Numeric.LAPACK.Matrix.Modifier (Transposition, Conjugation, Inversion)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape


{- |
@LowerUpper.fromMatrix a@
computes the LU decomposition of matrix @a@ with row pivotisation.

You can reconstruct @a@ from @lu@ depending on whether @a@ is tall or wide.

> LU.multiplyP NonInverted lu $ LU.extractL lu ##*# LU.tallExtractU lu
> LU.multiplyP NonInverted lu $ LU.wideExtractL lu #*## LU.extractU lu
-}
fromMatrix ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   LowerUpper vert horiz height width a
fromMatrix = Plain.fromMatrix . ArrMatrix.toVector

solve ::
   (Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Plain.Square height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
solve = ArrMatrix.lift1 . Plain.solve


extractP ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
   Inversion -> LowerUpper vert horiz height width a ->
   Matrix.Permutation height a
extractP inverted = PermMatrix.fromPermutation . Plain.extractP inverted

multiplyP ::
   (Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB,
    Eq height, Shape.C height, Shape.C widthA, Shape.C widthB,
    Class.Floating a) =>
   Inversion ->
   LowerUpper vertA horizA height widthA a ->
   Full vertB horizB height widthB a ->
   Full vertB horizB height widthB a
multiplyP inverted = ArrMatrix.lift1 . Plain.multiplyP inverted



extractL ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   Full vert horiz height width a
extractL = ArrMatrix.lift0 . Plain.extractL

wideExtractL ::
   (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpper Extent.Small horiz height width a -> Tri.UnitLower height a
wideExtractL = ArrMatrix.lift0 . Plain.wideExtractL

{- |
@wideMultiplyL transposed lu a@ multiplies the square part of @lu@
containing the lower triangular matrix with @a@.

> wideMultiplyL NonTransposed lu a == wideExtractL lu #*## a
> wideMultiplyL Transposed lu a == Tri.transpose (wideExtractL lu) #*## a
-}
wideMultiplyL ::
   (Extent.C horizA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
    Shape.C widthA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpper Extent.Small horizA height widthA a ->
   Full vert horiz height widthB a ->
   Full vert horiz height widthB a
wideMultiplyL transposed = ArrMatrix.lift1 . Plain.wideMultiplyL transposed

wideSolveL ::
   (Extent.C horizA, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpper Extent.Small horizA height width a ->
   Full vert horiz height nrhs a -> Full vert horiz height nrhs a
wideSolveL transposed conjugated =
   ArrMatrix.lift1 . Plain.wideSolveL transposed conjugated


extractU ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   Full vert horiz height width a
extractU = ArrMatrix.lift0 . Plain.extractU

tallExtractU ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpper vert Extent.Small height width a -> Tri.Upper width a
tallExtractU = ArrMatrix.lift0 . Plain.tallExtractU

{- |
@tallMultiplyU transposed lu a@ multiplies the square part of @lu@
containing the upper triangular matrix with @a@.

> tallMultiplyU NonTransposed lu a == tallExtractU lu #*## a
> tallMultiplyU Transposed lu a == Tri.transpose (tallExtractU lu) #*## a
-}
tallMultiplyU ::
   (Extent.C vertA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
    Shape.C heightA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpper vertA Extent.Small heightA height a ->
   Full vert horiz height widthB a ->
   Full vert horiz height widthB a
tallMultiplyU transposed = ArrMatrix.lift1 . Plain.tallMultiplyU transposed

tallSolveU ::
   (Extent.C vertA, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpper vertA Extent.Small height width a ->
   Full vert horiz width nrhs a -> Full vert horiz width nrhs a
tallSolveU transposed conjugated =
   ArrMatrix.lift1 . Plain.tallSolveU transposed conjugated



toMatrix ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   Full vert horiz height width a
toMatrix = ArrMatrix.lift0 . Plain.toMatrix


multiplyFull ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   LowerUpper vert horiz height fuse a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
multiplyFull = ArrMatrix.lift1 . Plain.multiplyFull