{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE StandaloneDeriving #-}
module Numeric.LAPACK.Linear.Plain (
   LowerUpper,
   Tall, Wide, Square, LiberalSquare,
   Transposition(..),
   Conjugation(..),
   Inversion(..),
   mapExtent,
   fromMatrix,
   toMatrix,
   solve,
   multiplyFull,

   determinant,

   extractP,
   multiplyP,

   extractL,
   wideExtractL,
   wideMultiplyL,
   wideSolveL,

   extractU,
   tallExtractU,
   tallMultiplyU,
   tallSolveU,

   caseTallWide,
   ) where

import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Plain.Format as ArrFormat
import qualified Numeric.LAPACK.Matrix.Layout as LayoutPub
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict
import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix.Private as MatrixPriv
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Output ((/+/))
import Numeric.LAPACK.Matrix.Plain.Format (formatArray)
import Numeric.LAPACK.Matrix.Type.Private (Matrix)
import Numeric.LAPACK.Matrix.Triangular.Basic (Lower, Upper)
import Numeric.LAPACK.Matrix.Layout.Private
         (Order(RowMajor, ColumnMajor), Triangle(Triangle))
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed, Transposed),
          Conjugation(NonConjugated, Conjugated),
          Inversion(NonInverted, Inverted))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver, withInfo)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked), deconsUnchecked)
import Numeric.LAPACK.Private
         (copyBlock, copyTransposed, copyToColumnMajor, copyToColumnMajorTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!))

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (forM_)

import Data.Monoid ((<>))


data LU

data instance Matrix LU xl xu lower upper meas vert horiz height width a where
   LowerUpper ::
      Banded.RectangularDiagonal meas vert horiz
         height width (Perm.Element height) ->
      SplitArray meas vert horiz height width a ->
      LowerUpperFlex lower upper meas vert horiz height width a

deriving instance
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Storable a,
    Show height, Show width, Show a) =>
   Show (Matrix LU xl xu lower upper meas vert horiz height width a)

type SplitArray meas vert horiz height width a =
   Split.Split Layout.Triangle meas vert horiz height width a

split_ ::
   Matrix LU xl xu lower upper meas vert horiz height width a ->
   SplitArray meas vert horiz height width a
split_ (LowerUpper _pivot split) = split

type LowerUpperFlex = Matrix LU () ()
type LowerUpper = LowerUpperFlex Layout.Filled Layout.Filled
type Tall height width =
         LowerUpper Extent.Size Extent.Big Extent.Small height width
type Wide height width =
         LowerUpper Extent.Size Extent.Small Extent.Big height width
type LiberalSquare height width = SquareMeas Extent.Size height width
type Square sh = SquareMeas Extent.Shape sh sh
type SquareMeas meas height width =
         LowerUpper meas Extent.Small Extent.Small height width

instance Matrix.Format LU where
   type FormatExtra LU extra = extra ~ ()
   format fmt lu@(LowerUpper _ipiv m) =
      Perm.format (extractP NonInverted lu)
      /+/
      formatArray fmt m

instance Matrix.Layout LU where
   type LayoutExtra LU extra = extra ~ ()
   layout (LowerUpper _ipiv m) =
      ArrFormat.splitArrayFromList2 (Layout.splitExtent $ Array.shape m) $
      ArrFormat.layoutSplit m


mapExtent ::
   (Extent.C vertA, Extent.C horizA) =>
   (Extent.C vertB, Extent.C horizB) =>
   ExtentStrict.Map measA vertA horizA measB vertB horizB height width ->
   LowerUpperFlex lower upper measA vertA horizA height width a ->
   LowerUpperFlex lower upper measB vertB horizB height width a
mapExtent f (LowerUpper pivot split) =
   let g = ExtentStrict.apply f
   in LowerUpper (Banded.mapExtent g pivot) $
      Array.mapShape (Layout.splitMapExtent g) split


