{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.RowMajor where

import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor))
import Numeric.LAPACK.Matrix.Private (Full, ShapeInt, shapeInt)
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated,Conjugated))
import Numeric.LAPACK.Scalar (zero, one)
import Numeric.LAPACK.Private (ComplexPart, pointerSeq)

import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Storable (Storable)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import Data.Complex (Complex)
import Data.Foldable (forM_)


type Matrix height width = Array (height,width)
type Vector = Array

takeRow ::
   (Shape.Indexed height, Shape.C width, Shape.Index height ~ ix,
    Storable a) =>
   ix -> Matrix height width a -> Vector width a
takeRow ix (Array (height,width) x) =
   Array.unsafeCreateWithSize width $ \n yPtr ->
   withForeignPtr x $ \xPtr ->
      copyArray yPtr (advancePtr xPtr (n * Shape.offset height ix)) n

takeColumn ::
   (Shape.C height, Shape.Indexed width, Shape.Index width ~ ix,
    Class.Floating a) =>
   ix -> Matrix height width a -> Vector height a
takeColumn ix (Array (height,width) x) =
   Array.unsafeCreateWithSize height $ \n yPtr -> evalContT $ do
      let offset = Shape.offset width ix
      nPtr <- Call.cint n
      xPtr <- ContT $ withForeignPtr x
      incxPtr <- Call.cint $ Shape.size width
      incyPtr <- Call.cint 1
      liftIO $ BlasGen.copy nPtr (advancePtr xPtr offset) incxPtr yPtr incyPtr


fromRows ::
   (Shape.C width, Eq width, Storable a) =>
   width -> [Vector width a] -> Matrix ShapeInt width a
fromRows width rows =
   Array.unsafeCreate (shapeInt $ length rows, width) $ \dstPtr ->
   let widthSize = Shape.size width
   in forM_ (zip (pointerSeq widthSize dstPtr) rows) $
         \(dstRowPtr, Array.Array rowWidth srcFPtr) ->
         withForeignPtr srcFPtr $ \srcPtr -> do
            Call.assert
               "Matrix.fromRows: non-matching vector size"
               (width == rowWidth)
            copyArray dstRowPtr srcPtr widthSize


tensorProduct ::
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Either Conjugation Conjugation ->
   Vector height a -> Vector width a -> Matrix height width a
tensorProduct side (Array height x) (Array width y) =
   Array.unsafeCreate (height,width) $ \cPtr -> do
   let m = Shape.size width
   let n = Shape.size height
   let trans conjugated =
         case conjugated of NonConjugated -> 'T'; Conjugated -> 'C'
   let ((transa,transb),(lda,ldb)) =
         case side of
            Left c -> ((trans c, 'N'),(1,1))
            Right c -> (('N', trans c),(m,n))
   evalContT $ do
      transaPtr <- Call.char transa
      transbPtr <- Call.char transb
      mPtr <- Call.cint m
      nPtr <- Call.cint n
      kPtr <- Call.cint 1
      alphaPtr <- Call.number one
      aPtr <- ContT $ withForeignPtr y
      ldaPtr <- Call.leadingDim lda
      bPtr <- ContT $ withForeignPtr x
      ldbPtr <- Call.leadingDim ldb
      betaPtr <- Call.number zero
      ldcPtr <- Call.leadingDim m
      liftIO $
         BlasGen.gemm
            transaPtr transbPtr mPtr nPtr kPtr alphaPtr
            aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr


decomplex ::
   (Class.Real a) =>
   Matrix height width (Complex a) ->
   Matrix height (width, Shape.Enumeration ComplexPart) a
decomplex (Array (height,width) a) =
   Array (height, (width, Shape.Enumeration)) (castForeignPtr a)

recomplex ::
   (Class.Real a) =>
   Matrix height (width, Shape.Enumeration ComplexPart) a ->
   Matrix height width (Complex a)
recomplex (Array (height, (width, Shape.Enumeration)) a) =
   Array (height,width) (castForeignPtr a)


scaleRows ::
   (Shape.C height, Eq height, Shape.C width, Class.Floating a) =>
   Vector height a -> Matrix height width a -> Matrix height width a
