{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE EmptyDataDecls #-}
module Numeric.LAPACK.Matrix.Type where

import qualified Numeric.LAPACK.Matrix.Plain.Format as ArrFormat
import qualified Numeric.LAPACK.Output as Output
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape
import Numeric.LAPACK.Output (Output)

import qualified Numeric.Netlib.Class as Class

import qualified Hyper

import qualified Control.DeepSeq as DeepSeq

import qualified Data.Array.Comfort.Shape as Shape

import Data.Semigroup (Semigroup, (<>))

data family Matrix typ a

data Scale shape
data instance Matrix (Scale shape) a = Scale shape a

newtype instance Matrix (Perm.Permutation sh) a =
   Permutation (Perm.Permutation sh)
      deriving (Show)

instance (NFData typ, DeepSeq.NFData a) => DeepSeq.NFData (Matrix typ a) where
   rnf = rnf

class NFData typ where
   rnf :: (DeepSeq.NFData a) => Matrix typ a -> ()

   (FormatMatrix typ, Class.Floating a) =>
      Hyper.Display (Matrix typ a) where
   display = Output.hyper . formatMatrix ArrFormat.deflt

class FormatMatrix typ where
   We use constraint @(Class.Floating a)@ and not @(Format a)@
   because it allows us to align the components of complex numbers.
   formatMatrix ::
      (Class.Floating a, Output out) => String -> Matrix typ a -> out

instance (Shape.C sh) => FormatMatrix (Scale sh) where
   formatMatrix fmt (Scale shape a) =
      ArrFormat.formatDiagonal fmt MatrixShape.RowMajor shape $
      replicate (Shape.size shape) a

instance (Shape.C sh) => FormatMatrix (Perm.Permutation sh) where
   formatMatrix _fmt (Permutation perm) = Perm.format perm

instance (MultiplySame typ, Class.Floating a) => Semigroup (Matrix typ a) where
   (<>) = multiplySame

class MultiplySame typ where
   multiplySame ::
      (Class.Floating a) => Matrix typ a -> Matrix typ a -> Matrix typ a

instance (Eq shape) => MultiplySame (Scale shape) where
   multiplySame =
      scaleWithCheck "Scale.multiplySame" height
         (\a (Scale shape b) -> Scale shape $ a*b)

instance (Shape.C sh, Eq sh) => MultiplySame (Perm.Permutation sh) where
   multiplySame (Permutation a) (Permutation b) =
      Permutation $ Perm.multiply b a

scaleWithCheck :: (Eq shape) =>
   String -> (b -> shape) ->
   (a -> b -> c) -> Matrix (Scale shape) a -> b -> c
scaleWithCheck name getSize f (Scale shape a) b =
   if shape == getSize b
      then f a b
      else error $ name ++ ": dimensions mismatch"

class Box typ where
   type HeightOf typ
   type WidthOf typ
   height :: Matrix typ a -> HeightOf typ
   width :: Matrix typ a -> WidthOf typ

instance Box (Scale sh) where
   type HeightOf (Scale sh) = sh
   type WidthOf (Scale sh) = sh
   height (Scale shape _) = shape
   width (Scale shape _) = shape

instance Box (Perm.Permutation sh) where
   type HeightOf (Perm.Permutation sh) = sh
   type WidthOf (Perm.Permutation sh) = sh
   height (Permutation perm) = Perm.size perm
   width (Permutation perm) = Perm.size perm

indices ::
   (Box typ,
    HeightOf typ ~ height, Shape.Indexed height,
    WidthOf typ ~ width, Shape.Indexed width) =>
   Matrix typ a -> [(Shape.Index height, Shape.Index width)]
indices sh = Shape.indices (height sh, width sh)