{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
module Numeric.LAPACK.Matrix.Type.Private where

import qualified Numeric.LAPACK.Matrix.Plain.Format as ArrFormat
import qualified Numeric.LAPACK.Output as Output
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Layout.Private (Empty, Filled)
import Numeric.LAPACK.Matrix.Extent.Private (Extent, Shape, Small)
import Numeric.LAPACK.Output (Output)

import qualified Numeric.Netlib.Class as Class

import qualified Hyper

import qualified Control.DeepSeq as DeepSeq
import Control.Applicative ((<$>))

import qualified Data.Array.Comfort.Boxed as BoxedArray
import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.Foldable as Fold
import Data.Function.HT (Id)
import Data.Monoid (Monoid, mempty, mappend)
import Data.Semigroup (Semigroup, (<>))
import Data.Foldable (Foldable)
import Data.Maybe (fromMaybe)
import Data.Tuple.HT (mapSnd)

import GHC.Exts (Constraint)



data family
   Matrix typ extraLower extraUpper lower upper meas vert horiz height width a

type Quadratic typ extraLower extraUpper lower upper sh =
      QuadraticMeas typ extraLower extraUpper lower upper Shape sh sh
type QuadraticMeas typ extraLower extraUpper lower upper meas =
      Matrix typ extraLower extraUpper lower upper meas Small Small


asQuadratic ::
   Id (QuadraticMeas typ extraLower extraUpper lower upper meas height width a)
asQuadratic = id


data Product fuse
data instance
   Matrix (Product fuse) xl xu lower upper meas vert horiz height width a where
      Product ::
         (Omni.MultipliedBands lowerA lowerB ~ lowerC,
          Omni.MultipliedBands lowerB lowerA ~ lowerC,
          Omni.MultipliedBands upperA upperB ~ upperC,
          Omni.MultipliedBands upperB upperA ~ upperC) =>
         Matrix typA xlA xuA lowerA upperA meas vert horiz height fuse a ->
         Matrix typB xlB xuB lowerB upperB meas vert horiz fuse width a ->
         Matrix (Product fuse)
            (typA,xlA,xuA,lowerA,upperA) (typB,xuB,xlB,upperB,lowerB)
            lowerC upperC meas vert horiz height width a

type family ProductType   extra
type family ProductExtraL extra
type family ProductExtraU extra
type family ProductLower  extra
type family ProductUpper  extra
type instance ProductType   (typ,xl,xu,lower,upper) = typ
type instance ProductExtraL (typ,xl,xu,lower,upper) = xl
type instance ProductExtraU (typ,xl,xu,lower,upper) = xu
type instance ProductLower  (typ,xl,xu,lower,upper) = lower
type instance ProductUpper  (typ,xl,xu,lower,upper) = upper


data Scale
data instance
   Matrix Scale xl xu lower upper meas vert horiz height width a where
      Scale :: sh -> a -> Quadratic Scale () () Empty Empty sh a

deriving instance
   (Shape.C height, Show height, Show a) =>
   Show (Matrix Scale xl xu lower upper meas vert horiz height width a)


data Identity
data instance
   Matrix Identity xl xu lower upper meas vert horiz height width a where
      Identity ::
         (Extent.Measure meas) =>
         Extent meas Small Small height width ->
         QuadraticMeas Identity () () Empty Empty meas height width a


data Permutation
data instance
   Matrix Permutation xl xu lower upper meas vert horiz height width a where
   Permutation ::
      Perm.Permutation sh -> Quadratic Permutation () () lower upper sh a

deriving instance
   (Shape.C height, Show height) =>
   Show (Matrix Permutation xl xu lower upper meas vert horiz height width a)

deriving instance
   (Shape.C height, Eq height) =>
   Eq (Matrix Permutation xl xu lower upper meas vert horiz height width a)


instance
   (NFData typ,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    DeepSeq.NFData height, DeepSeq.NFData width, DeepSeq.NFData a) =>
   DeepSeq.NFData (Matrix typ xl xu lower upper meas vert horiz height width a)
      where
   rnf = rnf

class NFData typ where
   rnf ::
      (Extent.Measure meas, Extent.C vert, Extent.C horiz,
       DeepSeq.NFData height, DeepSeq.NFData width, DeepSeq.NFData a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a -> ()



instance
   (Format typ, FormatExtra typ xl, FormatExtra typ xu,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a) =>
      Hyper.Display
         (Matrix typ xl xu lower upper meas vert horiz height width a) where
   display = Output.hyper . format ArrFormat.defltConfig


class Format typ where
   type FormatExtra typ extra :: Constraint
   format ::
      (FormatExtra typ xl, FormatExtra typ xu,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a, Output out) =>
      ArrFormat.Config ->
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      out

{- |
Default implementation of 'format'.
Some matrices need more than one array for display,
e.g. @Householder@ and @LowerUpper@.
'Layout' class is still needed for @Block@ matrices.
-}
formatWithLayout ::
   (Layout typ, LayoutExtra typ xl, LayoutExtra typ xu,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Shape.C width, Class.Floating a, Output out) =>
   ArrFormat.Config ->
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   out
formatWithLayout fmt =
   Output.formatTable . map (concatMap (uncurry attachSeparators)) .
   ArrFormat.toRows . fmap (mapSnd (formatCell fmt)) . layout

attachSeparators :: (Foldable f) =>
   Output.Separator -> f (style,str) -> [(Output.Separator, style, str)]
attachSeparators sep0 =
   map (\(sep,(style,x)) -> (sep,style,x)) .
   zip (sep0 : repeat Output.Empty) . Fold.toList

formatCell ::
   (Class.Floating a, Output out) =>
   ArrFormat.Config -> Maybe (Output.Style, a) ->
   ArrFormat.Tuple a (Output.Style, out)
formatCell fmt =
   maybe
      (ArrFormat.fillTuple
         (Output.Stored, Output.text $ ArrFormat.configEmpty fmt))
      (\(style,a) ->
         (,) style . Output.text <$> ArrFormat.printfFloating fmt a)


instance Format Scale where
   type FormatExtra Scale extra = ()
   format = formatWithLayout

instance Format Permutation where
   type FormatExtra Permutation extra = ()
   format _fmt (Permutation perm) = Perm.format perm


{- |
Layout matrix elements for use in formatting a block matrix.
Optimally its implementation is reused in 'format' via 'formatWithLayout',
but sometimes that is not possible.
-}
class (Box typ) => Layout typ where
   type LayoutExtra typ extra :: Constraint
   {-
   We use constraint @(Class.Floating a)@ and not @(Format a)@
   because it allows us to align the components of complex numbers.

   We use a BoxedArray instead of a nested list,
   although both the underlying formatters
   and the frontend use a nested list.
   This gives us a little more type safety for block matrices.
   -}
   layout ::
      (LayoutExtra typ xl, LayoutExtra typ xu,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      BoxedArray.Array (height, width)
         (Output.Separator, Maybe (Output.Style, a))

instance Layout Scale where
   type LayoutExtra Scale extra = ()
   layout (Scale shape a) =
      let n = Shape.size shape in
      -- ToDo: 'take' no longer needed when BoxedArray.fromList includes it
      BoxedArray.fromList (shape,shape) $ take (n*n) $
      cycle $
      (Output.Space, Just (Output.Stored, a))
      :
      replicate n (Output.Space, Nothing)

instance Layout Permutation where
   type LayoutExtra Permutation extra = ()
   layout (Permutation perm) =
      let sh = Perm.size perm in
      BoxedArray.fromList (sh,sh) $ concat $
      map (map ((,) Output.Space . fmap ((,) Output.Stored))) $
      Perm.layout perm



instance
   (MultiplySame typ, MultiplySameExtra typ xl, MultiplySameExtra typ xu,
    MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
    Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C height, Eq height, height ~ width, Class.Floating a) =>
      Semigroup (Matrix typ xl xu lower upper meas vert horiz height width a)
         where
   (<>) = multiplySame

class (Box typ) => MultiplySame typ where
   type MultiplySameExtra typ extra :: Constraint
   multiplySame ::
      (matrix ~ Matrix typ xl xu lower upper meas vert horiz sh sh a,
       MultiplySameExtra typ xl, MultiplySameExtra typ xu,
       MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
       Extent.Measure meas, Extent.C vert, Extent.C horiz,
       Shape.C sh, Eq sh, Class.Floating a) =>
      matrix -> matrix -> matrix

instance MultiplySame Scale where
   type MultiplySameExtra Scale extra = extra ~ ()
   multiplySame =
      scaleWithCheck "Scale.multiplySame" height
         (\a (Scale shape b) -> Scale shape $ a*b)

instance MultiplySame Permutation where
   type MultiplySameExtra Permutation extra = extra ~ ()
   multiplySame (Permutation a) (Permutation b) =
      Permutation $ Perm.multiply b a


instance
   (MultiplySame typ, StaticIdentity typ,
    MultiplySameExtra typ xl, MultiplySameExtra typ xu,
    StaticIdentityExtra typ xl, StaticIdentityStrip typ lower,
    StaticIdentityExtra typ xu, StaticIdentityStrip typ upper,
    MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper,
    meas ~ Shape, vert ~ Small, horiz ~ Small,
    Shape.Static height, Eq height, height ~ width, Class.Floating a) =>
      Monoid (Matrix typ xl xu lower upper meas vert horiz height width a) where
   mappend = (<>)
   mempty = staticIdentity

class StaticIdentity typ where
   type StaticIdentityExtra typ extra :: Constraint
   type StaticIdentityStrip typ strip :: Constraint
   staticIdentity ::
      (StaticIdentityExtra typ xl, StaticIdentityStrip typ lower) =>
      (StaticIdentityExtra typ xu, StaticIdentityStrip typ upper) =>
      (Shape.Static sh, Class.Floating a) =>
      Quadratic typ xl xu lower upper sh a

instance StaticIdentity Scale where
   type StaticIdentityExtra Scale extra = extra ~ ()
   type StaticIdentityStrip Scale strip = strip ~ Empty
   staticIdentity = Scale Shape.static 1

instance StaticIdentity Permutation where
   type StaticIdentityExtra Permutation extra = extra ~ ()
   type StaticIdentityStrip Permutation strip = strip ~ Filled
   staticIdentity = Permutation $ Perm.identity Shape.static


scaleWithCheck :: (Eq shape) =>
   String -> (b -> shape) ->
   (a -> b -> c) ->
   Matrix Scale xl xu lower upper meas vert horiz shape shape a ->
   b -> c
scaleWithCheck name getSize f (Scale shape a) b =
   if shape == getSize b
      then f a b
      else error $ name ++ ": dimensions mismatch"


class Box typ where
   type BoxExtra typ extra :: Constraint
   extent ::
      (BoxExtra typ xl, BoxExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Extent.Extent meas vert horiz height width
   height ::
      (BoxExtra typ xl, BoxExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      Matrix typ xl xu lower upper meas vert horiz height width a -> height
   height = Extent.height . extent
   width ::
      (BoxExtra typ xl, BoxExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      Matrix typ xl xu lower upper meas vert horiz height width a -> width
   width = Extent.width . extent

instance Box Scale where
   type BoxExtra Scale extra = ()
   extent (Scale shape _) = Extent.square shape
   height (Scale shape _) = shape
   width (Scale shape _) = shape

instance Box Identity where
   type BoxExtra Identity extra = ()
   extent (Identity extent_) = extent_

instance Box Permutation where
   type BoxExtra Permutation extra = ()
   extent (Permutation perm) = Extent.square $ Perm.size perm
   height (Permutation perm) = Perm.size perm
   width (Permutation perm) = Perm.size perm

instance (Eq fuse) => Box (Product fuse) where
   type BoxExtra (Product fuse) extra =
         (Box (ProductType extra),
          BoxExtra (ProductType extra) (ProductExtraL extra),
          BoxExtra (ProductType extra) (ProductExtraU extra))
   extent (Product a b) =
      fromMaybe (error "Matrix.Product: shapes mismatch") $
      Extent.fuse (extent a) (extent b)

squareSize ::
   (Box typ, BoxExtra typ xl, BoxExtra typ xu) =>
   Quadratic typ xl xu lower upper sh a -> sh
squareSize = height

indices ::
   (Box typ, BoxExtra typ xl, BoxExtra typ xu,
    Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.Indexed height, Shape.Indexed width) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   [(Shape.Index height, Shape.Index width)]
indices sh = Shape.indices (height sh, width sh)


class (Box typ) => ToQuadratic typ where
   heightToQuadratic ::
      (Extent.Measure meas) =>
      QuadraticMeas typ xl xu lower upper meas height width a ->
      Quadratic typ xl xu lower upper height a
   widthToQuadratic ::
      (Extent.Measure meas) =>
      QuadraticMeas typ xl xu lower upper meas height width a ->
      Quadratic typ xl xu lower upper width a

instance ToQuadratic Scale where
   heightToQuadratic (Scale shape a) = Scale shape a
   widthToQuadratic (Scale shape a) = Scale shape a

instance ToQuadratic Identity where
   heightToQuadratic (Identity extent_) =
      Identity $ Extent.square $ Extent.height extent_
   widthToQuadratic (Identity extent_) =
      Identity $ Extent.square $ Extent.width extent_

instance ToQuadratic Permutation where
   heightToQuadratic (Permutation perm) = Permutation perm
   widthToQuadratic (Permutation perm) = Permutation perm


class (Box typ) => MapExtent typ where
   type MapExtentExtra typ extra :: Constraint
   type MapExtentStrip typ strip :: Constraint
   mapExtent ::
      (MapExtentExtra typ xl, MapExtentStrip typ lower) =>
      (MapExtentExtra typ xu, MapExtentStrip typ upper) =>
      (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
      (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
      ExtentStrict.Map measA vertA horizA measB vertB horizB height width ->
      Matrix typ xl xu lower upper measA vertA horizA height width a ->
      Matrix typ xl xu lower upper measB vertB horizB height width a


class (Box typ) => Transpose typ where
   type TransposeExtra typ extra :: Constraint
   transpose ::
      (TransposeExtra typ xl, TransposeExtra typ xu) =>
      (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
      (Shape.C height, Shape.C width, Class.Floating a) =>
      Matrix typ xl xu lower upper meas vert horiz height width a ->
      Matrix typ xu xl upper lower meas horiz vert width height a

instance Transpose Scale where
   type TransposeExtra Scale extra = ()
   transpose (Scale shape a) = Scale shape a

instance Transpose Identity where
   type TransposeExtra Identity extra = ()
   transpose (Identity extent_) = Identity $ Extent.transpose extent_

instance Transpose Permutation where
   type TransposeExtra Permutation extra = ()
   transpose (Permutation perm) = Permutation $ Perm.transpose perm

instance (Shape.C fuse, Eq fuse) => Transpose (Product fuse) where
   type TransposeExtra (Product fuse) extra =
         (Transpose (ProductType extra),
          TransposeExtra (ProductType extra) (ProductExtraL extra),
          TransposeExtra (ProductType extra) (ProductExtraU extra))
   transpose (Product a b) = Product (transpose b) (transpose a)


swapMultiply ::
   (Transpose typA, Transpose typB) =>
   (TransposeExtra typA xlA, TransposeExtra typA xuA) =>
   (TransposeExtra typB xlB, TransposeExtra typB xuB) =>
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA) =>
   (Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
   (Shape.C heightA, Shape.C widthA) =>
   (Shape.C heightB, Shape.C widthB) =>
   (Class.Floating a) =>
   (matrix ->
    Matrix typA xuA xlA upperA lowerA measA horizA vertA widthA heightA a ->
    Matrix typB xuB xlB upperB lowerB measB horizB vertB widthB heightB a) ->
   Matrix typA xlA xuA lowerA upperA measA vertA horizA heightA widthA a ->
   matrix ->
   Matrix typB xlB xuB lowerB upperB measB vertB horizB heightB widthB a
swapMultiply multiplyTrans a b =
   transpose $ multiplyTrans b $ transpose a

powerStrips ::
   (MatrixShape.PowerStrip lower, MatrixShape.PowerStrip upper) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   (MatrixShape.PowerStripSingleton lower,
    MatrixShape.PowerStripSingleton upper)
powerStrips _ =
   (MatrixShape.powerStripSingleton, MatrixShape.powerStripSingleton)

strips ::
   (MatrixShape.Strip lower, MatrixShape.Strip upper) =>
   Matrix typ xl xu lower upper meas vert horiz height width a ->
   (MatrixShape.StripSingleton lower, MatrixShape.StripSingleton upper)
strips _ = (MatrixShape.stripSingleton, MatrixShape.stripSingleton)