mapPivotHeight ::
   (sh0 -> sh1) ->
   Vector shape (Perm.Element sh0) -> Vector shape (Perm.Element sh1)
mapPivotHeight _f (Array shape xs) = Array shape (castForeignPtr xs)

mapHeight ::
   (Extent.C vert, Extent.C horiz) =>
   (heightA -> heightB) ->
   LowerUpperFlex lower upper Extent.Size vert horiz heightA width a ->
   LowerUpperFlex lower upper Extent.Size vert horiz heightB width a
mapHeight f (LowerUpper pivot split) =
   LowerUpper
      (Banded.mapHeight f $ mapPivotHeight f pivot)
      (Split.mapHeight f split)

mapWidth ::
   (Extent.C vert, Extent.C horiz) =>
   (widthA -> widthB) ->
   LowerUpperFlex lower upper Extent.Size vert horiz height widthA a ->
   LowerUpperFlex lower upper Extent.Size vert horiz height widthB a
mapWidth f (LowerUpper pivot split) =
   LowerUpper
      (Banded.mapWidth f pivot)
      (Split.mapWidth f split)


fromMatrix ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a ->
   LowerUpper meas vert horiz height width a
fromMatrix (Array (Layout.Full order extent) a) =
   let (height,width) = Extent.dimensions extent
   in uncurry LowerUpper $
      Array.unsafeCreateWithSizeAndResult
         (snd $ Layout.rectangularDiagonal extent) $ \_ ipivPtr ->
      ArrayIO.unsafeCreate
         (Layout.Split Layout.Triangle ColumnMajor extent) $ \luPtr ->

   evalContT $ do
      let m = Shape.size height
      let n = Shape.size width
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim m
      liftIO $ do
         copyToColumnMajor order m n aPtr luPtr
         withInfo "getrf" $
            LapackGen.getrf mPtr nPtr luPtr ldaPtr
               (Perm.deconsElementPtr ipivPtr)

solve ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
   Square height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
solve = solveTrans NonTransposed

solveTrans ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
   Transposition ->
   LowerUpperFlex lower upper
      Extent.Shape Extent.Small Extent.Small height height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
solveTrans trans
   (LowerUpper
      (Array _ ipiv)
      (Array (Layout.Split Layout.Triangle orderLU extentLU) lu)) =

   solver "LowerUpper.solve" (Extent.squareSize extentLU) $
         \n nPtr nrhsPtr xPtr ldxPtr -> do
      let lda = n
      transPtr <- Call.char $
         case trans of
            NonTransposed -> 'N'
            Transposed -> 'T'
      aPtr <-
         case orderLU of
            RowMajor -> copyToColumnMajorTemp orderLU n n lu
            ColumnMajor -> ContT $ withForeignPtr lu
      ldaPtr <- Call.leadingDim lda
      ipivPtr <- fmap Perm.deconsElementPtr $ ContT $ withForeignPtr ipiv
      liftIO $
         withInfo "getrs" $
            LapackGen.getrs transPtr
               nPtr nrhsPtr aPtr ldaPtr ipivPtr xPtr ldxPtr

{- |
Caution:
@LU.determinant . LU.fromMatrix@ will fail for singular matrices.
-}
determinant ::
   (Extent.Measure meas, Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpperFlex lower upper meas Extent.Small Extent.Small height width a -> a
determinant (LowerUpper ipiv split) =
   Perm.condNegate (map Perm.deconsElement $ Array.toList ipiv) $
   Split.determinantR split


extractP ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Inversion ->
   LowerUpperFlex lower upper meas vert horiz height width a ->
   Perm.Permutation height
extractP inverted (LowerUpper ipiv _) =
   Perm.fromTruncatedPivots (Inverted <> inverted)
      (Array.mapShape Perm.Shape ipiv)

multiplyP ::
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA,
    Extent.Measure measB, Extent.C vertB, Extent.C horizB,
    Eq height, Shape.C height, Shape.C widthA, Shape.C widthB,
    Class.Floating a) =>
   Inversion ->
   LowerUpperFlex lower upper measA vertA horizA height widthA a ->
   Full measB vertB horizB height widthB a ->
   Full measB vertB horizB height widthB a
