{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Split where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Mosaic.Private as Mos
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Tri
import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Mosaic.Private (diagonalPointers)
import Numeric.LAPACK.Matrix.Triangular.Basic (Lower, Upper)
import Numeric.LAPACK.Matrix.Layout.Private
         (Order(RowMajor, ColumnMajor), transposeFromOrder,
          swapOnRowMajor, sideSwapFromOrder,
          Triangle, uploFromOrder, flipOrder)
import Numeric.LAPACK.Matrix.Extent.Private (Extent)
import Numeric.LAPACK.Matrix.Modifier
         (Transposition, transposeOrder,
          Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Linear.Private (solver, withInfo)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private (copyBlock, conjugateToTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.C.Types (CInt, CChar)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (poke)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


type Split lower meas vert horiz height width =
      Array (Layout.Split lower meas vert horiz height width)

type Square lower sh = Split lower Extent.Shape Extent.Small Extent.Small sh sh


mapExtent ::
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
   (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
   Extent.Map measA vertA horizA measB vertB horizB height width ->
   Split lower measA vertA horizA height width a ->
   Split lower measB vertB horizB height width a
mapExtent = Array.mapShape . Layout.splitMapExtent

mapExtentSizes ::
   (Extent measA vertA horizA heightA widthA ->
    Extent measB vertB horizB heightB widthB) ->
   Split lower measA vertA horizA heightA widthA a ->
   Split lower measB vertB horizB heightB widthB a
mapExtentSizes f =
   Array.mapShape
      (\(Layout.Split lowerPart order extent) ->
         Layout.Split lowerPart order $ f extent)

mapHeight ::
   (Extent.C vert, Extent.C horiz) =>
   (heightA -> heightB) ->
   Split lower Extent.Size vert horiz heightA width a ->
   Split lower Extent.Size vert horiz heightB width a
mapHeight = mapExtentSizes . Extent.mapHeight

mapWidth ::
   (Extent.C vert, Extent.C horiz) =>
   (widthA -> widthB) ->
   Split lower Extent.Size vert horiz height widthA a ->
   Split lower Extent.Size vert horiz height widthB a
mapWidth = mapExtentSizes . Extent.mapWidth

uncheck ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Split lower meas vert horiz height width a ->
   Split lower meas vert horiz (Unchecked height) (Unchecked width) a
uncheck = mapExtentSizes $ Extent.mapWrap Unchecked Unchecked

recheck ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Split lower meas vert horiz (Unchecked height) (Unchecked width) a ->
   Split lower meas vert horiz height width a
recheck = mapExtentSizes Extent.recheck


heightToQuadratic ::
   (Extent.Measure meas) =>
   Split lower meas Extent.Small Extent.Small height width a ->
   Square lower height a
heightToQuadratic =
   Array.mapShape $
      \(Layout.Split part order_ extent_) ->
         Layout.Split part order_ $
         Extent.square $ Extent.height extent_

widthToQuadratic ::
   (Extent.Measure meas) =>
   Split lower meas Extent.Small Extent.Small height width a ->
   Square lower width a
widthToQuadratic =
   Array.mapShape $
      \(Layout.Split part order_ extent_) ->
         Layout.Split part order_ $
         Extent.square $ Extent.width extent_


determinantR ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Split lower meas vert Extent.Small height width a -> a
determinantR (Array (Layout.Split _ order extent) a) =
   let (height,width) = Extent.dimensions extent
       m = Shape.size height
       n = Shape.size width
       k = case order of RowMajor -> n; ColumnMajor -> m
   in unsafePerformIO $
      withForeignPtr a $ \aPtr ->
      Private.product (min m n) aPtr (k+1)


extractTriangle ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Either lower Triangle ->
   Split lower meas vert horiz height width a ->
   Full meas vert horiz height width a
extractTriangle part (Array (Layout.Split _ order extent) qr) =

   Array.unsafeCreate (Layout.Full order extent) $ \rPtr -> do

   let (height,width) = Extent.dimensions extent
   let ((loup,m), (uplo,n)) =
         swapOnRowMajor order
            (('L', Shape.size height), ('U', Shape.size width))
   evalContT $ do
      loupPtr <- Call.char loup
      uploPtr <- Call.char uplo
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      qrPtr <- ContT $ withForeignPtr qr
      ldqrPtr <- Call.leadingDim m
      ldrPtr <- Call.leadingDim m
      zeroPtr <- Call.number zero
      onePtr <- Call.number one
      liftIO $
         case part of
            Left _ -> do
               LapackGen.lacpy loupPtr mPtr nPtr qrPtr ldqrPtr rPtr ldrPtr
               LapackGen.laset uploPtr mPtr nPtr zeroPtr onePtr rPtr ldrPtr
            Right _ -> do
               LapackGen.laset loupPtr mPtr nPtr zeroPtr zeroPtr rPtr ldrPtr
               LapackGen.lacpy uploPtr mPtr nPtr qrPtr ldqrPtr rPtr ldrPtr


wideExtractL ::
   (Extent.Measure meas, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Split lower meas Extent.Small horiz height width a -> Lower height a
wideExtractL =
   Mos.fromLowerPart
      (\order m lPtr -> mapM_ (flip poke one) $ diagonalPointers order m lPtr)
      Layout.NoMirror
   .
   toFull

tallExtractR ::
   (Extent.Measure meas, Extent.C vert,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Split lower meas vert Extent.Small height width a -> Upper width a
tallExtractR = Tri.takeUpper . toFull

toFull ::
   Split lower meas vert horiz height width a ->
   Full meas vert horiz height width a
toFull =
   Array.mapShape
      (\(Layout.Split _ order extent) -> Layout.Full order extent)


wideMultiplyL ::
   (Extent.Measure measA, Extent.C horizA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height,
    Shape.C widthA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   Split Triangle measA Extent.Small horizA height widthA a ->
   Full meas vert horiz height widthB a ->
   Full meas vert horiz height widthB a
wideMultiplyL transposed a b =
   if Layout.splitHeight (Array.shape a) == Matrix.height b
      then multiplyTriangular ('L','U') 'U' transposed a b
      else error "wideMultiplyL: height shapes mismatch"

tallMultiplyR ::
   (Extent.Measure measA, Extent.C vertA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height,
    Shape.C heightA, Shape.C widthB, Class.Floating a) =>
   Transposition ->
   Split lower measA vertA Extent.Small heightA height a ->
   Full meas vert horiz height widthB a ->
   Full meas vert horiz height widthB a
tallMultiplyR transposed a b =
   if Layout.splitWidth (Array.shape a) == Matrix.height b
      then multiplyTriangular ('U','L') 'N' transposed a b
      else error "wideMultiplyR: height shapes mismatch"

multiplyTriangular ::
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA,
    Extent.Measure measB, Extent.C vertB, Extent.C horizB,
    Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   (Char,Char) -> Char -> Transposition ->
   Split lower measA vertA horizA heightA widthA a ->
   Full measB vertB horizB heightB widthB a ->
   Full measB vertB horizB heightB widthB a
multiplyTriangular (normalPart,transposedPart) diag transposed
   (Array (Layout.Split _ orderA extentA) a)
   (Array (Layout.Full orderB extentB) b) =

   Array.unsafeCreate (Layout.Full orderB extentB) $ \cPtr -> do

   let (heightA,widthA) = Extent.dimensions extentA
   let (heightB,widthB) = Extent.dimensions extentB
   let transOrderB = transposeOrder transposed orderB
   let ((uplo, transa), lda) =
         case orderA of
            RowMajor ->
               ((transposedPart, flipOrder transOrderB), Shape.size widthA)
            ColumnMajor ->
               ((normalPart, transOrderB), Shape.size heightA)
   let (side,(m,n)) =
         sideSwapFromOrder orderB (Shape.size heightB, Shape.size widthB)
   evalContT $ do
      sidePtr <- Call.char side
      uploPtr <- Call.char uplo
      transaPtr <- Call.char $ transposeFromOrder transa
      diagPtr <- Call.char diag
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr b
      ldcPtr <- Call.leadingDim m
      alphaPtr <- Call.number one
      liftIO $ do
         copyBlock (m*n) bPtr cPtr
         BlasGen.trmm sidePtr uploPtr transaPtr diagPtr
            mPtr nPtr alphaPtr aPtr ldaPtr cPtr ldcPtr


wideSolveL ::
   (Extent.Measure measA, Extent.C horizA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   Split Triangle measA Extent.Small horizA height width a ->
   Full meas vert horiz height nrhs a -> Full meas vert horiz height nrhs a
wideSolveL transposed conjugated
      (Array (Layout.Split _ orderA extentA) a) =
   let heightA = Extent.height extentA
   in solver "Split.wideSolveL" heightA $ \n nPtr nrhsPtr xPtr ldxPtr -> do

      uploPtr <- Call.char $ uploFromOrder $ flipOrder orderA
      diagPtr <- Call.char 'U'
      let m = Shape.size heightA
      solveTriangular transposed conjugated orderA m n a
         uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr

tallSolveR ::
   (Extent.Measure measA, Extent.C vertA,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) =>
   Transposition -> Conjugation ->
   Split lower measA vertA Extent.Small height width a ->
   Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a
tallSolveR transposed conjugated
      (Array (Layout.Split _ orderA extentA) a) =
   let (heightA,widthA) = Extent.dimensions extentA
   in solver "Split.tallSolveR" widthA $ \n nPtr nrhsPtr xPtr ldxPtr -> do

      uploPtr <- Call.char $ uploFromOrder orderA
      diagPtr <- Call.char 'N'
      let m = Shape.size heightA
      solveTriangular transposed conjugated orderA m n a
         uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr

solveTriangular ::
   Class.Floating a =>
   Transposition -> Conjugation ->
   Order -> Int -> Int -> ForeignPtr a ->
   Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr CInt -> ContT r IO ()
solveTriangular transposed conjugated orderA m n a
   uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr = do
      let (trans, getA) =
            case (transposeOrder transposed orderA, conjugated) of
               (RowMajor, NonConjugated) -> ('T', ContT $ withForeignPtr a)
               (RowMajor, Conjugated) -> ('C', ContT $ withForeignPtr a)
               (ColumnMajor, NonConjugated) -> ('N', ContT $ withForeignPtr a)
               (ColumnMajor, Conjugated) -> ('N', conjugateToTemp (m*n) a)
      transPtr <- Call.char trans
      aPtr <- getA
      ldaPtr <- Call.leadingDim $ case orderA of ColumnMajor -> m; RowMajor -> n
      liftIO $
         withInfo "trtrs" $
            LapackGen.trtrs uploPtr transPtr diagPtr
               nPtr nrhsPtr aPtr ldaPtr xPtr ldxPtr