{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
module Numeric.LAPACK.Matrix.Triangular.Private where

import qualified Numeric.LAPACK.Matrix.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor,ColumnMajor), flipOrder, uploFromOrder,
          Empty, Filled, NonUnit)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private (Full)
import Numeric.LAPACK.Scalar (zero)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Private
         (pointerSeq, copyBlock, copyCondConjugateToTemp,
          pokeCInt, fill, withInfo, errorCodeMsg)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))

import Foreign.Marshal.Alloc (alloca)
import Foreign.Marshal.Array (advancePtr)
import Foreign.C.Types (CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Foldable (forM_)


diagonalPointers :: (Storable a) => Order -> Int -> Ptr a -> [Ptr a]
diagonalPointers order n aPtr =
   take n $ scanl advancePtr aPtr $
   case order of
      RowMajor -> iterate pred n
      ColumnMajor -> iterate succ 2

diagonalPointerPairs ::
   (Storable a, Storable b) =>
   Order -> Int -> Ptr a -> Ptr b -> [(Ptr a, Ptr b)]
diagonalPointerPairs order n aPtr bPtr =
   zip (pointerSeq 1 aPtr) $ diagonalPointers order n bPtr


columnMajorPointers ::
   (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, ((Ptr a, Ptr a), Ptr a))]
columnMajorPointers n fullPtr packedPtr =
   let ds = iterate succ 1
   in  take n $ zip ds $
       zip
         (zip (pointerSeq 1 fullPtr) (pointerSeq n fullPtr))
         (scanl advancePtr packedPtr ds)

rowMajorPointers ::
   (Storable a) => Int -> Ptr a -> Ptr a -> [(Int, (Ptr a, Ptr a))]
rowMajorPointers n fullPtr packedPtr =
   let ds = iterate pred n
   in  take n $ zip ds $
       zip (pointerSeq (n+1) fullPtr) (scanl advancePtr packedPtr ds)


forPointers :: [(Int, a)] -> (Ptr CInt -> a -> IO ()) -> IO ()
forPointers xs act =
   alloca $ \nPtr ->
   forM_ xs $ \(d,ptrs) -> do
      pokeCInt nPtr d
      act nPtr ptrs


copyTriangleToTemp ::
   Class.Floating a =>
   Conjugation -> Order -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyTriangleToTemp conj order =
   copyCondConjugateToTemp $
   case order of
      RowMajor -> conj
      ColumnMajor -> NonConjugated


unpackToTemp ::
   Storable a =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Int -> ForeignPtr a -> ContT r IO (Ptr a)
unpackToTemp f n a = do
   apPtr <- ContT $ withForeignPtr a
   aPtr <- Call.allocaArray (n*n)
   liftIO $ f n apPtr aPtr
   return aPtr


unpack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
unpack order n packedPtr fullPtr =
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      ldaPtr <- Call.leadingDim n
      liftIO $ withInfo errorCodeMsg "tpttr" $
         LapackGen.tpttr uploPtr nPtr packedPtr fullPtr ldaPtr

pack :: Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
pack order n = packRect order n n

packRect :: Class.Floating a => Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
packRect order n ld fullPtr packedPtr =
   evalContT $ do
      uploPtr <- Call.char $ uploFromOrder order
      nPtr <- Call.cint n
      ldaPtr <- Call.leadingDim ld
      liftIO $ withInfo errorCodeMsg "trttp" $
         LapackGen.trttp uploPtr nPtr fullPtr ldaPtr packedPtr


unpackZero, _unpackZero ::
   Class.Floating a => Order -> Int -> Ptr a -> Ptr a -> IO ()
_unpackZero order n packedPtr fullPtr = do
   fill zero (n*n) fullPtr
   unpack order n packedPtr fullPtr

unpackZero order n packedPtr fullPtr = do
   fillTriangle zero (flipOrder order) n fullPtr
   unpack order n packedPtr fullPtr

fillTriangle :: Class.Floating a => a -> Order -> Int -> Ptr a -> IO ()
fillTriangle z order n aPtr = evalContT $ do
   uploPtr <- Call.char $ uploFromOrder order
   nPtr <- Call.cint n
   zPtr <- Call.number z
   liftIO $ LapackGen.laset uploPtr nPtr nPtr zPtr zPtr aPtr nPtr



