{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Basic where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.RowMajor as RowMajor
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Layout.Private
         (Order(RowMajor, ColumnMajor), transposeFromOrder, flipOrder)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated))
import Numeric.LAPACK.Matrix.Private
         (Full, Tall, Wide, Square, General, fromFull, ShapeInt, revealOrder)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, zero, one)
import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked))
import Numeric.LAPACK.Matrix.Extent (Extent)
import Numeric.LAPACK.Private
         (pointerSeq, copyTransposed, copySubMatrix, copyBlock)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))
import Data.Array.Comfort.Shape ((::+)((::+)))

import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)

import Data.Complex (Complex)


caseTallWide ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
   Full meas vert horiz height width a ->
   Either (Tall height width a) (Wide height width a)
caseTallWide (Array shape a) =
   either (Left . flip Array a) (Right . flip Array a) $
   Layout.caseTallWide shape


transpose ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width a -> Full meas horiz vert width height a
transpose = Array.mapShape Layout.transpose

adjoint ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a -> Full meas horiz vert width height a
adjoint = transpose . Vector.conjugate


swapMultiply ::
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA,
    Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
   (matrix ->
    Full measA horizA vertA widthA heightA a ->
    Full measB horizB vertB widthB heightB a) ->
   Full measA vertA horizA heightA widthA a ->
   matrix ->
   Full measB vertB horizB heightB widthB a
swapMultiply multiplyTrans a b = transpose $ multiplyTrans b $ transpose a


mapExtent ::
   (Extent measA vertA horizA heightA widthA ->
    Extent measB vertB horizB heightB widthB) ->
   Full measA vertA horizA heightA widthA a ->
   Full measB vertB horizB heightB widthB a
mapExtent f =
   Array.mapShape
      (\(Layout.Full order extent) ->
         Layout.Full order $ f extent)

mapHeight ::
   (Extent.C vert, Extent.C horiz) =>
   (heightA -> heightB) ->
   Full Extent.Size vert horiz heightA width a ->
   Full Extent.Size vert horiz heightB width a
mapHeight = mapExtent . Extent.mapHeight

mapWidth ::
   (Extent.C vert, Extent.C horiz) =>
   (widthA -> widthB) ->
   Full Extent.Size vert horiz height widthA a ->
   Full Extent.Size vert horiz height widthB a
mapWidth = mapExtent . Extent.mapWidth

uncheck ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz height width a ->
   Full meas vert horiz (Unchecked height) (Unchecked width) a
uncheck = mapExtent $ Extent.mapWrap Unchecked Unchecked

recheck ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   Full meas vert horiz (Unchecked height) (Unchecked width) a ->
   Full meas vert horiz height width a
recheck = mapExtent Extent.recheck


singleRow :: Order -> Vector width a -> General () width a
singleRow order = Array.mapShape (Layout.general order ())

singleColumn :: Order -> Vector height a -> General height () a
singleColumn order = Array.mapShape (flip (Layout.general order) ())

flattenRow :: General () width a -> Vector width a
flattenRow = Array.mapShape Layout.fullWidth

flattenColumn :: General height () a -> Vector height a
flattenColumn = Array.mapShape Layout.fullHeight

liftRow ::
   Order ->
   (Vector height0 a -> Vector height1 b) ->
   General () height0 a -> General () height1 b
liftRow order f = singleRow order . f . flattenRow

liftColumn ::
   Order ->
   (Vector height0 a -> Vector height1 b) ->
   General height0 () a -> General height1 () b
liftColumn order f = singleColumn order . f . flattenColumn

unliftRow ::
   Order ->
   (General () height0 a -> General () height1 b) ->
   Vector height0 a -> Vector height1 b
unliftRow order f = flattenRow . f . singleRow order

unliftColumn ::
   Order ->
   (General height0 () a -> General height1 () b) ->
   Vector height0 a -> Vector height1 b
unliftColumn order f = flattenColumn . f . singleColumn order


forceRowMajor ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
forceRowMajor (Array shape@(Layout.Full order extent) x) =
   case order of
      RowMajor -> Array shape x
      ColumnMajor ->
         Array.unsafeCreate (Layout.Full RowMajor extent) $ \yPtr ->
         withForeignPtr x $ \xPtr -> do
            let (height, width) = Extent.dimensions extent
            let n = Shape.size width
            let m = Shape.size height
            Private.copyTransposed n m xPtr n yPtr

forceOrder ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Order ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
forceOrder order =
   case order of
      RowMajor -> forceRowMajor
      ColumnMajor -> transpose . forceRowMajor . transpose


takeSub ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C heightA, Shape.C height, Shape.C width, Class.Floating a) =>
   heightA -> Int -> ForeignPtr a ->
   Layout.Full meas vert horiz height width ->
   Full meas vert horiz height width a
