{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Class (
   SquareShape(takeDiagonal, identityFrom), SquareShapeExtra, toSquare,
   MapSquareSize(mapSquareSize),
   MapSize(mapHeight, mapWidth),
   trace,
   Complex(conjugate, fromReal, toComplex),
   adjoint,
   Unpack(unpack), UnpackExtra, toFull,

   Homogeneous, HomogeneousExtra, Scale, ScaleExtra,
   zeroFrom, negate, scaleReal, scale, scaleRealReal, (.*#),
   Additive, AdditiveExtra, add, (#+#),
   Subtractive, SubtractiveExtra, sub, (#-#),
   ) where

import qualified Numeric.LAPACK.Matrix.Array.Basic as OmniMatrix
import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Type.Private as Matrix
import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Matrix.Permutation as Permutation
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Permutation as PermPub
import qualified Numeric.LAPACK.Vector as Vector
import qualified Numeric.LAPACK.Scalar as Scalar
import Numeric.LAPACK.Matrix.Type.Private (Matrix)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, ComplexOf)

import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Shape as Shape

import Prelude hiding (negate)

import GHC.Exts (Constraint)


class Complex typ where
   conjugate ::
      (Matrix typ xl xu lower upper meas vert horiz height width ~ matrix,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      matrix a -> matrix a
   fromReal ::
      (Matrix typ xl xu lower upper meas vert horiz height width ~ matrix,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      matrix (RealOf a) -> matrix a
   toComplex ::
      (Matrix typ xl xu lower upper meas vert horiz height width ~ matrix,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      matrix a -> matrix (ComplexOf a)

instance Complex (ArrMatrix.Array pack property) where
   conjugate (ArrMatrix.Array a) = ArrMatrix.Array $ Vector.conjugate a
   fromReal  (ArrMatrix.Array a) = ArrMatrix.Array $ Vector.fromReal  a
   toComplex (ArrMatrix.Array a) = ArrMatrix.Array $ Vector.toComplex a

instance Complex Matrix.Scale where
   conjugate (Matrix.Scale sh m) = Matrix.Scale sh $ Scalar.conjugate m
   fromReal (Matrix.Scale sh m) = Matrix.Scale sh $ Scalar.fromReal m
   toComplex (Matrix.Scale sh m) = Matrix.Scale sh $ Scalar.toComplex m

instance Complex Matrix.Permutation where
   conjugate = id
   fromReal (Matrix.Permutation p) = Matrix.Permutation p
   toComplex (Matrix.Permutation p) = Matrix.Permutation p

adjoint ::
   (Matrix.Transpose typ, Complex typ) =>
   (Matrix.TransposeExtra typ xl, Matrix.TransposeExtra typ xu) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xu xl upper lower meas horiz vert width height a
adjoint = conjugate . Matrix.transpose


class (Matrix.Box typ) => SquareShape typ where
   type SquareShapeExtra typ extra :: Constraint
   takeDiagonal ::
      (SquareShapeExtra typ xl, SquareShapeExtra typ xu) =>
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a -> Vector sh a
   identityFrom ::
      (SquareShapeExtra typ xl, SquareShapeExtra typ xu) =>
      (Shape.C sh, Class.Floating a) =>
      Matrix.Quadratic typ xl xu lower upper sh a ->
      Matrix.Quadratic typ xl xu lower upper sh a

instance SquareShape (ArrMatrix.Array pack property) where
   type SquareShapeExtra (ArrMatrix.Array pack property) extra = ()
   takeDiagonal a@(ArrMatrix.Array _) = OmniMatrix.takeDiagonal a
   identityFrom a@(ArrMatrix.Array _) = OmniMatrix.identityFrom a

instance SquareShape Matrix.Scale where
   type SquareShapeExtra Matrix.Scale extra = ()
   takeDiagonal (Matrix.Scale sh a) = Vector.constant sh a
   identityFrom (Matrix.Scale sh _a) = Matrix.Scale sh Scalar.one

instance SquareShape Matrix.Permutation where
   type SquareShapeExtra Matrix.Permutation extra = ()
   takeDiagonal a@(Matrix.Permutation _) =
      Perm.takeDiagonal . Permutation.toPermutation $ a
   identityFrom (Matrix.Permutation perm) =
      Matrix.Permutation $ Perm.identity $ Perm.size perm


trace ::
   (SquareShape typ, SquareShapeExtra typ xl, SquareShapeExtra typ xu) =>
   (Shape.C sh, Class.Floating a) =>
   Matrix.Quadratic typ xl xu lower upper sh a -> a
trace = Vector.sum . takeDiagonal



class (SquareShape typ) => MapSquareSize typ where
   {- |
   The number of rows and columns
   must be maintained by the shape mapping function.

   Not available for `Block` matrices.
   -}
   mapSquareSize ::
      (Shape.C shA, Shape.C shB) =>
      (shA -> shB) ->
      Matrix.Quadratic typ xl xu lower upper shA a ->
      Matrix.Quadratic typ xl xu lower upper shB a

instance MapSquareSize (ArrMatrix.Array pack property) where
   mapSquareSize f a@(ArrMatrix.Array _) = OmniMatrix.mapSquareSize f a

instance MapSquareSize Matrix.Scale where
   mapSquareSize f (Matrix.Scale sh a) =
      Matrix.Scale (Layout.mapChecked "Scale.mapSquareSize" f sh) a

instance MapSquareSize Matrix.Permutation where
   mapSquareSize f (Matrix.Permutation perm) =
      Matrix.Permutation $ Perm.mapSize f perm


class (Matrix.Box typ) => MapSize typ where
   {- |
   The number of rows and columns
   must be maintained by the shape mapping function.
   -}
   mapHeight ::
      (Extent.C vert, Extent.C horiz,
       Shape.C heightA, Shape.C heightB, Shape.C width) =>
      (heightA -> heightB) ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz heightA width a ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz heightB width a
   mapWidth ::
      (Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C widthA, Shape.C widthB) =>
      (widthA -> widthB) ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz height widthA a ->
      Matrix typ extraLower extraUpper lower upper
         Extent.Size vert horiz height widthB a

instance MapSize (ArrMatrix.Array pack property) where
   mapHeight f a@(ArrMatrix.Array _) = OmniMatrix.mapHeight f a
   mapWidth f a@(ArrMatrix.Array _) = OmniMatrix.mapWidth f a


class Unpack typ where
   type UnpackExtra typ extra :: Constraint
   -- In contrast to OmniMatrix.unpack it cannot maintain the matrix property.
   unpack ::
      (UnpackExtra typ xl, UnpackExtra typ xu) =>
      (Omni.Strip lower, Omni.Strip upper) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      ArrMatrix.ArrayMatrix Layout.Unpacked Omni.Arbitrary
         lower upper meas vert horiz height width a

instance (Omni.Property prop) => Unpack (ArrMatrix.Array pack prop) where
   type UnpackExtra (ArrMatrix.Array pack prop) extra = extra ~ ()
   unpack a@(ArrMatrix.Array _) =
      ArrMatrix.liftUnpacked1 id $ OmniMatrix.unpack a

instance Unpack Matrix.Scale where
   type UnpackExtra Matrix.Scale extra = extra ~ ()
   unpack (Matrix.Scale sh a) =
      ArrMatrix.liftUnpacked0 $ Banded.toFull $
      Banded.diagonal Layout.RowMajor $ Vector.constant sh a

instance Unpack Matrix.Permutation where
   type UnpackExtra Matrix.Permutation extra = extra ~ ()
   unpack (Matrix.Permutation perm) =
      ArrMatrix.liftUnpacked1 id $ PermPub.toMatrix perm

toFull ::
   (Unpack typ, UnpackExtra typ xl, UnpackExtra typ xu) =>
   (Omni.Strip lower, Omni.Strip upper) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   ArrMatrix.Full meas vert horiz height width a
toFull = OmniMatrix.toFull . unpack

toSquare ::
   (Unpack typ, UnpackExtra typ xl, UnpackExtra typ xu) =>
   (Omni.Strip lower, Omni.Strip upper) =>
   (Shape.C sh, Class.Floating a) =>
   Matrix.Quadratic typ xl xu lower upper sh a -> ArrMatrix.Square sh a
toSquare = toFull



class Homogeneous typ where
   type HomogeneousExtra typ extra :: Constraint
   zeroFrom ::
      (HomogeneousExtra typ xl, HomogeneousExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a
   negate ::
      (HomogeneousExtra typ xl, HomogeneousExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a
   scaleReal ::
      (HomogeneousExtra typ xl, HomogeneousExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      RealOf a ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a

instance
   (ArrMatrix.Homogeneous property) =>
      Homogeneous (ArrMatrix.Array pack property) where
   type HomogeneousExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   zeroFrom = ArrMatrix.zero . ArrMatrix.shape
   negate = ArrMatrix.negate
   scaleReal = ArrMatrix.scaleReal

instance Homogeneous Matrix.Scale where
   type HomogeneousExtra Matrix.Scale extra = extra ~ ()
   zeroFrom (Matrix.Scale sh _a) = Matrix.Scale sh Scalar.zero
   negate (Matrix.Scale sh a) = Matrix.Scale sh (-a)
   scaleReal c (Matrix.Scale sh a) = Matrix.Scale sh (Scalar.fromReal c*a)

newtype ScaleReal f a = ScaleReal {getScaleReal :: a -> f a -> f a}

scaleRealReal ::
   (Homogeneous typ, HomogeneousExtra typ xl, HomogeneousExtra typ xu) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Real a) =>
   a ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a
scaleRealReal =
   getScaleReal $ Class.switchReal (ScaleReal scaleReal) (ScaleReal scaleReal)


class (Homogeneous typ) => Scale typ where
   type ScaleExtra typ extra :: Constraint
   scale ::
      (ScaleExtra typ xl, ScaleExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      a ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a

instance
   (ArrMatrix.Scale property) =>
      Scale (ArrMatrix.Array pack property) where
   type ScaleExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   scale = ArrMatrix.scale

instance Scale Matrix.Scale where
   type ScaleExtra Matrix.Scale extra = extra ~ ()
   scale c (Matrix.Scale sh a) = Matrix.Scale sh (c*a)

(.*#) ::
   (Scale typ, ScaleExtra typ xl, ScaleExtra typ xu) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.C height, Shape.C width, Class.Floating a) =>
   a ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a
(.*#) = scale

infixl 7 .*#


class Additive typ where
   type AdditiveExtra typ extra :: Constraint
   add ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (AdditiveExtra typ xl, AdditiveExtra typ xu,
       Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a

instance
   (ArrMatrix.Additive property) =>
      Additive (ArrMatrix.Array pack property) where
   type AdditiveExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   add = ArrMatrix.add

instance Additive Matrix.Scale where
   type AdditiveExtra Matrix.Scale extra = extra ~ ()
   add (Matrix.Scale sha a) (Matrix.Scale shb b) =
      if sha == shb
         then Matrix.Scale sha (a+b)
         else error "Matrix.add Scale: dimensions mismatch"

class (Additive typ) => Subtractive typ where
   type SubtractiveExtra typ extra :: Constraint
   sub ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (SubtractiveExtra typ xl, SubtractiveExtra typ xu,
       Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xl xu lower upper meas vert horiz height width a

instance
   (ArrMatrix.Subtractive property) =>
      Subtractive (ArrMatrix.Array pack property) where
   type SubtractiveExtra (ArrMatrix.Array pack property) extra = extra ~ ()
   sub = ArrMatrix.sub

instance Subtractive Matrix.Scale where
   type SubtractiveExtra Matrix.Scale extra = extra ~ ()
   sub (Matrix.Scale sha a) (Matrix.Scale shb b) =
      if sha == shb
         then Matrix.Scale sha (a-b)
         else error "Matrix.sub Scale: dimensions mismatch"

infixl 6 #+#, #-#, `add`, `sub`

(#+#) ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Additive typ, AdditiveExtra typ xl, AdditiveExtra typ xu,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a
(#+#) = add

(#-#) ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Subtractive typ, SubtractiveExtra typ xl, SubtractiveExtra typ xu,
    Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   Matrix typ xl xu lower upper meas vert horiz height width a
(#-#) = sub