multiplyP inverted
      (LowerUpper ipiv@(Array shapeIPiv ipivFPtr)
         (Array (Layout.Split _ _ extentLU) _lu))
      (Array shape@(Layout.Full order extent) a) =
   Array.unsafeCreate shape $ \bPtr -> do

   Call.assert "multiplyP: heights mismatch"
      (Extent.height extentLU == Extent.height extent)

   let (height,width) = Extent.dimensions extent
   let m = Shape.size height
   let n = Shape.size width
   let k = Shape.size shapeIPiv

   evalContT $ do
      aPtr <- ContT $ withForeignPtr a
      ipivPtr <- ContT $ withForeignPtr ipivFPtr
      liftIO $ copyBlock (n*m) aPtr bPtr
      case order of
         ColumnMajor -> do
            nPtr <- Call.cint n
            ldaPtr <- Call.leadingDim m
            k1Ptr <- Call.cint 1
            k2Ptr <- Call.cint k
            incxPtr <-
               Call.cint $
               case inverted of
                  Inverted -> 1
                  NonInverted -> -1
            liftIO $
               LapackGen.laswp nPtr bPtr ldaPtr k1Ptr k2Ptr
                  (Perm.deconsElementPtr ipivPtr) incxPtr
         RowMajor ->
            liftIO $
            swapColumns
               (Perm.Shape height) n bPtr (Array.mapShape Perm.Shape ipiv) $
            Perm.indices (Inverted <> inverted) (Perm.Shape shapeIPiv)

