{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Basic where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.RowMajor as RowMajor
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Shape.Private
         (Order(RowMajor, ColumnMajor), transposeFromOrder, flipOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private
         (Full, Tall, Wide, General, ZeroInt, revealOrder)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, one)
import Numeric.LAPACK.Private (copySubMatrix, copyBlock)

import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((:+:)((:+:)))

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Complex (Complex)


caseTallWide ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
   Full vert horiz height width a ->
   Either (Tall height width a) (Wide height width a)
caseTallWide (Array shape a) =
   either (Left . flip Array a) (Right . flip Array a) $
   MatrixShape.caseTallWide shape


transpose ::
   (Extent.C vert, Extent.C horiz) =>
   Full vert horiz height width a -> Full horiz vert width height a
transpose = Array.mapShape MatrixShape.transpose

adjoint ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a -> Full horiz vert width height a
adjoint = transpose . Vector.conjugate


swapMultiply ::
   (Extent.C vertA, Extent.C vertB, Extent.C horizA, Extent.C horizB) =>
   (matrix ->
    Full horizA vertA widthA heightA a ->
    Full horizB vertB widthB heightB a) ->
   Full vertA horizA heightA widthA a ->
   matrix ->
   Full vertB horizB heightB widthB a
swapMultiply multiplyTrans a b = transpose $ multiplyTrans b $ transpose a


mapHeight ::
   (Extent.GeneralTallWide vert horiz,
    Extent.GeneralTallWide horiz vert) =>
   (heightA -> heightB) ->
   Full vert horiz heightA width a ->
   Full vert horiz heightB width a
mapHeight f =
   Array.mapShape
      (\(MatrixShape.Full order extent) ->
         MatrixShape.Full order $ Extent.mapHeight f extent)

mapWidth ::
   (Extent.GeneralTallWide vert horiz,
    Extent.GeneralTallWide horiz vert) =>
   (widthA -> widthB) ->
   Full vert horiz height widthA a ->
   Full vert horiz height widthB a
mapWidth f =
   Array.mapShape
      (\(MatrixShape.Full order extent) ->
         MatrixShape.Full order $ Extent.mapWidth f extent)


singleRow :: Order -> Vector width a -> General () width a
singleRow order = Array.mapShape (MatrixShape.general order ())

singleColumn :: Order -> Vector height a -> General height () a
singleColumn order = Array.mapShape (flip (MatrixShape.general order) ())

flattenRow :: General () width a -> Vector width a
flattenRow = Array.mapShape MatrixShape.fullWidth

flattenColumn :: General height () a -> Vector height a
flattenColumn = Array.mapShape MatrixShape.fullHeight

liftRow ::
   Order ->
   (Vector height0 a -> Vector height1 b) ->
   General () height0 a -> General () height1 b
liftRow order f = singleRow order . f . flattenRow

liftColumn ::
   Order ->
   (Vector height0 a -> Vector height1 b) ->
   General height0 () a -> General height1 () b
liftColumn order f = singleColumn order . f . flattenColumn

unliftRow ::
   Order ->
   (General () height0 a -> General () height1 b) ->
   Vector height0 a -> Vector height1 b
unliftRow order f = flattenRow . f . singleRow order

unliftColumn ::
   Order ->
   (General height0 () a -> General height1 () b) ->
   Vector height0 a -> Vector height1 b
unliftColumn order f = flattenColumn . f . singleColumn order


forceRowMajor ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a ->
   Full vert horiz height width a
forceRowMajor (Array shape@(MatrixShape.Full order extent) x) =
   case order of
      RowMajor -> Array shape x
      ColumnMajor ->
         Array.unsafeCreate (MatrixShape.Full RowMajor extent) $ \yPtr ->
         withForeignPtr x $ \xPtr -> do
            let (height, width) = Extent.dimensions extent
            let n = Shape.size width
            let m = Shape.size height
            Private.copyTransposed n m xPtr n yPtr

forceOrder ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Order ->
   Full vert horiz height width a ->
   Full vert horiz height width a
forceOrder order =
   case order of
      RowMajor -> forceRowMajor
      ColumnMajor -> transpose . forceRowMajor . transpose


takeTop ::
   (Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
    Class.Floating a) =>
   Full vert Extent.Big (height0:+:height1) width a ->
   Full vert Extent.Big height0 width a
takeTop (Array (MatrixShape.Full order extentA) a) =
   let (heightA@(heightB:+:_), width) = Extent.dimensions extentA
       extentB = Extent.reduceWideHeight heightB extentA
       ma = Shape.size heightA
       mb = Shape.size heightB
       n = Shape.size width
   in Array.unsafeCreateWithSize (MatrixShape.Full order extentB) $
            \blockSize bPtr ->
      withForeignPtr a $ \aPtr ->
      case order of
         RowMajor -> copyBlock blockSize aPtr bPtr
         ColumnMajor -> copySubMatrix mb n ma aPtr mb bPtr

takeBottom ::
   (Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
    Class.Floating a) =>
   Full vert Extent.Big (height0:+:height1) width a ->
   Full vert Extent.Big height1 width a
takeBottom (Array (MatrixShape.Full order extentA) a) =
   let (heightA@(height0:+:heightB), width) = Extent.dimensions extentA
       extentB = Extent.reduceWideHeight heightB extentA
       k = Shape.size height0
       ma = Shape.size heightA
       mb = Shape.size heightB
       n = Shape.size width
   in Array.unsafeCreateWithSize (MatrixShape.Full order extentB) $
            \blockSize bPtr ->
      withForeignPtr a $ \aPtr ->
      case order of
         RowMajor -> copyBlock blockSize (advancePtr aPtr (k*n)) bPtr
         ColumnMajor -> copySubMatrix mb n ma (advancePtr aPtr k) mb bPtr

takeLeft ::
   (Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
    Class.Floating a) =>
   Full Extent.Big vert height (width0:+:width1) a ->
   Full Extent.Big vert height width0 a
takeLeft = transpose . takeTop . transpose

takeRight ::
   (Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
    Class.Floating a) =>
   Full Extent.Big vert height (width0:+:width1) a ->
   Full Extent.Big vert height width1 a
takeRight = transpose . takeBottom . transpose


splitRows ::
   (Extent.C vert, Shape.C width, Class.Floating a) =>
   Int ->
   Full vert Extent.Big ZeroInt width a ->
   Full vert Extent.Big (ZeroInt:+:ZeroInt) width a
splitRows k =
   Array.mapShape
      (\(MatrixShape.Full order extent) ->
         MatrixShape.Full order $
         Extent.reduceWideHeight
            (Shape.zeroBasedSplit k $ Extent.height extent)
            extent)

takeRows, dropRows ::
   (Extent.C vert, Shape.C width, Class.Floating a) =>
   Int ->
   Full vert Extent.Big ZeroInt width a ->
   Full vert Extent.Big ZeroInt width a
takeRows k = takeTop . splitRows k
dropRows k = takeBottom . splitRows k

takeColumns, dropColumns ::
   (Extent.C horiz, Shape.C height, Class.Floating a) =>
   Int ->
   Full Extent.Big horiz height ZeroInt a ->
   Full Extent.Big horiz height ZeroInt a
takeColumns k = transpose . takeRows k . transpose
dropColumns k = transpose . dropRows k . transpose


liftRowMajor ::
   (Extent.C vert, Extent.C horiz) =>
   (Array (height, width) a -> Array (height, width) b) ->
   (Array (width, height) a -> Array (width, height) b) ->
   Full vert horiz height width a ->
   Full vert horiz height width b
liftRowMajor fr fc a =
   either
      (Array.reshape (Array.shape a) . fr)
      (Array.reshape (Array.shape a) . fc) $
   revealOrder a

scaleRows ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
scaleRows x = liftRowMajor (RowMajor.scaleRows x) (RowMajor.scaleColumns x)

scaleColumns ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a ->
   Full vert horiz height width a ->
   Full vert horiz height width a
scaleColumns x = transpose . scaleRows x . transpose


scaleRowsComplex ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Real a) =>
   Vector height a ->
   Full vert horiz height width (Complex a) ->
   Full vert horiz height width (Complex a)