takeSub heightA k a shape@(Layout.Full order extentB) =
   Array.unsafeCreateWithSize shape $ \blockSize bPtr ->
   withForeignPtr a $ \aPtr ->
   let ma = Shape.size heightA
       mb = Shape.size $ Extent.height extentB
       n  = Shape.size $ Extent.width  extentB
   in case order of
         RowMajor -> copyBlock blockSize (advancePtr aPtr (k*n)) bPtr
         ColumnMajor -> copySubMatrix mb n ma (advancePtr aPtr k) mb bPtr

takeTop ::
   (Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
    Class.Floating a) =>
   Full Extent.Size vert Extent.Big (height0::+height1) width a ->
   Full Extent.Size vert Extent.Big height0 width a
takeTop (Array (Layout.Full order extentA) a) =
   let heightA@(heightB::+_) = Extent.height extentA
       extentB = Extent.reduceWideHeight heightB extentA
   in takeSub heightA 0 a $ Layout.Full order extentB

takeBottom ::
   (Extent.C vert, Shape.C height0, Shape.C height1, Shape.C width,
    Class.Floating a) =>
   Full Extent.Size vert Extent.Big (height0::+height1) width a ->
   Full Extent.Size vert Extent.Big height1 width a
takeBottom (Array (Layout.Full order extentA) a) =
   let heightA@(height0::+heightB) = Extent.height extentA
       extentB = Extent.reduceWideHeight heightB extentA
   in takeSub heightA (Shape.size height0) a $ Layout.Full order extentB

takeLeft ::
   (Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
    Class.Floating a) =>
   Full Extent.Size Extent.Big vert height (width0::+width1) a ->
   Full Extent.Size Extent.Big vert height width0 a
takeLeft = transpose . takeTop . transpose

takeRight ::
   (Extent.C vert, Shape.C height, Shape.C width0, Shape.C width1,
    Class.Floating a) =>
   Full Extent.Size Extent.Big vert height (width0::+width1) a ->
   Full Extent.Size Extent.Big vert height width1 a
takeRight = transpose . takeBottom . transpose


splitRows ::
   (Extent.C vert, Extent.C horiz, Shape.C width, Class.Floating a) =>
   Int ->
   Full Extent.Size vert horiz ShapeInt width a ->
   Full Extent.Size vert horiz (ShapeInt::+ShapeInt) width a
splitRows = mapExtent . Extent.mapHeight . Shape.zeroBasedSplit

takeRows, dropRows ::
   (Extent.C vert, Shape.C width, Class.Floating a) =>
   Int ->
   Full Extent.Size vert Extent.Big ShapeInt width a ->
   Full Extent.Size vert Extent.Big ShapeInt width a
takeRows k = takeTop . splitRows k
dropRows k = takeBottom . splitRows k

takeColumns, dropColumns ::
   (Extent.C horiz, Shape.C height, Class.Floating a) =>
   Int ->
   Full Extent.Size Extent.Big horiz height ShapeInt a ->
   Full Extent.Size Extent.Big horiz height ShapeInt a
takeColumns k = transpose . takeRows k . transpose
dropColumns k = transpose . dropRows k . transpose


data OrderBias = LeftBias | RightBias | ContiguousBias
   deriving (Eq, Ord, Enum, Show)

beside ::
   (Extent.C vertA, Extent.C vertB, Extent.C vertC,
    Shape.C height, Eq height, Shape.C widthA, Shape.C widthB,
    Class.Floating a) =>
   OrderBias ->
   Extent.AppendMode vertA vertB vertC height widthA widthB ->
   Full Extent.Size vertA Extent.Big height widthA a ->
   Full Extent.Size vertB Extent.Big height widthB a ->
   Full Extent.Size vertC Extent.Big height (widthA::+widthB) a