scaleRows (Array heightX x) (Array shape@(height,width) a) =
      Array.unsafeCreate shape $ \bPtr -> do
   Call.assert "scaleRows: sizes mismatch" (heightX == height)
   evalContT $ do
      let m = Shape.size height
      let n = Shape.size width
      nPtr <- Call.cint n
      xPtr <- ContT $ withForeignPtr x
      aPtr <- ContT $ withForeignPtr a
      incaPtr <- Call.cint 1
      incbPtr <- Call.cint 1
      liftIO $ sequence_ $ take m $
         zipWith3
            (\xkPtr akPtr bkPtr -> do
               BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr
               BlasGen.scal nPtr xkPtr bkPtr incbPtr)
            (pointerSeq 1 xPtr)
            (pointerSeq n aPtr)
            (pointerSeq n bPtr)

scaleColumns ::
   (Shape.C height, Shape.C width, Eq width, Class.Floating a) =>
   Vector width a -> Matrix height width a -> Matrix height width a
scaleColumns (Array widthX x) (Array shape@(height,width) a) =
      Array.unsafeCreate shape $ \bPtr -> do
   Call.assert "scaleColumns: sizes mismatch" (widthX == width)
   evalContT $ do
      let m = Shape.size height
      let n = Shape.size width
      transPtr <- Call.char 'N'
      nPtr <- Call.cint n
      klPtr <- Call.cint 0
      kuPtr <- Call.cint 0
      alphaPtr <- Call.number one
      xPtr <- ContT $ withForeignPtr x
      ldxPtr <- Call.leadingDim 1
      aPtr <- ContT $ withForeignPtr a
      incaPtr <- Call.cint 1
      betaPtr <- Call.number zero
      incbPtr <- Call.cint 1
      liftIO $ sequence_ $ take m $
         zipWith
            (\akPtr bkPtr ->
               Private.gbmv transPtr
                  nPtr nPtr klPtr kuPtr alphaPtr xPtr ldxPtr
                  akPtr incaPtr betaPtr bkPtr incbPtr)
            (pointerSeq n aPtr)
            (pointerSeq n bPtr)


kronecker ::
   (Extent.C vert, Extent.C horiz,
    Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB,
    Class.Floating a) =>
   Full vert horiz heightA widthA a ->
   Matrix heightB widthB a ->
   Matrix (heightA,heightB) (widthA,widthB) a
kronecker
      (Array (MatrixShape.Full orderA extentA) a) (Array (heightB,widthB) b) =
   let (heightA,widthA) = Extent.dimensions extentA
   in Array.unsafeCreate ((heightA,heightB), (widthA,widthB)) $ \cPtr ->
      evalContT $ do
   let (ma,na) = (Shape.size heightA, Shape.size widthA)
   let (mb,nb) = (Shape.size heightB, Shape.size widthB)
   let (lda,istep) =
         case orderA of
            RowMajor -> (1,na)
            ColumnMajor -> (ma,1)
   transaPtr <- Call.char 'N'
   transbPtr <- Call.char 'T'
   mPtr <- Call.cint na
   nPtr <- Call.cint nb
   kPtr <- Call.cint 1
   alphaPtr <- Call.number one
   aPtr <- ContT $ withForeignPtr a
   ldaPtr <- Call.leadingDim lda
   bPtr <- ContT $ withForeignPtr b
   ldbPtr <- Call.leadingDim 1
   betaPtr <- Call.number zero
   ldcPtr <- Call.leadingDim nb
   liftIO $
      forM_ (liftA2 (,) (take ma [0..]) (take mb [0..])) $ \(i,j) -> do
         let aiPtr = advancePtr aPtr (istep*i)
         let bjPtr = advancePtr bPtr (nb*j)
         let cijPtr = advancePtr cPtr (na*nb*(j+mb*i))
         BlasGen.gemm
            transbPtr transaPtr nPtr mPtr kPtr alphaPtr
            bjPtr ldbPtr aiPtr ldaPtr betaPtr cijPtr ldcPtr