{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyDataDecls #-}
module Numeric.LAPACK.Linear.Plain (
   LowerUpper,
   Square, Tall, Wide,
   Transposition(..),
   Conjugation(..),
   Inversion(..),
   mapExtent,
   fromMatrix,
   toMatrix,
   solve,
   multiplyFull,

   determinant,

   extractP,
   multiplyP,

   extractL,
   wideExtractL,
   wideMultiplyL,
   wideSolveL,

   extractU,
   tallExtractU,
   tallMultiplyU,
   tallSolveU,

   caseTallWide,
   ) where

import qualified Numeric.LAPACK.Matrix.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Array as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Extent as ExtentMap
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Shape as ExtShape
import qualified Numeric.LAPACK.Split as Split
import Numeric.LAPACK.Output ((/+/))
import Numeric.LAPACK.Matrix.Array.Format (formatArray)
import Numeric.LAPACK.Matrix.Type (FormatMatrix(formatMatrix))
import Numeric.LAPACK.Matrix.Triangular.Basic (UnitLower, Upper)
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor, ColumnMajor), Triangle(Triangle))
import Numeric.LAPACK.Matrix.Modifier
         (Transposition(NonTransposed, Transposed),
          Conjugation(NonConjugated, Conjugated),
          Inversion(NonInverted, Inverted))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver, withInfo)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Private (copyBlock, copyTransposed, copyToColumnMajor)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!))

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (forM_)

import Data.Monoid ((<>))


data LU vert horiz height width

data instance Type.Matrix (LU vert horiz height width) a =
   LowerUpper {
      _pivot ::
         Vector (ExtShape.Min width (Perm.Shape height)) (Perm.Element height),
      split_ ::
         Array
            (MatrixShape.Split MatrixShape.Triangle vert horiz height width) a
   } deriving (Show)

type LowerUpper vert horiz height width =
         Type.Matrix (LU vert horiz height width)
type Square sh = LowerUpper Extent.Small Extent.Small sh sh
type Tall height width = LowerUpper Extent.Big Extent.Small height width
type Wide height width = LowerUpper Extent.Small Extent.Big height width

instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      FormatMatrix (LU vert horiz height width) where
   formatMatrix fmt lu@(LowerUpper _ipiv m) =
      Perm.format (extractP NonInverted lu)
      /+/
      formatArray fmt m

mapExtent ::
   (Extent.C vertA, Extent.C horizA) =>
   (Extent.C vertB, Extent.C horizB) =>
   Extent.Map vertA horizA vertB horizB height width ->
   LowerUpper vertA horizA height width a ->
   LowerUpper vertB horizB height width a
mapExtent f (LowerUpper pivot split) =
   LowerUpper pivot $ Array.mapShape (MatrixShape.splitMapExtent f) split

fromMatrix ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   LowerUpper vert horiz height width a
fromMatrix (Array (MatrixShape.Full order extent) a) =
   let (height,width) = Extent.dimensions extent
   in uncurry LowerUpper $
      Array.unsafeCreateWithSizeAndResult
         (ExtShape.Min width $ Perm.Shape height) $ \_ ipivPtr ->
      ArrayIO.unsafeCreate
         (MatrixShape.Split MatrixShape.Triangle ColumnMajor extent) $ \luPtr ->

   evalContT $ do
      let m = Shape.size height
      let n = Shape.size width
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim m
      liftIO $ do
         copyToColumnMajor order m n aPtr luPtr
         withInfo "getrf" $
            LapackGen.getrf mPtr nPtr luPtr ldaPtr
               (Perm.deconsElementPtr ipivPtr)