beside orderBias (Extent.AppendMode appendMode)
      (Array (Layout.Full orderA extentA) a)
      (Array (Layout.Full orderB extentB) b) =
   let (heightA,widthA) = Extent.dimensions extentA
       (heightB,widthB) = Extent.dimensions extentB
       n = Shape.size heightA
       ma = Shape.size widthA; volA = n*ma
       mb = Shape.size widthB; volB = n*mb
       m = ma+mb
       create order act =
          Array.unsafeCreate
             (Layout.Full order $ appendMode extentA extentB) $ \cPtr ->
          withForeignPtr a $ \aPtr ->
          withForeignPtr b $ \bPtr ->
          act aPtr bPtr cPtr $ advancePtr cPtr $
          case order of
             RowMajor -> ma
             ColumnMajor -> volA
   in
    if heightA /= heightB
      then error "beside: mismatching heights"
      else
         case (orderA,orderB) of
            (RowMajor,RowMajor) ->
               create RowMajor $ \aPtr bPtr cPtr _ -> evalContT $ do
                  maPtr <- Call.cint ma
                  mbPtr <- Call.cint mb
                  incxPtr <- Call.cint 1
                  incyPtr <- Call.cint 1
                  liftIO $
                     sequence_ $ take n $
                     zipWith3
                        (\akPtr bkPtr ckPtr -> do
                           BlasGen.copy maPtr akPtr incxPtr ckPtr incyPtr
                           BlasGen.copy mbPtr bkPtr incxPtr
                              (ckPtr `advancePtr` ma) incyPtr)
                        (pointerSeq ma aPtr)
                        (pointerSeq mb bPtr)
                        (pointerSeq m cPtr)
            (RowMajor,ColumnMajor) ->
               case orderBias of
                  LeftBias ->
                     create RowMajor $ \aPtr bPtr clPtr crPtr -> do
                        copySubMatrix ma n ma aPtr m clPtr
                        copyTransposed mb n bPtr m crPtr
                  _ ->
                     create ColumnMajor $ \aPtr bPtr clPtr crPtr -> do
                        copyTransposed n ma aPtr n clPtr
                        copyBlock volB bPtr crPtr
            (ColumnMajor,RowMajor) ->
               case orderBias of
                  RightBias ->
                     create RowMajor $ \aPtr bPtr clPtr crPtr -> do
                        copyTransposed ma n aPtr m clPtr
                        copySubMatrix mb n mb bPtr m crPtr
                  _ ->
                     create ColumnMajor $ \aPtr bPtr clPtr crPtr -> do
                        copyBlock volA aPtr clPtr
                        copyTransposed n mb bPtr n crPtr
            (ColumnMajor,ColumnMajor) ->
               create ColumnMajor $ \aPtr bPtr clPtr crPtr -> evalContT $ do
                  naPtr <- Call.cint volA
                  nbPtr <- Call.cint volB
                  incxPtr <- Call.cint 1
                  incyPtr <- Call.cint 1
                  liftIO $ do
                     BlasGen.copy naPtr aPtr incxPtr clPtr incyPtr
                     BlasGen.copy nbPtr bPtr incxPtr crPtr incyPtr

above ::
   (Extent.C horizA, Extent.C horizB, Extent.C horizC,
    Shape.C width, Eq width, Shape.C heightA, Shape.C heightB,
    Class.Floating a) =>
   OrderBias ->
   Extent.AppendMode horizA horizB horizC width heightA heightB ->
   Full Extent.Size Extent.Big horizA heightA width a ->
   Full Extent.Size Extent.Big horizB heightB width a ->
   Full Extent.Size Extent.Big horizC (heightA::+heightB) width a
above orderBias appendMode a b =
   transpose $ beside orderBias appendMode (transpose a) (transpose b)

stack ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C heightA, Eq heightA, Shape.C heightB, Eq heightB,
    Shape.C widthA, Eq widthA, Shape.C widthB, Eq widthB, Class.Floating a) =>
   Full meas vert horiz heightA widthA a -> General heightA widthB a ->
   General heightB widthA a -> Full meas vert horiz heightB widthB a ->
   Full meas vert horiz (heightA::+heightB) (widthA::+widthB) a
stack = stackBiased RightBias RightBias

stackMosaic ::
   (Shape.C shA, Eq shA, Shape.C shB, Eq shB, Class.Floating a) =>
   Square shA a -> General shA shB a ->
   General shB shA a -> Square shB a ->
   Square (shA::+shB) a
stackMosaic = stackBiased LeftBias RightBias

stackBiased ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C heightA, Eq heightA, Shape.C heightB, Eq heightB,
    Shape.C widthA, Eq widthA, Shape.C widthB, Eq widthB, Class.Floating a) =>
   OrderBias -> OrderBias ->
   Full meas vert horiz heightA widthA a -> General heightA widthB a ->
   General heightB widthA a -> Full meas vert horiz heightB widthB a ->
   Full meas vert horiz (heightA::+heightB) (widthA::+widthB) a
stackBiased vertBias horizBias a b c d =
   mapExtent
      (\ _ ->
         Extent.stack
            (Layout.fullExtent $ Array.shape a)
            (Layout.fullExtent $ Array.shape d)) $
   above vertBias Extent.appendAny
      (beside horizBias Extent.appendAny (fromFull a) b)
      (beside horizBias Extent.appendAny c (fromFull d))


liftRowMajor ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Array (height, width) a -> Array (height, width) b) ->
   (Array (width, height) a -> Array (width, height) b) ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width b
liftRowMajor fr fc a =
   either
      (Array.reshape (Array.shape a) . fr)
      (Array.reshape (Array.shape a) . fc) $
   revealOrder a

scaleRows ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
scaleRows x = liftRowMajor (RowMajor.scaleRows x) (RowMajor.scaleColumns x)

scaleColumns ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
scaleColumns x = transpose . scaleRows x . transpose