scaleRowsComplex x =
   liftRowMajor
      (RowMajor.recomplex . RowMajor.scaleRows x . RowMajor.decomplex)
      (RowMajor.recomplex .
       RowMajor.scaleColumns
         (RowMajor.tensorProduct (Left NonConjugated) x
            (Vector.one Shape.Enumeration)) .
       RowMajor.decomplex)

scaleColumnsComplex ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Real a) =>
   Vector width a ->
   Full vert horiz height width (Complex a) ->
   Full vert horiz height width (Complex a)
scaleColumnsComplex x = transpose . scaleRowsComplex x . transpose


scaleRowsReal ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width,
    Class.Floating a) =>
   Vector height (RealOf a) ->
   Full vert horiz height width a ->
   Full vert horiz height width a
scaleRowsReal =
   getScaleRowsReal $
   Class.switchFloating
      (ScaleRowsReal scaleRows)
      (ScaleRowsReal scaleRows)
      (ScaleRowsReal scaleRowsComplex)
      (ScaleRowsReal scaleRowsComplex)

newtype ScaleRowsReal f g a =
   ScaleRowsReal {getScaleRowsReal :: f (RealOf a) -> g a -> g a}

scaleColumnsReal ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width (RealOf a) ->
   Full vert horiz height width a ->
   Full vert horiz height width a
