{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE Rank2Types #-}
module Numeric.LAPACK.Matrix.Extent.Strict where

import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import Numeric.LAPACK.Matrix.Extent.Private
         (C, Extent, Measure, switchTag,
          Shape, Size, Small, Big, errorTagTriple)

import qualified Data.Array.Comfort.Shape as Shape


newtype Map measA vertA horizA measB vertB horizB height width =
   Map {apply :: Extent.Map measA vertA horizA measB vertB horizB height width}

transpose ::
   (Measure measA, C vertA, C horizA) =>
   (Measure measB, C vertB, C horizB) =>
   Map measA vertA horizA measB vertB horizB height width ->
   Map measA horizA vertA measB horizB vertB width height
transpose (Map m) = Map (Extent.transpose . m . Extent.transpose)


{- |
Admissible tag combinations are:

> meas  vert  horiz
> Shape Small Small - Square
> Size  Small Small - LiberalSquare
> Size  Big   Small - Tall
> Size  Small Big   - Wide
> Size  Big   Big   - General

We can enforce this set with the constraints

> (Extent.Measured meas vert, Extent.Measured meas horiz)

However, in some cases it leads to constraints
like @Measured meas Small@ or @Measured meas Big@.
The former one is morally equivalent to @Measure meas@
and the latter one is morally equivalent to @meas ~ Size@.
However, in order to convince the compiler
you would have to go through 'switchMeasured'.

In order to circumvent this trouble
we use internal functions with weaker constraints:

> (Extent.Measure meas, Extent.C vert, Extent.C horiz)

This is typesafe whenever the input
is based on one of the five admissible extent types.
We only need the strict constraints
when constructing matrices of arbitrary extent type,
i.e. this almost only concerns 'Numeric.LAPACK.Matrix.Extent.fromSquare'.
-}
class (Measure meas, C tag) => Measured meas tag where
   switchMeasured :: f Shape Small -> f Size Small -> f Size Big -> f meas tag
instance (tag ~ Small) => Measured Shape tag where
   switchMeasured f _ _ = f
instance (C tag) => Measured Size tag where
   switchMeasured _ = switchTag

{-
Alternative set of instances:

instance (Measure meas) => Measured meas Small where
instance (meas ~ Size) => Measured meas Big where
-}


newtype RotRight3 f c a b = RotRight3 {getRotRight3 :: f a b c}

switchTagTriple ::
   (Measured meas vert, Measured meas horiz) =>
   f Shape Small Small -> f Size Small Small -> f Size Small Big ->
   f Size Big Small -> f Size Big Big -> f meas vert horiz
switchTagTriple fSquare fLiberalSquare fWide fTall fGeneral =
   getRotRight3 $
   switchMeasured
      (RotRight3 $ switchTag fSquare errorTagTriple)
      (RotRight3 $ switchTag fLiberalSquare fWide)
      (RotRight3 $ switchTag fTall fGeneral)


type family MeasureTarget meas sh
type instance MeasureTarget Shape sh = sh
type instance MeasureTarget Size sh = Int

type family Dimension meas height width
type instance Dimension Shape height width = height
type instance Dimension Size height width = (height, width)


data Cons_ height width meas vert horiz =
   Cons {
      getCons ::
         (MeasureTarget meas height ~ MeasureTarget meas width) =>
         Dimension meas height width -> Extent meas vert horiz height width
   }

consChecked ::
   (Measured meas vert, Measured meas horiz) =>
   (Shape.C height, Shape.C width) =>
   (MeasureTarget meas height ~ MeasureTarget meas width) =>
   Dimension meas height width ->
   Extent meas vert horiz height width
consChecked =
   getCons $
   switchTagTriple
      (Cons Extent.Square)
      (Cons $ \(height, width) ->
         if Shape.size height == Shape.size width
            then Extent.liberalSquare height width
            else error "Extent.liberalSquare: height and width size differ")
      (Cons $ \(height, width) ->
         if Shape.size height <= Shape.size width
            then Extent.wide height width
            else error "Extent.wide: width smaller than height")
      (Cons $ \(height, width) ->
         if Shape.size height >= Shape.size width
            then Extent.tall height width
            else error "Extent.tall: height smaller than width")
      (Cons $ uncurry Extent.general)


unifiers ::
   (Extent.Measure measA, Extent.C vertA, Extent.C horizA,
    Extent.Measure measB, Extent.C vertB, Extent.C horizB) =>
   (Extent.MultiplyMeasure measA measB ~ measC) =>
   (Extent.Multiply vertA vertB ~ vertC) =>
   (Extent.Multiply horizA horizB ~ horizC) =>
   Extent measA vertA horizA height fuse ->
   Extent measB vertB horizB fuse width ->
   ((Extent.MeasureFact measC, Extent.TagFact vertC, Extent.TagFact horizC),
    (Map measA vertA horizA measC vertC horizC height fuse,
     Map measB vertB horizB measC vertC horizC fuse width))
unifiers a b =
   ((Extent.multiplyMeasureLaw (Extent.measureFact a) (Extent.measureFact b),
     Extent.multiplyTagLaw (Extent.heightFact a) (Extent.heightFact b),
     Extent.multiplyTagLaw (Extent.widthFact a) (Extent.widthFact b)),
    (Map $ flip Extent.unifyLeft b, Map $ Extent.unifyRight a))