scaleRowsComplex ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Real a) =>
   Vector height a ->
   Full meas vert horiz height width (Complex a) ->
   Full meas vert horiz height width (Complex a)
scaleRowsComplex x =
   liftRowMajor
      (RowMajor.recomplex . RowMajor.scaleRows x . RowMajor.decomplex)
      (RowMajor.recomplex .
       RowMajor.scaleColumns
         (RowMajor.tensorProduct (Left NonConjugated) x
            (Vector.one Shape.Enumeration)) .
       RowMajor.decomplex)

scaleColumnsComplex ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Real a) =>
   Vector width a ->
   Full meas vert horiz height width (Complex a) ->
   Full meas vert horiz height width (Complex a)
scaleColumnsComplex x = transpose . scaleRowsComplex x . transpose


scaleRowsReal ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height (RealOf a) ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
scaleRowsReal =
   getScaleRowsReal $
   Class.switchFloating
      (ScaleRowsReal scaleRows)
      (ScaleRowsReal scaleRows)
      (ScaleRowsReal scaleRowsComplex)
      (ScaleRowsReal scaleRowsComplex)

newtype ScaleRowsReal f g a =
   ScaleRowsReal {getScaleRowsReal :: f (RealOf a) -> g a -> g a}

scaleColumnsReal ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width (RealOf a) ->
   Full meas vert horiz height width a ->
   Full meas vert horiz height width a
scaleColumnsReal x = transpose . scaleRowsReal x . transpose



multiplyVector ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Full meas vert horiz height width a -> Vector width a -> Vector height a
multiplyVector a x =
   let width = Layout.fullWidth $ Array.shape a
   in if width == Array.shape x
         then multiplyVectorUnchecked a x
         else error "multiplyVector: width shapes mismatch"

multiplyVectorUnchecked ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
   Full meas vert horiz height width a -> Vector width a -> Vector height a
multiplyVectorUnchecked
   (Array shape@(Layout.Full order extent) a) (Array _ x) =
      Array.unsafeCreate (Extent.height extent) $ \yPtr -> do
   let (m,n) = Layout.dimensions shape
   let lda = m
   evalContT $ do
      transPtr <- Call.char $ transposeFromOrder order
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      alphaPtr <- Call.number one
      aPtr <- ContT $ withForeignPtr a
      ldaPtr <- Call.leadingDim lda
      xPtr <- ContT $ withForeignPtr x
      incxPtr <- Call.cint 1
      betaPtr <- Call.number zero
      incyPtr <- Call.cint 1
      liftIO $
         Private.gemv
            transPtr mPtr nPtr alphaPtr aPtr ldaPtr
            xPtr incxPtr betaPtr yPtr incyPtr

{- |
Multiply two matrices with the same dimension constraints.
E.g. you can multiply 'General' and 'General' matrices,
or 'Square' and 'Square' matrices.
It may seem to be overly strict in this respect,
but that design supports type inference the best.
You can lift the restrictions by generalizing operands
with 'Square.toFull', 'Matrix.fromFull',
'Matrix.generalizeTall' or 'Matrix.generalizeWide'.
-}
multiply, multiplyColumnMajor ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height,
    Shape.C fuse, Eq fuse,
    Shape.C width,
    Class.Floating a) =>
   Full meas vert horiz height fuse a ->
   Full meas vert horiz fuse width a ->
   Full meas vert horiz height width a
-- preserve order of the right factor
multiply
   (Array (Layout.Full orderA extentA) a)
   (Array (Layout.Full orderB extentB) b) =
   case Extent.fuse extentA extentB of
      Nothing -> error "multiply: fuse shapes mismatch"
      Just extent ->
         Array.unsafeCreate (Layout.Full orderB extent) $ \cPtr -> do

      let (height,fuse) = Extent.dimensions extentA
      let width = Extent.width extentB
      let m = Shape.size height
      let n = Shape.size width
      let k = Shape.size fuse
      case orderB of
         RowMajor ->
            Private.multiplyMatrix (flipOrder orderB) (flipOrder orderA)
               n k m b a cPtr
         ColumnMajor -> Private.multiplyMatrix orderA orderB m k n a b cPtr

-- always return ColumnMajor
multiplyColumnMajor
   (Array (Layout.Full orderA extentA) a)
   (Array (Layout.Full orderB extentB) b) =
   case Extent.fuse extentA extentB of
      Nothing -> error "multiply: fuse shapes mismatch"
      Just extent ->
         Array.unsafeCreate (Layout.Full ColumnMajor extent) $ \cPtr -> do

      let (height,fuse) = Extent.dimensions extentA
      let width = Extent.width extentB
      let m = Shape.size height
      let n = Shape.size width
      let k = Shape.size fuse
      Private.multiplyMatrix orderA orderB m k n a b cPtr