{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyDataDecls #-}
module Numeric.LAPACK.Matrix.Array (
   Matrix(Array),
   ArrayMatrix,
   Array,

   Full,
   General,
   Tall,
   Wide,
   Square,

   shape,
   reshape,
   mapShape,
   toVector,
   fromVector,
   lift0,
   lift1,
   lift2,
   lift3,
   unlift1,
   unlift2,
   unliftRow,
   unliftColumn,

   Homogeneous(zero, negate, scaleReal), scale, scaleRealReal, (.*#),
   ShapeOrder(forceOrder, shapeOrder), adaptOrder,
   Additive(add, sub), (#+#), (#-#),
   Complex,
   Multiply.MultiplyLeft,
   Multiply.MultiplyRight,
   Multiply.MultiplySquare,
   Multiply.Multiply,
   ) where

import qualified Numeric.LAPACK.Matrix.Array.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Triangular
import qualified Numeric.LAPACK.Matrix.Hermitian.Basic as Hermitian
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix.Array.Format (FormatArray, formatArray)
import Numeric.LAPACK.Matrix.Type (Matrix)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf, ComplexOf)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary

import qualified Control.DeepSeq as DeepSeq

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape

import Prelude hiding (negate)


data Array shape
newtype instance Matrix (Array shape) a = Array (Array.Array shape a)
   deriving (Show)

type ArrayMatrix shape = Matrix (Array shape)


type Full vert horiz height width =
         ArrayMatrix (MatrixShape.Full vert horiz height width)
type General height width = ArrayMatrix (MatrixShape.General height width)
type Tall height width = ArrayMatrix (MatrixShape.Tall height width)
type Wide height width = ArrayMatrix (MatrixShape.Wide height width)
type Square sh = ArrayMatrix (MatrixShape.Square sh)


instance (DeepSeq.NFData shape) => Type.NFData (Array shape) where
   rnf (Array arr) = DeepSeq.rnf arr

instance (Box.Box sh) => Type.Box (Array sh) where
   type HeightOf (Array sh) = Box.HeightOf sh
   type WidthOf (Array sh) = Box.WidthOf sh
   height (Array arr) = Box.height $ Array.shape arr
   width (Array arr) = Box.width $ Array.shape arr


shape :: ArrayMatrix sh a -> sh
shape (Array a) = Array.shape a

reshape ::
   (Shape.C sh0, Shape.C sh1) =>
   sh1 -> ArrayMatrix sh0 a -> ArrayMatrix sh1 a
reshape = lift1 . CheckedArray.reshape

mapShape ::
   (Shape.C sh0, Shape.C sh1) =>
   (sh0 -> sh1) -> ArrayMatrix sh0 a -> ArrayMatrix sh1 a
mapShape = lift1 . CheckedArray.mapShape


toVector :: ArrayMatrix sh a -> Array.Array sh a
toVector (Array a) = a

fromVector :: Array.Array sh a -> ArrayMatrix sh a
fromVector = Array

lift0 :: Array.Array shA a -> ArrayMatrix shA a
lift0 = Array

lift1 ::
   (Array.Array shA a -> Array.Array shB b) ->
   ArrayMatrix shA a -> ArrayMatrix shB b
lift1 f (Array a) = Array $ f a

lift2 ::
   (Array.Array shA a -> Array.Array shB b -> Array.Array shC c) ->
   ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c
lift2 f (Array a) (Array b) = Array $ f a b

lift3 ::
   (Array.Array shA a -> Array.Array shB b ->
    Array.Array shC c -> Array.Array shD d) ->
   ArrayMatrix shA a -> ArrayMatrix shB b ->
   ArrayMatrix shC c -> ArrayMatrix shD d
lift3 f (Array a) (Array b) (Array c) = Array $ f a b c


unlift1 ::
   (ArrayMatrix shA a -> ArrayMatrix shB b) ->
   Array.Array shA a -> Array.Array shB b
unlift1 f a = toVector $ f $ Array a

unlift2 ::
   (ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c) ->
   Array.Array shA a -> Array.Array shB b -> Array.Array shC c
unlift2 f a b = toVector $ f (Array a) (Array b)


unliftRow ::
   MatrixShape.Order ->
   (General () height0 a -> General () height1 b) ->
   Vector height0 a -> Vector height1 b
unliftRow order = Basic.unliftRow order . unlift1

unliftColumn ::
   MatrixShape.Order ->
   (General height0 () a -> General height1 () b) ->
   Vector height0 a -> Vector height1 b
unliftColumn order = Basic.unliftColumn order . unlift1


instance (FormatArray sh) => Type.FormatMatrix (Array sh) where
   formatMatrix fmt (Array a) = formatArray fmt a

instance (Multiply.MultiplySame sh) => Type.MultiplySame (Array sh) where
   multiplySame = lift2 Multiply.same


instance (Complex sh) => Type.Complex (Array sh) where
   conjugate = conjugate
   fromReal  = fromReal
   toComplex = toComplex

class (Shape.C shape) => Complex shape where
   conjugate ::
      (Class.Floating a) => ArrayMatrix shape a -> ArrayMatrix shape a
   conjugate = lift1 Vector.conjugate
   fromReal ::
      (Class.Floating a) =>
      ArrayMatrix shape (RealOf a) -> ArrayMatrix shape a
   fromReal = lift1 Vector.fromReal
   toComplex ::
      (Class.Floating a) =>
      ArrayMatrix shape a -> ArrayMatrix shape (ComplexOf a)
   toComplex = lift1 Vector.toComplex

instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      Complex (MatrixShape.Full vert horiz height width) where

instance (Shape.C size) => Complex (MatrixShape.Hermitian size) where

instance
   (MatrixShape.Content lo, MatrixShape.TriDiag diag, MatrixShape.Content up,
    Shape.C size) =>
      Complex (MatrixShape.Triangular lo diag up size) where

instance
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Complex (MatrixShape.Banded sub super vert horiz height width) where

instance
   (Unary.Natural off, Shape.C size) =>
      Complex (MatrixShape.BandedHermitian off size) where


class (Shape.C shape) => Homogeneous shape where
   zero :: (Class.Floating a) => shape -> ArrayMatrix shape a
   zero = lift0 . Vector.zero
   negate :: (Class.Floating a) => ArrayMatrix shape a -> ArrayMatrix shape a
   negate = lift1 Vector.negate
   scaleReal :: (Class.Floating a) =>
      RealOf a -> ArrayMatrix shape a -> ArrayMatrix shape a
   scaleReal = lift1 . Vector.scaleReal


scale, (.*#) ::
   (Multiply.Scale shape, Class.Floating a) =>
   a -> ArrayMatrix shape a -> ArrayMatrix shape a
scale = lift1 . Multiply.scale
(.*#) = scale

infixl 7 .*#

newtype ScaleReal f a = ScaleReal {getScaleReal :: a -> f a -> f a}

scaleRealReal ::
   (Homogeneous shape, Class.Real a) =>
   a -> ArrayMatrix shape a -> ArrayMatrix shape a
scaleRealReal =
   getScaleReal $ Class.switchReal (ScaleReal scaleReal) (ScaleReal scaleReal)


instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      Homogeneous (MatrixShape.Full vert horiz height width) where

instance (Shape.C size) => Homogeneous (MatrixShape.Hermitian size) where

instance
   (MatrixShape.Content lo, MatrixShape.NonUnit ~ diag, MatrixShape.Content up,
    Shape.C size) =>
      Homogeneous (MatrixShape.Triangular lo diag up size) where

instance
   (Unary.Natural sub, Unary.Natural super, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width) =>
      Homogeneous (MatrixShape.Banded sub super vert horiz height width) where

instance
   (Unary.Natural off, Shape.C size) =>
      Homogeneous (MatrixShape.BandedHermitian off size) where


class (Shape.C shape) => ShapeOrder shape where
   forceOrder ::
      (Class.Floating a) =>
      MatrixShape.Order -> ArrayMatrix shape a -> ArrayMatrix shape a
   shapeOrder :: shape -> MatrixShape.Order

{- |
@adaptOrder x y@ contains the data of @y@ with the layout of @x@.
-}
adaptOrder ::
   (ShapeOrder shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a -> ArrayMatrix shape a
adaptOrder = forceOrder . shapeOrder . shape

instance
   (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) =>
      ShapeOrder (MatrixShape.Full vert horiz height width) where
   forceOrder = lift1 . Basic.forceOrder
   shapeOrder = MatrixShape.fullOrder

instance (Shape.C size) => ShapeOrder (MatrixShape.Hermitian size) where
   forceOrder = lift1 . Hermitian.forceOrder
   shapeOrder = MatrixShape.hermitianOrder

instance
   (MatrixShape.Content lo,
    MatrixShape.TriDiag diag,
    MatrixShape.Content up, Shape.C size) =>
      ShapeOrder (MatrixShape.Triangular lo diag up size) where
   forceOrder = lift1 . Triangular.forceOrder
   shapeOrder = MatrixShape.triangularOrder


class (Homogeneous shape) => Additive shape where
   add, sub ::
      (Class.Floating a) =>
      ArrayMatrix shape a -> ArrayMatrix shape a -> ArrayMatrix shape a
   sub a = add a . negate

instance
   (Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, Shape.C width, Eq width) =>
      Additive (MatrixShape.Full vert horiz height width) where
   add = addGen
   sub = subGen

instance (Shape.C size, Eq size) => Additive (MatrixShape.Hermitian size) where
   add = addGen
   sub = subGen

instance
   (MatrixShape.Content lo, Eq lo,
    MatrixShape.NonUnit ~ diag,
    MatrixShape.Content up, Eq up,
    Shape.C size, Eq size) =>
      Additive (MatrixShape.Triangular lo diag up size) where
   add = addGen
   sub = subGen

addGen, subGen ::
   (ShapeOrder shape, Eq shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a -> ArrayMatrix shape a
addGen a b = lift2 Vector.add (adaptOrder b a) b
subGen a b = lift2 Vector.sub (adaptOrder b a) b


infixl 6 #+#, #-#, `add`, `sub`

(#+#), (#-#) ::
   (Additive shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a -> ArrayMatrix shape a
(#+#) = add
(#-#) = sub