{-# INLINE swapColumns #-}
swapColumns ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   (diagShape ~ Layout.RectangularDiagonal meas vert horiz height width) =>
   Perm.Shape height -> Int -> Ptr a ->
   Array (Perm.Shape diagShape) (Perm.Element height) ->
   [Perm.Element diagShape] -> IO ()
swapColumns sh n xPtr ipiv is = evalContT $ do
   nPtr <- Call.cint n
   incPtr <- Call.cint 1
   let mapIx (Perm.Element i) = Perm.Element i
   let columnPtr ix = advancePtr xPtr (n * Shape.uncheckedOffset sh ix)
   liftIO $ forM_ is $ \i ->
      BlasGen.swap nPtr (columnPtr $ mapIx i) incPtr (columnPtr (ipiv!i)) incPtr



extractL ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height width a ->
   Full meas vert horiz height width a
extractL = Split.extractTriangle (Left Triangle) . split_

wideExtractL ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpperFlex lower upper meas Extent.Small horiz height width a ->
   Lower height a
wideExtractL = Split.wideExtractL . split_

wideMultiplyL ::
   (Extent.Measure measA, Extent.C horizA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height,
    Shape.C widthA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpperFlex lower upper measA Extent.Small horizA height widthA a ->
   Full meas vert horiz height widthB a ->
   Full meas vert horiz height widthB a
wideMultiplyL transposed = Split.wideMultiplyL transposed . split_

wideSolveL ::
   (Extent.Measure measA, Extent.C horizA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpperFlex lower upper measA Extent.Small horizA height width a ->
   Full meas vert horiz height nrhs a -> Full meas vert horiz height nrhs a
wideSolveL transposed conjugated =
   Split.wideSolveL transposed conjugated . split_


extractU ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height width a ->
   Full meas vert horiz height width a
extractU = Split.extractTriangle (Right Triangle) . split_

tallExtractU ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert Extent.Small height width a ->
   Upper width a
tallExtractU = Split.tallExtractR . split_

tallMultiplyU ::
   (Extent.Measure measA, Extent.C vertA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height,
    Shape.C heightA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpperFlex lower upper measA vertA Extent.Small heightA height a ->
   Full meas vert horiz height widthB a ->
   Full meas vert horiz height widthB a
tallMultiplyU transposed = Split.tallMultiplyR transposed . split_

tallSolveU ::
   (Extent.Measure measA, Extent.C vertA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpperFlex lower upper measA vertA Extent.Small height width a ->
   Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a
tallSolveU transposed conjugated =
   Split.tallSolveR transposed conjugated . split_



toMatrix ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height width a ->
   Full meas vert horiz height width a
toMatrix =
   getToMatrix $
   ExtentPriv.switchTagTriple
      (ToMatrix wideToMatrix)
      (ToMatrix wideToMatrix)
      (ToMatrix wideToMatrix)
      (ToMatrix tallToMatrix)
      (ToMatrix $
         either
            (MatrixPriv.fromFull . tallToMatrix)
            (MatrixPriv.fromFull . wideToMatrix) .
         caseTallWide)

newtype ToMatrix lower upper height width a meas vert horiz =
   ToMatrix {
      getToMatrix ::
         LowerUpperFlex lower upper meas vert horiz height width a ->
         Full meas vert horiz height width a
   }

tallToMatrix ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert Extent.Small height width a ->
   Full meas vert Extent.Small height width a
tallToMatrix a =
   multiplyP NonInverted a $ Basic.transpose $
   tallMultiplyU Transposed a $ Basic.transpose $ extractL a

wideToMatrix ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) =>
   LowerUpperFlex lower upper meas Extent.Small horiz height width a ->
   Full meas Extent.Small horiz height width a
wideToMatrix a =
   multiplyP NonInverted a $ wideMultiplyL NonTransposed a $ extractU a


multiplyFull ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height fuse a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
multiplyFull a =
   case Matrix.extent a of
      ExtentPriv.Square _ -> multiplyFullAux a
      ExtentPriv.Separate _ _ ->
         Basic.mapHeight deconsUnchecked .
         multiplyFullAux (mapHeight Unchecked a)

multiplyFullAux ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height fuse a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
multiplyFullAux =
   getMultiplyFullRight $
   ExtentPriv.switchTagTriple
      {-
      We cannot simply use squareFull here,
      because this requires height~width.
      -}
      (MultiplyFullRight wideMultiplyFullRight)
      (MultiplyFullRight wideMultiplyFullRight)
      (MultiplyFullRight wideMultiplyFullRight)
      (MultiplyFullRight tallMultiplyFullRight)
      (MultiplyFullRight $
         either tallMultiplyFullRight wideMultiplyFullRight . caseTallWide)

newtype MultiplyFullRight lower upper height fuse width a meas vert horiz =
   MultiplyFullRight {
      getMultiplyFullRight ::
         LowerUpperFlex lower upper meas vert horiz height fuse a ->
         Full meas vert horiz fuse width a ->
         Full meas vert horiz height width a
   }

tallMultiplyFullRight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpperFlex lower upper meas vert Extent.Small height fuse a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
tallMultiplyFullRight a =
   multiplyP NonInverted a .
   Basic.multiply (MatrixPriv.weakenTall (extractL a)) .
   tallMultiplyU NonTransposed a

wideMultiplyFullRight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpperFlex lower upper meas Extent.Small horiz height fuse a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
wideMultiplyFullRight a =
   multiplyP NonInverted a . wideMultiplyL NonTransposed a .
   Basic.multiply (MatrixPriv.weakenWide (extractU a))


transMultiplyVector ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width,
    Eq height, Eq width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height width a ->
   Vector height a -> Vector width a
transMultiplyVector =
   Basic.unliftColumn Layout.ColumnMajor .
   either tallTransMultiplyFullRight wideTransMultiplyFullRight . caseTallWide

tallTransMultiplyFullRight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpperFlex lower upper meas horiz Extent.Small fuse height a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
tallTransMultiplyFullRight a =
   tallMultiplyU Transposed a .
   Basic.multiply (Basic.transpose $ MatrixPriv.weakenTall $ extractL a) .
   multiplyP Inverted a

wideTransMultiplyFullRight ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpperFlex lower upper meas Extent.Small vert fuse height a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
wideTransMultiplyFullRight a =
   Basic.multiply (Basic.transpose $ MatrixPriv.weakenWide $ extractU a) .
   wideMultiplyL Transposed a .
   multiplyP Inverted a


caseTallWide ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   LowerUpperFlex lower upper meas vert horiz height width a ->
   Either (Tall height width a) (Wide height width a)
caseTallWide (LowerUpper ipiv (Array shape a)) =
   let consLU ipivb b newShape =
         LowerUpper
            (Array.mapShape
               (\bandShape ->
                  bandShape{Layout.bandedExtent = Layout.splitExtent newShape})
               ipivb)
            (Array newShape b)
   in either (Left . consLU ipiv a) (Right . consLU ipiv a) $
      Layout.caseTallWideSplit shape


_toRowMajor ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Eq height, Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpperFlex lower upper meas vert horiz height width a ->
   LowerUpperFlex lower upper meas vert horiz height width a
_toRowMajor
   (LowerUpper ipiv
      arr@(Array (Layout.Split Layout.Triangle order extent) a)) =
   LowerUpper ipiv $
   case order of
      RowMajor -> arr
      ColumnMajor ->
         Array.unsafeCreate
            (Layout.Split Layout.Triangle RowMajor extent) $ \bPtr ->
         withForeignPtr a $ \aPtr -> do
            let (height, width) = Extent.dimensions extent
            let n = Shape.size width
            let m = Shape.size height
            copyTransposed n m aPtr n bPtr


instance Matrix.Box LU where
   type BoxExtra LU extra = extra ~ ()
   extent = Layout.splitExtent . Array.shape . split_

instance Matrix.ToQuadratic LU where
   heightToQuadratic (LowerUpper pivot split) =
      LowerUpper
         (Array.mapShape (layoutPivotSquare . Layout.bandedHeight) pivot)
         (Split.heightToQuadratic split)
   widthToQuadratic (LowerUpper pivot split) =
      LowerUpper
         (mapPivotHeight (const $ Layout.bandedWidth $ Array.shape pivot) $
          Array.mapShape (layoutPivotSquare . Layout.bandedWidth) pivot)
         (Split.widthToQuadratic split)

layoutPivotSquare :: sh -> Layout.Diagonal sh
layoutPivotSquare = LayoutPub.diagonal Layout.ColumnMajor

instance Matrix.MapExtent LU where
   type MapExtentExtra LU extra = extra ~ ()
   type MapExtentStrip LU strip = ()
   mapExtent = mapExtent

instance Multiply.MultiplyVector LU where
   type MultiplyVectorExtra LU extra = extra ~ ()
   matrixVector lu =
      Basic.unliftColumn Layout.ColumnMajor
         (multiplyFull (mapExtent Extent.toGeneral lu))
   vectorMatrix = flip $ \lu ->
      case Matrix.extent lu of
         ExtentPriv.Square _ -> transMultiplyVector lu
         ExtentPriv.Separate _ _ ->
            Array.mapShape deconsUnchecked .
            transMultiplyVector (mapWidth Unchecked lu)

instance Multiply.MultiplySquare LU where
   type MultiplySquareExtra LU extra = extra ~ ()
   squareFull lu =
      ArrMatrix.lift1 $
         multiplyP NonInverted lu .
         wideMultiplyL NonTransposed lu .
         tallMultiplyU NonTransposed lu

   fullSquare = flip $ \lu ->
      ArrMatrix.lift1 $
         Basic.transpose .
         tallMultiplyU Transposed lu .
         wideMultiplyL Transposed lu .
         multiplyP Inverted lu .
         Basic.transpose

instance Divide.Determinant LU where
   type DeterminantExtra LU extra = extra ~ ()
   determinant = determinant

instance Divide.Solve LU where
   type SolveExtra LU extra = extra ~ ()
   solve trans = ArrMatrix.lift1 . solveTrans trans