{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyDataDecls #-}
module Numeric.LAPACK.Matrix.Array (



   Plain.Homogeneous, zero, negate, scaleReal, scale, scaleRealReal, (.*#),
   Plain.ShapeOrder, forceOrder, Plain.shapeOrder, adaptOrder,
   Plain.Additive, add, sub, (#+#), (#-#),
   ) where

import qualified Numeric.LAPACK.Matrix.Plain.Divide as Divide
import qualified Numeric.LAPACK.Matrix.Plain.Multiply as Multiply
import qualified Numeric.LAPACK.Matrix.Plain.Class as Plain
import qualified Numeric.LAPACK.Matrix.Type as Type
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Box as Box
import qualified Numeric.LAPACK.Matrix.Basic as Basic
import Numeric.LAPACK.Matrix.Plain.Format (FormatArray, formatArray)
import Numeric.LAPACK.Matrix.Type (Matrix)
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Scalar (RealOf)

import qualified Numeric.Netlib.Class as Class

import qualified Control.DeepSeq as DeepSeq

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape

import Prelude hiding (negate)

data Array shape
newtype instance Matrix (Array shape) a = Array (Array.Array shape a)
   deriving (Show)

type ArrayMatrix shape = Matrix (Array shape)

type Full vert horiz height width =
         ArrayMatrix (MatrixShape.Full vert horiz height width)
type General height width = ArrayMatrix (MatrixShape.General height width)
type Tall height width = ArrayMatrix (MatrixShape.Tall height width)
type Wide height width = ArrayMatrix (MatrixShape.Wide height width)
type Square sh = ArrayMatrix (MatrixShape.Square sh)

instance (DeepSeq.NFData shape) => Type.NFData (Array shape) where
   rnf (Array arr) = DeepSeq.rnf arr

instance (Box.Box sh) => Type.Box (Array sh) where
   type HeightOf (Array sh) = Box.HeightOf sh
   type WidthOf (Array sh) = Box.WidthOf sh
   height (Array arr) = Box.height $ Array.shape arr
   width (Array arr) = Box.width $ Array.shape arr

shape :: ArrayMatrix sh a -> sh
shape (Array a) = Array.shape a

reshape ::
   (Shape.C sh0, Shape.C sh1) =>
   sh1 -> ArrayMatrix sh0 a -> ArrayMatrix sh1 a
reshape = lift1 . CheckedArray.reshape

mapShape ::
   (Shape.C sh0, Shape.C sh1) =>
   (sh0 -> sh1) -> ArrayMatrix sh0 a -> ArrayMatrix sh1 a
mapShape = lift1 . CheckedArray.mapShape

toVector :: ArrayMatrix sh a -> Array.Array sh a
toVector (Array a) = a

fromVector ::
   (Plain.Admissible sh, Class.Floating a) =>
   Array.Array sh a -> ArrayMatrix sh a
fromVector arr =
   Array $
   case Plain.check arr of
      Nothing -> arr
      Just msg -> error $ "Matrix.Array.fromVector: " ++ msg

{- |
'lift0' is a synonym for 'fromVector' but lacks the admissibility check.
You may thus fool the type tags.
This applies to the other lift functions, too.
lift0 :: Array.Array shA a -> ArrayMatrix shA a
lift0 = Array

lift1 ::
   (Array.Array shA a -> Array.Array shB b) ->
   ArrayMatrix shA a -> ArrayMatrix shB b
lift1 f (Array a) = Array $ f a

lift2 ::
   (Array.Array shA a -> Array.Array shB b -> Array.Array shC c) ->
   ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c
lift2 f (Array a) (Array b) = Array $ f a b

lift3 ::
   (Array.Array shA a -> Array.Array shB b ->
    Array.Array shC c -> Array.Array shD d) ->
   ArrayMatrix shA a -> ArrayMatrix shB b ->
   ArrayMatrix shC c -> ArrayMatrix shD d
lift3 f (Array a) (Array b) (Array c) = Array $ f a b c

lift4 ::
   (Array.Array shA a -> Array.Array shB b ->
    Array.Array shC c -> Array.Array shD d ->
    Array.Array shE e) ->
   ArrayMatrix shA a -> ArrayMatrix shB b ->
   ArrayMatrix shC c -> ArrayMatrix shD d ->
   ArrayMatrix shE e
lift4 f (Array a) (Array b) (Array c) (Array d) = Array $ f a b c d

unlift1 ::
   (ArrayMatrix shA a -> ArrayMatrix shB b) ->
   Array.Array shA a -> Array.Array shB b
unlift1 f a = toVector $ f $ Array a

unlift2 ::
   (ArrayMatrix shA a -> ArrayMatrix shB b -> ArrayMatrix shC c) ->
   Array.Array shA a -> Array.Array shB b -> Array.Array shC c
unlift2 f a b = toVector $ f (Array a) (Array b)

unliftRow ::
   MatrixShape.Order ->
   (General () height0 a -> General () height1 b) ->
   Vector height0 a -> Vector height1 b
unliftRow order = Basic.unliftRow order . unlift1

unliftColumn ::
   MatrixShape.Order ->
   (General height0 () a -> General height1 () b) ->
   Vector height0 a -> Vector height1 b
unliftColumn order = Basic.unliftColumn order . unlift1

instance (FormatArray sh) => Type.FormatMatrix (Array sh) where
   formatMatrix fmt (Array a) = formatArray fmt a

instance (Multiply.MultiplySame sh) => Type.MultiplySame (Array sh) where
   multiplySame = lift2 Multiply.same

zero ::
   (Plain.Homogeneous shape, Class.Floating a) => shape -> ArrayMatrix shape a
zero = lift0 . Plain.zero

negate ::
   (Plain.Homogeneous shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a
negate = lift1 Plain.negate

scaleReal ::
   (Plain.Homogeneous shape, Class.Floating a) =>
   RealOf a -> ArrayMatrix shape a -> ArrayMatrix shape a
scaleReal = lift1 . Plain.scaleReal

newtype ScaleReal f a = ScaleReal {getScaleReal :: a -> f a -> f a}

scaleRealReal ::
   (Plain.Homogeneous shape, Class.Real a) =>
   a -> ArrayMatrix shape a -> ArrayMatrix shape a
scaleRealReal =
   getScaleReal $ Class.switchReal (ScaleReal scaleReal) (ScaleReal scaleReal)

scale, (.*#) ::
   (Multiply.Scale shape, Class.Floating a) =>
   a -> ArrayMatrix shape a -> ArrayMatrix shape a
scale = lift1 . Multiply.scale
(.*#) = scale

infixl 7 .*#

forceOrder ::
   (Plain.ShapeOrder shape, Class.Floating a) =>
   MatrixShape.Order -> ArrayMatrix shape a -> ArrayMatrix shape a
forceOrder = lift1 . Plain.forceOrder

{- |
@adaptOrder x y@ contains the data of @y@ with the layout of @x@.
adaptOrder ::
   (Plain.ShapeOrder shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a -> ArrayMatrix shape a
adaptOrder = lift2 Plain.adaptOrder

infixl 6 #+#, #-#, `add`, `sub`

add, sub, (#+#), (#-#) ::
   (Plain.Additive shape, Class.Floating a) =>
   ArrayMatrix shape a -> ArrayMatrix shape a -> ArrayMatrix shape a
add = lift2 Plain.add
sub = lift2 Plain.sub
(#+#) = add
(#-#) = sub