uncheck :: Triangular lo diag up sh a -> Triangular lo diag up (Unchecked sh) a
uncheck =
   Array.mapShape $
      \(MatrixShape.Triangular diag uplo order sh) ->
         MatrixShape.Triangular diag uplo order (Unchecked sh)

recheck :: Triangular lo diag up (Unchecked sh) a -> Triangular lo diag up sh a
recheck =
   Array.mapShape $
      \(MatrixShape.Triangular diag uplo order (Unchecked sh)) ->
         MatrixShape.Triangular diag uplo order sh


stack ::
   (Box.Box sh0, Box.HeightOf sh0 ~ height, Shape.C height, Eq height,
    Box.Box sh1, Box.WidthOf sh1 ~ width, Shape.C width, Eq width,
    Shape.C sh2, Class.Floating a) =>
   String -> (height:+:width -> sh2) ->
   Array sh0 a -> Matrix.General height width a -> Array sh1 a -> Array sh2 a
stack name consShape
      (Array sha a) (Array (MatrixShape.Full order extent) b) (Array shc c) =
   let (height,width) = Extent.dimensions extent
   in Array.unsafeCreate (consShape (height :+: width)) $ \xPtr -> do
      Call.assert (name++".stack: height shapes mismatch") $
         height == Box.height sha
      Call.assert (name++".stack: width shapes mismatch") $
         width == Box.width shc
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr a $ \aPtr -> copyTriangleA copyBlock order m n aPtr xPtr
      withForeignPtr b $ \bPtr -> copyRectangle copyBlock order m n bPtr xPtr
      withForeignPtr c $ \cPtr -> copyTriangleC copyBlock order m n cPtr xPtr

takeTopRight ::
   (Shape.C sh, Shape.C height, Shape.C width, Class.Floating a) =>
   (sh -> (MatrixShape.Order, height:+:width)) ->
   Array sh a -> Matrix.General height width a
takeTopRight getShapes (Array sh x) =
   let (order, height:+:width) = getShapes sh
   in Array.unsafeCreate (MatrixShape.general order height width) $ \bPtr -> do
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr x $ copyRectangle (flip . copyBlock) order m n bPtr

takeTopLeft ::
   (Shape.C sh, Shape.C sha, Shape.C height, Shape.C width, Class.Floating a) =>
   (sh -> (sha, (MatrixShape.Order, height:+:width))) ->
   Array sh a -> Array sha a
takeTopLeft getShapes (Array sh x) =
   let (sha, (order, height:+:width)) = getShapes sh
   in Array.unsafeCreate sha $ \aPtr -> do
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr x $ copyTriangleA (flip . copyBlock) order m n aPtr

takeBottomRight ::
   (Shape.C sh, Shape.C shc, Shape.C height, Shape.C width, Class.Floating a) =>
   (sh -> (shc, (MatrixShape.Order, height:+:width))) ->
   Array sh a -> Array shc a
takeBottomRight getShapes (Array sh x) =
   let (shc, (order, height:+:width)) = getShapes sh
   in Array.unsafeCreate shc $ \cPtr -> do
      let m = Shape.size height
      let n = Shape.size width
      withForeignPtr x $ copyTriangleC (flip . copyBlock) order m n cPtr

