{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Array.Indexed where

import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix
import qualified Numeric.LAPACK.Matrix.Shape.Omni as Omni
import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent as Extent
import Numeric.LAPACK.Matrix.Array.Private (ArrayMatrix)
import Numeric.LAPACK.Scalar (conjugate, zero)

import qualified Numeric.Netlib.Class as Class

import qualified Type.Data.Num.Unary as Unary

import qualified Data.Array.Comfort.Storable.Unchecked as UArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable ((!))

import Foreign.Storable (Storable)

import Data.Maybe.HT (toMaybe)
import Data.Maybe (fromMaybe)
import Data.Tuple.HT (swap)


infixl 9 #!

(#!) ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.Indexed height, Shape.Indexed width, Class.Floating a) =>
   ArrayMatrix pack property lower upper meas vert horiz height width a ->
   (Shape.Index height, Shape.Index width) -> a
a#!ij =
   let shape = ArrMatrix.shape a
   in case shape of
         Omni.Full fullShape ->
            UArray.reshape fullShape (ArrMatrix.unwrap a) ! ij
         Omni.UpperTriangular _ ->
            accessAlt (checkedZero "UpperTriangular" shape ij) a ij
         Omni.LowerTriangular _ ->
            accessAlt (checkedZero "LowerTriangular" shape ij) a ij
         Omni.Symmetric _ ->
            accessAlt (ArrMatrix.toVector a ! swap ij) a ij
         Omni.Hermitian _ ->
            accessAlt (conjugate $ ArrMatrix.toVector a ! swap ij) a ij
         Omni.Banded _ -> accessBanded a ij
         Omni.UnitBandedTriangular _ -> accessBanded a ij
         Omni.BandedHermitian _ ->
            accessAlt
               (maybe (checkedZero "BandedHermitian" shape ij) conjugate $
                accessMaybe (ArrMatrix.toVector a) $ boxIx $ swap ij)
               a (boxIx ij)

accessBanded ::
   (Omni.ToPlain pack prop lower upper meas vert horiz height width) =>
   (Omni.Plain pack prop lower upper meas vert horiz height width ~ shape) =>
   (Layout.Banded sub super meas vert horiz height width ~ shape) =>
   (Unary.Natural sub, Unary.Natural super) =>
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.Indexed height, Shape.Indexed width, Class.Floating a) =>
   ArrayMatrix pack prop lower upper meas vert horiz height width a ->
   (Shape.Index height, Shape.Index width) -> a
accessBanded a ij =
   accessAlt (checkedZero "Banded" (ArrMatrix.shape a) ij) a $ boxIx ij

boxIx :: (row, column) -> Layout.BandedIndex row column
boxIx = uncurry Layout.InsideBox

accessAlt ::
   (Omni.ToPlain pack prop lower upper meas vert horiz height width) =>
   (Omni.Plain pack prop lower upper meas vert horiz height width ~ shape) =>
   (Shape.Indexed shape, Shape.Index shape ~ ix, Storable a) =>
   a -> ArrayMatrix pack prop lower upper meas vert horiz height width a ->
   ix -> a
accessAlt alt a = fromMaybe alt . accessMaybe (ArrMatrix.toVector a)

accessMaybe ::
   (Shape.Indexed sh, Storable a) =>
   UArray.Array sh a -> Shape.Index sh -> Maybe a
accessMaybe arr ij =
   toMaybe (Shape.inBounds (UArray.shape arr) ij) (arr UArray.! ij)

checkedZero ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz) =>
   (Shape.Indexed height, Shape.Indexed width, Class.Floating a) =>
   String ->
   Omni.Omni pack property lower upper meas vert horiz height width ->
   (Shape.Index height, Shape.Index width) -> a
checkedZero name sh ij =
   if Shape.inBounds (Omni.height sh, Omni.width sh) ij
      then zero
      else error $ "Matrix.Indexed." ++ name ++ ": index out of range"