scaleColumnsReal x = transpose . scaleRowsReal x . transpose



multiplyVector ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width,
    Class.Floating a) =>
   Full vert horiz height width a -> Vector width a -> Vector height a
multiplyVector a x =
   let width = MatrixShape.fullWidth $ Array.shape a
   in if width == Array.shape x
         then multiplyVectorUnchecked a x
         else error "multiplyVector: width shapes mismatch"

multiplyVectorUnchecked ::
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width,
    Class.Floating a) =>
   Full vert horiz height width a -> Vector width a -> Vector height a
multiplyVectorUnchecked
   (Array shape@(MatrixShape.Full order extent) a) (Array _ x) =
      Array.unsafeCreate (Extent.height extent) $ \yPtr -> do
   let (m,n) = MatrixShape.dimensions shape
   let lda = m
   evalContT $ do
      transPtr <- Call.char $ transposeFromOrder order
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      alphaPtr <- Call.number one
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      xPtr <- ContT $ withForeignPtr x
      incxPtr <- Call.cint 1
      betaPtr <- Call.number zero
      incyPtr <- Call.cint 1
      liftIO $
         Private.gemv
            transPtr mPtr nPtr alphaPtr aPtr ldaPtr
            xPtr incxPtr betaPtr yPtr incyPtr

{- |
Multiply two matrices with the same dimension constraints.
E.g. you can multiply 'General' and 'General' matrices,
or 'Square' and 'Square' matrices.
It may seem to be overly strict in this respect,
but that design supports type inference the best.
You can lift the restrictions by generalizing operands
with 'Square.toFull', 'Matrix.fromFull',
'Matrix.generalizeTall' or 'Matrix.generalizeWide'.
-}
multiply, multiplyColumnMajor ::
   (Extent.C vert, Extent.C horiz,
    Shape.C height,
    Shape.C fuse, Eq fuse,
    Shape.C width,
    Class.Floating a) =>
   Full vert horiz height fuse a ->
   Full vert horiz fuse width a ->
   Full vert horiz height width a
-- preserve order of the right factor
multiply
   (Array (MatrixShape.Full orderA extentA) a)
   (Array (MatrixShape.Full orderB extentB) b) =
   case Extent.fuse extentA extentB of
      Nothing -> error "multiply: fuse shapes mismatch"
      Just extent ->
         Array.unsafeCreate (MatrixShape.Full orderB extent) $ \cPtr -> do

      let (height,fuse) = Extent.dimensions extentA
      let width = Extent.width extentB
      let m = Shape.size height
      let n = Shape.size width
      let k = Shape.size fuse
      case orderB of
         RowMajor ->
            Private.multiplyMatrix (flipOrder orderB) (flipOrder orderA)
               n k m b a cPtr
         ColumnMajor -> Private.multiplyMatrix orderA orderB m k n a b cPtr

-- always return ColumnMajor
multiplyColumnMajor
   (Array (MatrixShape.Full orderA extentA) a)
   (Array (MatrixShape.Full orderB extentB) b) =
   case Extent.fuse extentA extentB of
      Nothing -> error "multiply: fuse shapes mismatch"
      Just extent ->
         Array.unsafeCreate (MatrixShape.Full ColumnMajor extent) $ \cPtr -> do

      let (height,fuse) = Extent.dimensions extentA
      let width = Extent.width extentB
      let m = Shape.size height
      let n = Shape.size width
      let k = Shape.size fuse
      Private.multiplyMatrix orderA orderB m k n a b cPtr