solve ::
   (Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Square height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
solve = solveTrans NonTransposed

solveTrans ::
   (Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Transposition -> Square height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
solveTrans trans
   (LowerUpper
      (Array _ ipiv)
      (Array (MatrixShape.Split MatrixShape.Triangle orderLU extentLU) lu)) =

   solver "LowerUpper.solve" (Extent.squareSize extentLU) $
         \n nPtr nrhsPtr xPtr ldxPtr -> do
      let lda = n
      transPtr <- Call.char $
         case trans of
            NonTransposed -> 'N'
            Transposed -> 'T'
      aPtr <-
         case orderLU of
            RowMajor -> do
               aPtr <- ContT $ withForeignPtr lu
               atmpPtr <- Call.allocaArray (n*n)
               liftIO $ copyToColumnMajor orderLU n n aPtr atmpPtr
               return atmpPtr
            ColumnMajor -> ContT $ withForeignPtr lu
      ldaPtr <- Call.leadingDim lda
      ipivPtr <- fmap Perm.deconsElementPtr $ ContT $ withForeignPtr ipiv
      liftIO $
         withInfo "getrs" $
            LapackGen.getrs transPtr
               nPtr nrhsPtr aPtr ldaPtr ipivPtr xPtr ldxPtr

{- |
Caution:
@LU.determinant . LU.fromMatrix@ will fail for singular matrices.
-}
determinant :: (Shape.C sh, Class.Floating a) => Square sh a -> a
determinant (LowerUpper ipiv split) =
   Perm.condNegate (map Perm.deconsElement $ Array.toList ipiv) $
   Split.determinantR split


extractP ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
   Inversion -> LowerUpper vert horiz height width a -> Perm.Permutation height
extractP inverted (LowerUpper ipiv _) =
   Perm.fromTruncatedPivots (Inverted <> inverted) ipiv

multiplyP ::
   (Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB,
    Eq height, Shape.C height, Shape.C widthA, Shape.C widthB,
    Class.Floating a) =>
   Inversion ->
   LowerUpper vertA horizA height widthA a ->
   Full vertB horizB height widthB a ->
   Full vertB horizB height widthB a
multiplyP inverted
      (LowerUpper ipiv@(Array shapeIPiv ipivFPtr)
         (Array (MatrixShape.Split _ _ extentLU) _lu))
      (Array shape@(MatrixShape.Full order extent) a) =
   Array.unsafeCreate shape $ \bPtr -> do

   Call.assert "multiplyP: heights mismatch"
      (Extent.height extentLU == Extent.height extent)

   let (height,width) = Extent.dimensions extent
   let m = Shape.size height
   let n = Shape.size width
   let k = Shape.size shapeIPiv

   evalContT $ do
      aPtr <- ContT $ withForeignPtr a
      ipivPtr <- ContT $ withForeignPtr ipivFPtr
      liftIO $ copyBlock (n*m) aPtr bPtr
      case order of
         ColumnMajor -> do
            nPtr <- Call.cint n
            ldaPtr <- Call.leadingDim m
            k1Ptr <- Call.cint 1
            k2Ptr <- Call.cint k
            incxPtr <-
               Call.cint $
               case inverted of
                  Inverted -> 1
                  NonInverted -> -1
            liftIO $
               LapackGen.laswp nPtr bPtr ldaPtr k1Ptr k2Ptr
                  (Perm.deconsElementPtr ipivPtr) incxPtr
         RowMajor ->
            liftIO $ swapColumns height n bPtr ipiv $
            Perm.indices (Inverted <> inverted) shapeIPiv

{-# INLINE swapColumns #-}
swapColumns ::
   (Shape.C sh, Shape.C width, Class.Floating a) =>
   sh -> Int -> Ptr a ->
   Array (ExtShape.Min width (Perm.Shape sh)) (Perm.Element sh) ->
   [Perm.Element sh] -> IO ()
swapColumns sh n xPtr ipiv is = evalContT $ do
   nPtr <- Call.cint n
   incPtr <- Call.cint 1
   let columnPtr ix =
         advancePtr xPtr (n * Shape.uncheckedOffset (Perm.Shape sh) ix)
   liftIO $ forM_ is $ \i ->
      BlasGen.swap nPtr (columnPtr i) incPtr (columnPtr (ipiv!i)) incPtr



extractL ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   Full vert horiz height width a
extractL = Split.extractTriangle (Left Triangle) . split_

wideExtractL ::
   (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpper Extent.Small horiz height width a -> UnitLower height a
wideExtractL = Split.wideExtractL . split_

wideMultiplyL ::
   (Extent.C horizA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
    Shape.C widthA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpper Extent.Small horizA height widthA a ->
   Full vert horiz height widthB a ->
   Full vert horiz height widthB a
wideMultiplyL transposed = Split.wideMultiplyL transposed . split_

wideSolveL ::
   (Extent.C horizA, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpper Extent.Small horizA height width a ->
   Full vert horiz height nrhs a -> Full vert horiz height nrhs a
wideSolveL transposed conjugated =
   Split.wideSolveL transposed conjugated . split_


extractU ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   Full vert horiz height width a
extractU = Split.extractTriangle (Right Triangle) . split_

tallExtractU ::
   (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) =>
   LowerUpper vert Extent.Small height width a -> Upper width a
tallExtractU = Split.tallExtractR . split_

tallMultiplyU ::
   (Extent.C vertA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height,
    Shape.C heightA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   LowerUpper vertA Extent.Small heightA height a ->
   Full vert horiz height widthB a ->
   Full vert horiz height widthB a
tallMultiplyU transposed = Split.tallMultiplyR transposed . split_

tallSolveU ::
   (Extent.C vertA, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   LowerUpper vertA Extent.Small height width a ->
   Full vert horiz width nrhs a -> Full vert horiz width nrhs a
tallSolveU transposed conjugated =
   Split.tallSolveR transposed conjugated . split_



toMatrix ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   Full vert horiz height width a
toMatrix =
   getToMatrix $
   Extent.switchTagPair
      (ToMatrix wideToMatrix)
      (ToMatrix wideToMatrix)
      (ToMatrix tallToMatrix)
      (ToMatrix $
         either
            (Matrix.fromFull . tallToMatrix)
            (Matrix.fromFull . wideToMatrix) .
         caseTallWide)

newtype ToMatrix height width a vert horiz =
   ToMatrix {
      getToMatrix ::
         LowerUpper vert horiz height width a ->
         Full vert horiz height width a
   }

tallToMatrix ::
   (Extent.C vert, Shape.C height, Shape.C width, Eq height, Eq width,
    Class.Floating a) =>
   LowerUpper vert Extent.Small height width a ->
   Full vert Extent.Small height width a
tallToMatrix a =
   multiplyP NonInverted a $ Basic.transpose $
   tallMultiplyU Transposed a $ Basic.transpose $ extractL a

wideToMatrix ::
   (Extent.C horiz, Shape.C height, Shape.C width, Eq height, Eq width,
    Class.Floating a) =>
   LowerUpper Extent.Small horiz height width a ->
   Full Extent.Small horiz height width a
wideToMatrix a =
   multiplyP NonInverted a $ wideMultiplyL NonTransposed a $ extractU a


multiplyFull ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse,
    Class.Floating a) =>
   LowerUpper vert horiz height fuse a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
multiplyFull =
   getMultiplyFullRight $
   Extent.switchTagPair
      {-
      We cannot simply use squareFull here,
      because this requires height~width.
      -}
      (MultiplyFullRight wideMultiplyFullRight)
      (MultiplyFullRight wideMultiplyFullRight)
      (MultiplyFullRight tallMultiplyFullRight)
      (MultiplyFullRight $
         either tallMultiplyFullRight wideMultiplyFullRight . caseTallWide)

newtype MultiplyFullRight height fuse width a vert horiz =
   MultiplyFullRight {
      getMultiplyFullRight ::
         LowerUpper vert horiz height fuse a ->
         Full vert horiz fuse width a ->
         Full vert horiz height width a
   }

tallMultiplyFullRight ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpper vert Extent.Small height fuse a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
tallMultiplyFullRight a =
   multiplyP NonInverted a .
   Basic.multiply (Matrix.generalizeTall (extractL a)) .
   tallMultiplyU NonTransposed a

wideMultiplyFullRight ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpper Extent.Small horiz height fuse a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
wideMultiplyFullRight a =
   multiplyP NonInverted a . wideMultiplyL NonTransposed a .
   Basic.multiply (Matrix.generalizeWide (extractU a))


tallTransMultiplyFullRight ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpper horiz Extent.Small fuse height a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
tallTransMultiplyFullRight a =
   tallMultiplyU Transposed a .
   Basic.multiply (Basic.transpose $ Matrix.generalizeTall $ extractL a) .
   multiplyP Inverted a

wideTransMultiplyFullRight ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse,
    Class.Floating a) =>
   LowerUpper Extent.Small vert fuse height a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
wideTransMultiplyFullRight a =
   Basic.multiply (Basic.transpose $ Matrix.generalizeWide $ extractU a) .
   wideMultiplyL Transposed a .
   multiplyP Inverted a


caseTallWide ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
   LowerUpper vert horiz height width a ->
   Either (Tall height width a) (Wide height width a)
caseTallWide (LowerUpper ipiv (Array shape a)) =
   either
      (Left . LowerUpper ipiv . flip Array a)
      (Right . LowerUpper ipiv . flip Array a) $
   MatrixShape.caseTallWideSplit shape


_toRowMajor ::
   (Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width,
    Class.Floating a) =>
   LowerUpper vert horiz height width a ->
   LowerUpper vert horiz height width a
_toRowMajor
   (LowerUpper ipiv
      arr@(Array (MatrixShape.Split MatrixShape.Triangle order extent) a)) =
   LowerUpper ipiv $
   case order of
      RowMajor -> arr
      ColumnMajor ->
         Array.unsafeCreate
            (MatrixShape.Split MatrixShape.Triangle RowMajor extent) $ \bPtr ->
         withForeignPtr a $ \aPtr -> do
            let (height, width) = Extent.dimensions extent
            let n = Shape.size width
            let m = Shape.size height
            copyTransposed n m aPtr n bPtr


instance
   (Extent.C vert, Extent.C horiz) =>
      Type.Box (LU vert horiz height width) where
   type HeightOf (LU vert horiz height width) = height
   type WidthOf (LU vert horiz height width) = width
   height = MatrixShape.splitHeight . Array.shape . split_
   width = MatrixShape.splitWidth . Array.shape . split_

instance
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width) =>
      Multiply.MultiplyRight (LU vert horiz height width) where
   matrixVector lu x =
      Basic.unliftColumn MatrixShape.ColumnMajor
         (multiplyFull (mapExtent ExtentMap.toGeneral lu)) x

instance
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width) =>
      Multiply.MultiplyLeft (LU vert horiz height width) where
   vectorMatrix x lu =
      Basic.unliftColumn MatrixShape.ColumnMajor
         (either tallTransMultiplyFullRight wideTransMultiplyFullRight $
          caseTallWide lu)
         x

instance
   (vert ~ Extent.Small, horiz ~ Extent.Small,
    Shape.C height, height ~ width) =>
      Multiply.MultiplySquare (LU vert horiz height width) where

   squareFull lu =
      ArrMatrix.lift1 $
         multiplyP NonInverted lu .
         wideMultiplyL NonTransposed lu .
         tallMultiplyU NonTransposed lu

   fullSquare = flip $ \lu ->
      ArrMatrix.lift1 $
         Basic.transpose .
         tallMultiplyU Transposed lu .
         wideMultiplyL Transposed lu .
         multiplyP Inverted lu .
         Basic.transpose

instance
   (vert ~ Extent.Small, horiz ~ Extent.Small,
    Shape.C height, height ~ width) =>
      Divide.Solve (LU vert horiz height width) where
   solve trans = ArrMatrix.lift1 . solveTrans trans