{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.LAPACK.Matrix.Array.Basic where

import qualified Numeric.LAPACK.Matrix.BandedHermitian.Basic as BandedHermitian
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Triangular
import qualified Numeric.LAPACK.Matrix.Mosaic.Packed as Packed
import qualified Numeric.LAPACK.Matrix.Mosaic.Basic as Mosaic
import qualified Numeric.LAPACK.Matrix.Symmetric.Unified as Symmetric
import qualified Numeric.LAPACK.Matrix.Square.Basic as Square

import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Array.Private (ArrayMatrix, Quadratic)
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)


{- |
The number of rows must be maintained by the height mapping function.
-}
mapHeight ::
   (ArrayMatrix pack property lower upper Extent.Size vert horiz ~ matrix,
    Extent.C vert, Extent.C horiz,
    Shape.C heightA, Shape.C heightB, Shape.C width) =>
   (heightA -> heightB) ->
   matrix heightA width a -> matrix heightB width a
mapHeight f =
   ArrMatrix.Array . Array.mapShape (Omni.mapHeight f) . ArrMatrix.unwrap

{- |
The number of columns must be maintained by the width mapping function.
-}
mapWidth ::
   (ArrayMatrix pack property lower upper Extent.Size vert horiz ~ matrix,
    Extent.C vert, Extent.C horiz,
    Shape.C widthA, Shape.C widthB, Shape.C height) =>
   (widthA -> widthB) ->
   matrix height widthA a -> matrix height widthB a
mapWidth f =
   ArrMatrix.Array . Array.mapShape (Omni.mapWidth f) . ArrMatrix.unwrap

mapSquareSize ::
   (Shape.C shA, Shape.C shB) =>
   (shA -> shB) ->
   Quadratic pack property lower upper shA a ->
   Quadratic pack property lower upper shB a
mapSquareSize f =
   ArrMatrix.Array . Array.mapShape (Omni.mapSquareSize f) . ArrMatrix.unwrap


toFull ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrMatrix.Full meas vert horiz height width a
toFull a =
   case ArrMatrix.shape a of
      Omni.Full _ -> ArrMatrix.liftUnpacked1 id a
      Omni.UpperTriangular _ -> ArrMatrix.lift1 Triangular.toSquare a
      Omni.LowerTriangular _ -> ArrMatrix.lift1 Triangular.toSquare a
      Omni.Symmetric _ -> ArrMatrix.lift1 Symmetric.toSquare a
      Omni.Hermitian _ -> ArrMatrix.lift1 Symmetric.toSquare a
      Omni.Banded _ -> ArrMatrix.lift1 Banded.toFull a
      Omni.UnitBandedTriangular _ -> ArrMatrix.lift1 Banded.toFull a
      Omni.BandedHermitian _ ->
         ArrMatrix.lift1 (Banded.toFull . BandedHermitian.toBanded) a

unpack ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   ArrayMatrix Layout.Unpacked
      property lower upper meas vert horiz height width a
unpack a =
   case ArrMatrix.shape a of
      Omni.Full _ -> ArrMatrix.liftUnpacked1 id a
      Omni.UpperTriangular _ -> ArrMatrix.lift1 Mosaic.unpack a
      Omni.LowerTriangular _ -> ArrMatrix.lift1 Mosaic.unpack a
      Omni.Symmetric _ -> ArrMatrix.lift1 Mosaic.unpack a
      Omni.Hermitian _ -> ArrMatrix.lift1 Mosaic.unpack a
      Omni.Banded _ ->
         ArrMatrix.liftUnpacked0 $ Banded.toFull $ ArrMatrix.toVector a
      Omni.UnitBandedTriangular _ ->
         ArrMatrix.liftUnpacked0 $ Banded.toFull $ ArrMatrix.toVector a
      Omni.BandedHermitian _ ->
         ArrMatrix.liftUnpacked0 $ Banded.toFull $
         BandedHermitian.toBanded $ ArrMatrix.toVector a

takeDiagonal ::
   (Shape.C sh, Class.Floating a) =>
   Quadratic pack property lower upper sh a -> Vector sh a
takeDiagonal a =
   case ArrMatrix.shape a of
      Omni.Full fullShape ->
         Square.takeDiagonal $ Array.reshape fullShape $ ArrMatrix.unwrap a
      Omni.UpperTriangular _ -> Mosaic.takeDiagonal $ ArrMatrix.toVector a
      Omni.LowerTriangular _ -> Mosaic.takeDiagonal $ ArrMatrix.toVector a
      Omni.Symmetric _ -> Mosaic.takeDiagonal $ ArrMatrix.toVector a
      Omni.Hermitian _ -> Mosaic.takeDiagonal $ ArrMatrix.toVector a
      Omni.Banded _ -> Banded.takeDiagonal $ ArrMatrix.toVector a
      Omni.UnitBandedTriangular _ -> Banded.takeDiagonal $ ArrMatrix.toVector a
      Omni.BandedHermitian _ ->
         Vector.fromReal $ BandedHermitian.takeDiagonal $ ArrMatrix.toVector a


identityFromShape ::
   (Shape.C sh, Class.Floating a) =>
   MatrixShape.Quadratic pack property lower upper sh ->
   Quadratic pack property lower upper sh a
identityFromShape omni =
   ArrMatrix.Array $
   case omni of
      Omni.Full _ ->
         identityOmni Omni.Full Square.identityOrder omni
      Omni.UpperTriangular _ ->
         identityOmni Omni.UpperTriangular Packed.identity omni
      Omni.LowerTriangular _ ->
         identityOmni Omni.LowerTriangular Packed.identity omni
      Omni.Symmetric _ ->
         identityOmni Omni.Symmetric Packed.identity omni
      Omni.Hermitian _ ->
         identityOmni Omni.Hermitian Packed.identity omni
      Omni.Banded _ ->
         identityOmni Omni.Banded Banded.identityFatOrder omni
      Omni.UnitBandedTriangular _ ->
         identityOmni Omni.UnitBandedTriangular Banded.identityFatOrder omni
      Omni.BandedHermitian _ ->
         identityOmni Omni.BandedHermitian
            BandedHermitian.identityFatOrder omni

identityOmni ::
   (shape -> omni) ->
   (Layout.Order -> sh -> Array shape a) ->
   MatrixShape.Quadratic pack property lower upper sh -> Array omni a
identityOmni consOmni eye omni =
   Array.mapShape consOmni $ eye (Omni.order omni) (Omni.squareSize omni)

identityFrom ::
   (Shape.C sh, Class.Floating a) =>
   Quadratic pack property lower upper sh a ->
   Quadratic pack property lower upper sh a
identityFrom = identityFromShape . ArrMatrix.shape

identityOrder ::
   (Omni.Quadratic pack property lower upper, Shape.C sh, Class.Floating a) =>
   Layout.Order -> sh -> Quadratic pack property lower upper sh a
identityOrder order sh = identityFromShape $ Omni.quadratic order sh


signNegativeDeterminant ::
   (Shape.C sh, Class.Floating a) =>
   MatrixShape.Quadratic pack property lower upper sh -> a
signNegativeDeterminant shape =
   Scalar.minusOne ^ mod (Shape.size (Omni.squareSize shape)) 2