{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Basic where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Private (Full, General)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (pointerSeq)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Storable (poke, peek)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)


transpose ::
   (Extent.C vert, Extent.C horiz) =>
   Full vert horiz height width a -> Full horiz vert width height a
transpose = Array.mapShape MatrixShape.transpose

mapHeight ::
   (Extent.GeneralTallWide vert horiz,
    Extent.GeneralTallWide horiz vert) =>
   (heightA -> heightB) ->
   Full vert horiz heightA width a ->
   Full vert horiz heightB width a
mapHeight f =
   Array.mapShape
      (\(MatrixShape.Full order extent) ->
         MatrixShape.Full order $ Extent.mapHeight f extent)

mapWidth ::
   (Extent.GeneralTallWide vert horiz,
    Extent.GeneralTallWide horiz vert) =>
   (widthA -> widthB) ->
   Full vert horiz height widthA a ->
   Full vert horiz height widthB a
mapWidth f =
   Array.mapShape
      (\(MatrixShape.Full order extent) ->
         MatrixShape.Full order $ Extent.mapWidth f extent)


singleRow :: Order -> Vector width a -> General () width a
singleRow order = Array.mapShape (MatrixShape.general order ())

singleColumn :: Order -> Vector height a -> General height () a
singleColumn order = Array.mapShape (flip (MatrixShape.general order) ())

flattenRow :: General () width a -> Vector width a
flattenRow = Array.mapShape MatrixShape.fullWidth

flattenColumn :: General height () a -> Vector height a
flattenColumn = Array.mapShape MatrixShape.fullHeight

liftRow ::
   Order ->
   (Vector height0 a -> Vector height1 b) ->
   General () height0 a -> General () height1 b
liftRow order f = singleRow order . f . flattenRow

liftColumn ::
   Order ->
   (Vector height0 a -> Vector height1 b) ->
   General height0 () a -> General height1 () b
liftColumn order f = singleColumn order . f . flattenColumn

unliftRow ::
   Order ->
   (General () height0 a -> General () height1 b) ->
   Vector height0 a -> Vector height1 b
unliftRow order f = flattenRow . f . singleRow order

unliftColumn ::
   Order ->
   (General height0 () a -> General height1 () b) ->
   Vector height0 a -> Vector height1 b
unliftColumn order f = flattenColumn . f . singleColumn order


forceRowMajor ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a
forceRowMajor (Array shape@(MatrixShape.Full order extent) x) =
   case order of
      RowMajor -> Array shape x
      ColumnMajor ->
         Array.unsafeCreate (MatrixShape.Full RowMajor extent) $ \yPtr ->
         withForeignPtr x $ \xPtr -> do
            let (height, width) = Extent.dimensions extent
            let n = Shape.size width
            let m = Shape.size height
            Private.copyTransposed n m xPtr n yPtr

forceOrder ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Order ->
   Full vert horiz height width a ->
   Full vert horiz height width a
forceOrder order =
   case order of
      RowMajor -> forceRowMajor
      ColumnMajor -> transpose . forceRowMajor . transpose


scaleRows ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
scaleRows
   (Array heightX x) (Array shape@(MatrixShape.Full order extent) a) =
      Array.unsafeCreate shape $ \bPtr -> do
   let (height,width) = Extent.dimensions extent
   Call.assert "scaleRows: sizes mismatch" (heightX == height)
   case order of
      RowMajor -> evalContT $ do
         let m = Shape.size height
         let n = Shape.size width
         alphaPtr <- Call.alloca
         nPtr <- Call.cint n
         xPtr <- ContT $ withForeignPtr x
         aPtr <- ContT $ withForeignPtr a
         incaPtr <- Call.cint 1
         incbPtr <- Call.cint 1
         liftIO $ sequence_ $ take m $
            zipWith3
               (\xkPtr akPtr bkPtr -> do
                  poke alphaPtr =<< peek xkPtr
                  BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr
                  BlasGen.scal nPtr alphaPtr bkPtr incbPtr)
               (pointerSeq 1 xPtr)
               (pointerSeq n aPtr)
               (pointerSeq n bPtr)
      ColumnMajor -> evalContT $ do
         let m = Shape.size width
         let n = Shape.size height
         transPtr <- Call.char 'N'
         nPtr <- Call.cint n
         klPtr <- Call.cint 0
         kuPtr <- Call.cint 0
         alphaPtr <- Call.number one
         xPtr <- ContT $ withForeignPtr x
         ldxPtr <- Call.leadingDim 1
         aPtr <- ContT $ withForeignPtr a
         incaPtr <- Call.cint 1
         betaPtr <- Call.number zero
         incbPtr <- Call.cint 1
         liftIO $ sequence_ $ take m $
            zipWith
               (\akPtr bkPtr ->
                  Private.gbmv transPtr
                     nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr
                     akPtr incaPtr betaPtr bkPtr incbPtr)
               (pointerSeq n aPtr)
               (pointerSeq n bPtr)

scaleColumns ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
scaleColumns x = transpose . scaleRows x . transpose