{-# INLINE copyTriangleA #-}
copyTriangleA ::
   (Class.Floating a) =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyTriangleA copy order m n aPtr xPtr =
   case order of
      ColumnMajor -> copy (Shape.triangleSize m) aPtr xPtr
      RowMajor ->
         forM_ (zip (iterate pred m) $
                zip (diagonalPointers order m aPtr)
                    (diagonalPointers order (m+n) xPtr)) $
            \(k,(aiPtr,xiPtr)) -> copy k aiPtr xiPtr

{-# INLINE copyTriangleC #-}
copyTriangleC ::
   (Class.Floating a) =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyTriangleC copy order m n cPtr xPtr =
   case order of
      RowMajor ->
         let triSize = Shape.triangleSize n
         in copy triSize cPtr
               (advancePtr xPtr $ Shape.triangleSize (m+n) - triSize)
      ColumnMajor ->
         forM_ (zip (iterate succ 0) $
                zip (diagonalPointers order n cPtr)
                    (drop m $ diagonalPointers order (m+n) xPtr)) $
            \(k,(aiPtr,xiPtr)) ->
               copy (k+1) (advancePtr aiPtr (-k)) (advancePtr xiPtr (-k))

{-# INLINE copyRectangle #-}
copyRectangle ::
   (Class.Floating a) =>
   (Int -> Ptr a -> Ptr a -> IO ()) ->
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyRectangle copy order m n bPtr xPtr =
   case order of
      RowMajor ->
         forM_ (take m $ zip (iterate pred m) $
                zip (pointerSeq n bPtr) (diagonalPointers order (m+n) xPtr)) $
            \(k,(biPtr,xiPtr)) -> copy n biPtr (advancePtr xiPtr k)
      ColumnMajor ->
         forM_ (take n $ zip (iterate succ m) $
                zip (pointerSeq m bPtr)
                    (drop m $ diagonalPointers order (m+n) xPtr)) $
            \(k,(biPtr,xiPtr)) -> copy m biPtr (advancePtr xiPtr (-k))



type Triangular lo diag up sh = Array (MatrixShape.Triangular lo diag up sh)

type FlexDiagonal diag sh =
         Triangular MatrixShape.Empty diag MatrixShape.Empty sh

newtype MultiplyRight diag sh a b lo up =
   MultiplyRight {getMultiplyRight :: Triangular lo diag up sh a -> b}

newtype Map diag sh0 sh1 a lo up =
   Map {getMap :: Triangular lo diag up sh0 a -> Triangular lo diag up sh1 a}

newtype Power diag sh a lo up =
   Power {
      getPower ::
         Triangular lo diag up sh a ->
         Triangular lo (PowerDiag lo up diag) up sh a
   }

type family PowerDiag lo up diag
type instance PowerDiag Empty up diag = diag
type instance PowerDiag Filled Empty diag = diag
type instance PowerDiag Filled Filled diag = NonUnit

type PowerContentDiag lo diag up =
      (MatrixShape.Content lo, MatrixShape.Content up, MatrixShape.TriDiag diag,
       PowerDiag lo up diag ~ diag, PowerDiag up lo diag ~ diag)


fromBanded ::
   (Class.Floating a) =>
   Int -> Order -> Int -> ForeignPtr a -> Int -> Ptr a -> IO ()
fromBanded k order n a bSize bPtr =
   withForeignPtr a $ \aPtr -> do
      fill zero bSize bPtr
      let lda = k+1
      let pointers =
            zip [0..] $ zip (pointerSeq lda aPtr) $
            diagonalPointers order n bPtr
      case order of
         ColumnMajor ->
            forM_ pointers $ \(i,(xPtr,yPtr)) ->
               let j = min i k
               in copyBlock (j+1) (advancePtr xPtr (k-j)) (advancePtr yPtr (-j))
         RowMajor ->
            forM_ pointers $ \(i,(xPtr,yPtr)) ->
               copyBlock (min lda (n-i)) xPtr yPtr


type FlexLower diag sh = Array (MatrixShape.LowerTriangular diag sh)

takeLower ::
   (MatrixShape.TriDiag diag,
    Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) =>
   (diag, Order -> Int -> Ptr a -> IO ()) ->
   Full Extent.Small horiz height width a -> FlexLower diag height a
takeLower (diag, fillDiag) (Array (MatrixShape.Full order extent) a) =
   let (height,width) = Extent.dimensions extent
       m = Shape.size height
       n = Shape.size width
       k = case order of RowMajor -> n; ColumnMajor -> m
   in Array.unsafeCreate
         (MatrixShape.Triangular diag MatrixShape.lower order height) $ \lPtr ->
      withForeignPtr a $ \aPtr -> do
         let dstOrder = flipOrder order
         packRect dstOrder m k aPtr lPtr
         fillDiag dstOrder m lPtr


fromUpperPart ::
   (Extent.C vert, Shape.C height, Shape.C width, Shape.C shape,
    Class.Floating a) =>
   (Order -> width -> shape) ->
   Full vert Extent.Small height width a -> Array shape a
fromUpperPart shape (Array (MatrixShape.Full order extent) a) =
   let (height,width) = Extent.dimensions extent
       m = Shape.size height
       n = Shape.size width
       k = case order of RowMajor -> n; ColumnMajor -> m
   in Array.unsafeCreate (shape order width) $ \bPtr ->
      withForeignPtr a $ \aPtr -> packRect order n k aPtr bPtr