{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Extent.Private where

import Numeric.LAPACK.Shape.Private (Unchecked(deconsUnchecked))
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Shape ((::+)((::+)))

import Text.Printf (printf)

import Control.DeepSeq (NFData, rnf)
import Control.Applicative (Const(Const))

import Data.Maybe.HT (toMaybe)


data Extent meas vert horiz height width where
   Square :: size -> Extent Shape Small Small size size
   Separate :: height -> width -> Extent Size vert horiz height width


instance
   (Measure measure, C vertical, C horizontal, NFData height, NFData width) =>
      NFData (Extent measure vertical horizontal height width) where
   rnf (Square s) = rnf s
   rnf (Separate h w) = rnf (h,w)


data Big = Big deriving (Eq,Show)
data Small = Small deriving (Eq,Show)

instance NFData Big where rnf Big = ()
instance NFData Small where rnf Small = ()


data Size = Size deriving (Eq,Show)
data Shape = Shape deriving (Eq,Show)

instance NFData Size where rnf Size = ()
instance NFData Shape where rnf Shape = ()


type General = Extent Size Big Big
type Tall = Extent Size Big Small
type Wide = Extent Size Small Big
type SquareMeas meas = Extent meas Small Small
type LiberalSquare = SquareMeas Size
type Square sh = SquareMeas Shape sh sh


general :: height -> width -> General height width
general = Separate

tall :: height -> width -> Tall height width
tall = Separate

wide :: height -> width -> Wide height width
wide = Separate

liberalSquare :: height -> width -> LiberalSquare height width
liberalSquare = Separate

square :: sh -> Square sh
square = Square


type Map measA vertA horizA measB vertB horizB height width =
         Extent measA vertA horizA height width ->
         Extent measB vertB horizB height width


class C tag where switchTag :: f Small -> f Big -> f tag
instance C Small where switchTag f _ = f
instance C Big where switchTag _ f = f

class Measure meas where switchMeasure :: f Shape -> f Size -> f meas
instance Measure Shape where switchMeasure f _ = f
instance Measure Size where switchMeasure _ f = f


switchTagPair ::
   (C vert, C horiz) =>
   f Small Small -> f Small Big -> f Big Small -> f Big Big -> f vert horiz
switchTagPair fSquare fWide fTall fGeneral =
   getFlip $
   switchTag
      (Flip $ switchTag fSquare fWide)
      (Flip $ switchTag fTall fGeneral)


newtype RotLeft3 f b c a = RotLeft3 {getRotLeft3 :: f a b c}

switchMeasureExtent ::
   (Measure meas, C vert, C horiz) =>
   f Shape Small Small ->
   (forall vert0 horiz0. (C vert0, C horiz0) => f Size vert0 horiz0) ->
   f meas vert horiz
switchMeasureExtent fSquare fGeneral =
   getRotLeft3 $
      switchMeasure
         (RotLeft3 $ switchTagPair fSquare
                        errorTagTriple errorTagTriple errorTagTriple)
         (RotLeft3 $ switchTagPair fGeneral fGeneral fGeneral fGeneral)


errorTagTripleAux ::
   Const String meas -> Const String vert -> Const String horiz ->
   f meas vert horiz
errorTagTripleAux (Const meas) (Const vert) (Const horiz) =
   error $ printf "forbidden Extent tag combination %s %s %s" meas vert horiz

showConst :: (Show a) => a -> Const String a
showConst a = Const $ show a

errorTagTriple :: (Measure meas, C vert, C horiz) => f meas vert horiz
errorTagTriple =
   errorTagTripleAux
      (switchMeasure (showConst Shape) (showConst Size))
      (switchTag (showConst Small) (showConst Big))
      (switchTag (showConst Small) (showConst Big))

switchTagTriple ::
   (Measure meas, C vert, C horiz) =>
   f Shape Small Small -> f Size Small Small -> f Size Small Big ->
   f Size Big Small -> f Size Big Big -> f meas vert horiz
switchTagTriple fSquare fLiberalSquare fWide fTall fGeneral =
   getRotLeft3 $
      switchMeasure
         (RotLeft3 $ switchTagPair fSquare
                        errorTagTriple errorTagTriple errorTagTriple)
         (RotLeft3 $ switchTagPair fLiberalSquare fWide fTall fGeneral)


caseTallWide ::
   (Measure meas, C vert, C horiz) =>
   (height -> width -> Bool) ->
   Extent meas vert horiz height width ->
   Either (Tall height width) (Wide height width)
caseTallWide _ (Square sh) = Left $ tall sh sh
caseTallWide ge x@(Separate _ _) =
   flip getAccessor x $
   switchTagPair
      (Accessor $ \(Separate h w) -> Left $ tall h w)
      (Accessor Right)
      (Accessor Left)
      (Accessor $ \(Separate h w) ->
         if ge h w
            then Left $ tall h w
            else Right $ wide h w)


newtype GenSquare sh meas vert horiz =
   GenSquare {getGenSquare :: sh -> Extent meas vert horiz sh sh}

genSquare ::
   (Measure meas, C vert, C horiz) => sh -> Extent meas vert horiz sh sh
genSquare =
   getGenSquare $
   switchMeasureExtent
      (GenSquare square)
      (GenSquare (\sh -> Separate sh sh))

genLiberalSquare ::
   (C vert, C horiz) => height -> width -> Extent Size vert horiz height width
genLiberalSquare = Separate

relaxMeasure :: (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width ->
   Extent Size vert horiz height width
relaxMeasure (Square s) = genSquare s
relaxMeasure (Separate h w) = Separate h w

newtype GenTall height width meas vert horiz =
   GenTall {
      getGenTall ::
         Extent meas vert Small height width ->
         Extent Size vert horiz height width
   }

generalizeTall :: (Measure meas, C vert, C horiz) =>
   Extent meas vert Small height width -> Extent Size vert horiz height width
generalizeTall =
   getGenTall
      (switchTagPair
         (GenTall id) (GenTall $ \(Separate h w) -> wide h w)
         (GenTall id) (GenTall $ \(Separate h w) -> general h w))
   .
   relaxMeasure

newtype GenWide height width meas vert horiz =
   GenWide {
      getGenWide ::
         Extent meas Small horiz height width ->
         Extent Size vert horiz height width
   }

generalizeWide :: (Measure meas, C vert, C horiz) =>
   Extent meas Small horiz height width -> Extent Size vert horiz height width
generalizeWide =
   getGenWide
      (switchTagPair
         (GenWide id)
         (GenWide id)
         (GenWide $ \(Separate h w) -> tall h w)
         (GenWide $ \(Separate h w) -> general h w))
   .
   relaxMeasure


newtype WeakenTall height width meas vert horiz =
   WeakenTall {
      getWeakenTall ::
         Extent meas vert Small height width ->
         Extent meas vert horiz height width
   }

weakenTall :: (Measure meas, C vert, C horiz) =>
   Extent meas vert Small height width -> Extent meas vert horiz height width
weakenTall =
   getWeakenTall $
   switchTagTriple
      (WeakenTall fromSquareLiberal)
      (WeakenTall id) (WeakenTall $ \(Separate h w) -> wide h w)
      (WeakenTall id) (WeakenTall $ \(Separate h w) -> general h w)

newtype WeakenWide height width meas vert horiz =
   WeakenWide {
      getWeakenWide ::
         Extent meas Small horiz height width ->
         Extent meas vert horiz height width
   }

weakenWide :: (Measure meas, C vert, C horiz) =>
   Extent meas Small horiz height width -> Extent meas vert horiz height width
weakenWide =
   getWeakenWide $
   switchTagTriple
      (WeakenWide fromSquareLiberal)
      (WeakenWide id)
      (WeakenWide id)
      (WeakenWide $ \(Separate h w) -> tall h w)
      (WeakenWide $ \(Separate h w) -> general h w)


newtype GenToTall height width meas vert horiz =
   GenToTall {
      getGenToTall ::
         Extent meas vert horiz height width ->
         Extent Size Big horiz height width
   }

genToTall :: (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width -> Extent Size Big horiz height width
genToTall =
   getGenToTall $
   switchTagTriple
      (GenToTall $ \(Square s) -> tall s s)
      (GenToTall $ \(Separate h w) -> tall h w)
      (GenToTall $ \(Separate h w) -> general h w)
      (GenToTall id)
      (GenToTall id)

newtype GenToWide height width meas vert horiz =
   GenToWide {
      getGenToWide ::
         Extent meas vert horiz height width ->
         Extent Size vert Big height width
   }

genToWide :: (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width -> Extent Size vert Big height width
genToWide =
   getGenToWide $
   switchTagTriple
      (GenToWide $ \(Square s) -> wide s s)
      (GenToWide $ \(Separate h w) -> wide h w)
      (GenToWide id)
      (GenToWide $ \(Separate h w) -> general h w)
      (GenToWide id)


newtype Accessor a height width meas vert horiz =
   Accessor {getAccessor :: Extent meas vert horiz height width -> a}

squareSize :: Square shape -> shape
squareSize (Square s) = s

height ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width -> height
height (Square s) = s
height (Separate h _w) = h

width ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width -> width
width (Square s) = s
width (Separate _h w) = w


dimensions ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width -> (height,width)
dimensions x = (height x, width x)


toGeneral ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width -> General height width
toGeneral x = general (height x) (width x)

fromSquare ::
   (Measure meas, C vert, C horiz) =>
   Square size -> Extent meas vert horiz size size
fromSquare (Square s) = genSquare s

fromSquareLiberal ::
   (Measure meas, C vert, C horiz) =>
   Extent Shape Small Small height width ->
   Extent meas vert horiz height width
fromSquareLiberal (Square h) = genSquare h

fromLiberalSquare :: (C vert, C horiz) =>
   LiberalSquare height width ->
   Extent Size vert horiz height width
fromLiberalSquare (Separate h w) = genLiberalSquare h w

squareFromFull ::
   (Measure meas, C vert, C horiz, Eq size) =>
   Extent meas vert horiz size size -> Square size
squareFromFull x =
   let size = height x
   in if size == width x
         then square size
         else error "Extent.squareFromFull: no square shape"

liberalSquareFromFull ::
   (Measure meas, C vert, C horiz, Shape.C height, Shape.C width) =>
   Extent meas vert horiz height width -> LiberalSquare height width
liberalSquareFromFull (Square s) = Separate s s
liberalSquareFromFull (Separate h w) =
   if Shape.size h == Shape.size w
      then liberalSquare h w
      else error "Extent.liberalSquareFromFull: no square shape"


transpose ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz height width ->
   Extent meas horiz vert width height
transpose (Square s) = Square s
transpose (Separate h w) = Separate w h


instance
   (Measure meas, C vert, C horiz, Eq height, Eq width) =>
      Eq (Extent meas vert horiz height width) where
   Square a == Square b  =  a==b
   Separate h0 w0 == Separate h1 w1  = h0==h1 && w0==w1


instance
   (Measure meas, C vert, C horiz, Show height, Show width) =>
      Show (Extent meas vert horiz height width) where
   showsPrec prec x@(Square _) = showsPrecSquare prec x
   showsPrec prec x@(Separate _ _) =
      flip getAccessor x $
      switchTagPair
         (Accessor $ showsPrecAny "Extent.liberalSquare" prec)
         (Accessor $ showsPrecAny "Extent.wide" prec)
         (Accessor $ showsPrecAny "Extent.tall" prec)
         (Accessor $ showsPrecAny "Extent.general" prec)

showsPrecSquare ::
   (Show height) =>
   Int -> Extent Shape Small Small height width -> ShowS
showsPrecSquare p x =
   showParen (p>10) $
   showString "Extent.square " . showsPrec 11 (height x)

showsPrecAny ::
   (Measure meas, C vert, C horiz, Show height, Show width) =>
   String -> Int -> Extent meas vert horiz height width -> ShowS
showsPrecAny name p x =
   showParen (p>10) $
   showString name .
   showString " " . showsPrec 11 (height x) .
   showString " " . showsPrec 11 (width x)


widen ::
   (C vert) =>
   widthB ->
   Extent Size vert Big height widthA -> Extent Size vert Big height widthB
widen w (Separate h _) = Separate h w

reduceWideHeight ::
   (C vert) =>
   heightB ->
   Extent Size vert Big heightA width -> Extent Size vert Big heightB width
reduceWideHeight h (Separate _ w) = Separate h w


reduceConsistent ::
   (Measure meas, C vert, C horiz) =>
   height -> width ->
   Extent meas vert horiz height width -> Extent meas vert horiz height width
reduceConsistent h _ (Square _) = Square h
reduceConsistent h w (Separate _ _) = Separate h w

mapHeight ::
   (C vert, C horiz) =>
   (heightA -> heightB) ->
   Extent Size vert horiz heightA width -> Extent Size vert horiz heightB width
mapHeight f (Separate h w) = Separate (f h) w

mapWidth ::
   (C vert, C horiz) =>
   (widthA -> widthB) ->
   Extent Size vert horiz height widthA -> Extent Size vert horiz height widthB
mapWidth f (Separate h w) = Separate h (f w)

mapSquareSize :: (shA -> shB) -> Square shA -> Square shB
mapSquareSize f (Square s) = Square (f s)


mapWrap ::
   (Measure meas, C vert, C horiz) =>
   (height -> f height) ->
   (width -> f width) ->
   Extent meas vert horiz height width ->
   Extent meas vert horiz (f height) (f width)
mapWrap fh _ (Square h) = Square (fh h)
mapWrap fh fw (Separate h w) = Separate (fh h) (fw w)

{- only admissible since GHC-7.8
mapUnwrap ::
   (Measure meas, C vert, C horiz) =>
   (f height -> height) ->
   (f width -> width) ->
   Extent meas vert horiz (f height) (f width) ->
   Extent meas vert horiz height width
mapUnwrap fh fw =
   getAdapt $
   switchTagPair
      (Adapt $ \(Square h) -> Square (fh h))
      (Adapt $ \(Wide h w) -> Wide (fh h) (fw w))
      (Adapt $ \(Tall h w) -> Tall (fh h) (fw w))
      (Adapt $ \(General h w) -> General (fh h) (fw w))
-}

recheck ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz (Unchecked height) (Unchecked width) ->
   Extent meas vert horiz height width
recheck (Square h) = Square (deconsUnchecked h)
recheck (Separate h w) = Separate (deconsUnchecked h) (deconsUnchecked w)

recheckAppend ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz
      (Unchecked heightA ::+ Unchecked heightB)
      (Unchecked widthA  ::+ Unchecked widthB) ->
   Extent meas vert horiz (heightA::+heightB) (widthA::+widthB)
recheckAppend (Square (ha::+hb)) =
   Square (deconsUnchecked ha ::+ deconsUnchecked hb)
recheckAppend (Separate (ha::+hb) (wa::+wb)) =
   Separate
      (deconsUnchecked ha ::+ deconsUnchecked hb)
      (deconsUnchecked wa ::+ deconsUnchecked wb)

fuse ::
   (Measure meas, C vert, C horiz, Eq fuse) =>
   Extent meas vert horiz height fuse ->
   Extent meas vert horiz fuse width ->
   Maybe (Extent meas vert horiz height width)
fuse (Square s0) (Square s1) = toMaybe (s0==s1) $ Square s0
fuse (Separate h f0) (Separate f1 w) = toMaybe (f0==f1) $ Separate h w

relaxMeasureWith ::
   (Measure measA, Measure measB,
    MultiplyMeasure measA measB ~ measC,
    C vert, C horiz) =>
   Extent measA vertA horizA heightA widthA ->
   Extent measB vert horiz height width ->
   Extent measC vert horiz height width
relaxMeasureWith (Square _) = id
relaxMeasureWith (Separate _ _) = relaxMeasure


kronecker ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz heightA widthA ->
   Extent meas vert horiz heightB widthB ->
   Extent meas vert horiz (heightA,heightB) (widthA,widthB)
kronecker = stackGen (,) (,)



{-
Tag table for 'beside'.

Small Small  Small Small -> Small Big
Small Small  Small Big   -> Small Big
Small Small  Big   Small -> Small Big
Small Small  Big   Big   -> Small Big
Small Big    Small Small -> Small Big
Small Big    Small Big   -> Small Big
Small Big    Big   Small -> Small Big
Small Big    Big   Big   -> Small Big
Big   Small  Small Small -> Small Big
Big   Small  Small Big   -> Small Big
Big   Small  Big   Small -> Big   Big
Big   Small  Big   Big   -> Big   Big
Big   Big    Small Small -> Small Big
Big   Big    Small Big   -> Small Big
Big   Big    Big   Small -> Big   Big
Big   Big    Big   Big   -> Big   Big
-}
newtype AppendMode vertA vertB vertC height widthA widthB =
   AppendMode (
      Extent Size vertA Big height widthA ->
      Extent Size vertB Big height widthB ->
      Extent Size vertC Big height (widthA::+widthB)
   )

appendLeftAux ::
   (C vertA, C vertB) => AppendMode vertA vertB vertA height widthA widthB
appendLeftAux =
   AppendMode $ \extentA extentB ->
      widen (width extentA ::+ width extentB) extentA

appendSame :: (C vert) => AppendMode vert vert vert height widthA widthB
appendSame = appendLeftAux

appendLeft :: (C vert) => AppendMode vert Big vert height widthA widthB
appendLeft = appendLeftAux

appendRight :: (C vert) => AppendMode Big vert vert height widthA widthB
appendRight =
   AppendMode $ \extentA extentB ->
      widen (width extentA ::+ width extentB) extentB

type family Append a b
type instance Append Small b = Small
type instance Append Big   b = b

newtype
   AppendAny vertB height widthA widthB vertA =
      AppendAny {
         getAppendAny ::
            AppendMode vertA vertB (Append vertA vertB) height widthA widthB
      }

appendAny ::
   (C vertA, C vertB) =>
   AppendMode vertA vertB (Append vertA vertB) height widthA widthB
appendAny =
   getAppendAny $ switchTag (AppendAny appendLeftAux) (AppendAny appendRight)


stack ::
   (Measure meas, C vert, C horiz) =>
   Extent meas vert horiz heightA widthA ->
   Extent meas vert horiz heightB widthB ->
   Extent meas vert horiz (heightA::+heightB) (widthA::+widthB)
stack = stackGen (::+) (::+)

stackGen ::
   (Measure meas, C vert, C horiz) =>
   (heightA -> heightB -> f heightA heightB) ->
   (widthA -> widthB -> f widthA widthB) ->
   Extent meas vert horiz heightA widthA ->
   Extent meas vert horiz heightB widthB ->
   Extent meas vert horiz (f heightA heightB) (f widthA widthB)
stackGen fh _f (Square sa) (Square sb) = Square (fh sa sb)
stackGen fh fw (Separate ha wa) (Separate hb wb) =
                              Separate (fh ha hb) (fw wa wb)


type family Multiply a b
type instance Multiply Small b = b
type instance Multiply Big   b = Big

type family MultiplyMeasure a b
type instance MultiplyMeasure Shape b = b
type instance MultiplyMeasure Size  b = Size


data TagFact a = C a => TagFact

newtype MultiplyTagLaw b a =
   MultiplyTagLaw {
      getMultiplyTagLaw :: TagFact a -> TagFact b -> TagFact (Multiply a b)
   }

multiplyTagLaw :: TagFact a -> TagFact b -> TagFact (Multiply a b)
multiplyTagLaw a@TagFact =
   ($ a) $ getMultiplyTagLaw $
   switchTag
      (MultiplyTagLaw $ flip const)
      (MultiplyTagLaw const)

heightFact :: (C vert) => Extent meas vert horiz height width -> TagFact vert
heightFact _ = TagFact

widthFact :: (C horiz) => Extent meas vert horiz height width -> TagFact horiz
widthFact _ = TagFact


data MeasureFact a = Measure a => MeasureFact

newtype MultiplyMeasureLaw b a =
   MultiplyMeasureLaw {
      getMultiplyMeasureLaw ::
         MeasureFact a -> MeasureFact b -> MeasureFact (MultiplyMeasure a b)
   }

multiplyMeasureLaw ::
   MeasureFact a -> MeasureFact b -> MeasureFact (MultiplyMeasure a b)
multiplyMeasureLaw a@MeasureFact =
   ($ a) $ getMultiplyMeasureLaw $
   switchMeasure
      (MultiplyMeasureLaw $ flip const)
      (MultiplyMeasureLaw const)

measureFact ::
   (Measure meas) => Extent meas vert horiz height width -> MeasureFact meas
measureFact _ = MeasureFact


newtype
   Unify height fuse width heightC widthC
      measB vertB horizB measA vertA horizA =
   Unify {
      getUnify ::
         Extent measA vertA horizA height fuse ->
         Extent measB vertB horizB fuse width ->
         Extent (MultiplyMeasure measA measB)
            (Multiply vertA vertB) (Multiply horizA horizB) heightC widthC
   }

unifyLeft ::
   (Measure measA, Measure measB, C vertA, C horizA, C vertB, C horizB) =>
   Extent measA vertA horizA height fuse ->
   Extent measB vertB horizB fuse width ->
   Extent (MultiplyMeasure measA measB)
      (Multiply vertA vertB) (Multiply horizA horizB) height fuse
unifyLeft =
   getUnify $
   switchTagTriple
      (Unify $ const . fromSquareLiberal)
      (Unify $ const . fromLiberalSquare)
      (Unify $ const . generalizeWide)
      (Unify $ const . generalizeTall)
      (Unify $ const . toGeneral)

unifyRight ::
   (Measure measA, Measure measB, C vertA, C horizA, C vertB, C horizB) =>
   Extent measA vertA horizA height fuse ->
   Extent measB vertB horizB fuse width ->
   Extent (MultiplyMeasure measA measB)
      (Multiply vertA vertB) (Multiply horizA horizB) fuse width
unifyRight =
   getUnify $
   switchTagTriple
      (Unify $ const id)
      (Unify $ const relaxMeasure)
      (Unify $ const genToWide)
      (Unify $ const genToTall)
      (Unify $ const toGeneral)


{-
Square  Square  -> Square
Square  Wide    -> Wide
Square  Tall    -> Tall
Square  General -> General
Wide    Square  -> Wide
Wide    Wide    -> Wide
Wide    Tall    -> General
Wide    General -> General
Tall    Square  -> Tall
Tall    Wide    -> General
Tall    Tall    -> Tall
Tall    General -> General
General Square  -> General
General Wide    -> General
General Tall    -> General
General General -> General

Small Small  Small Small -> Small Small
Small Small  Small Big   -> Small Big
Small Small  Big   Small -> Big   Small
Small Small  Big   Big   -> Big   Big
Small Big    Small Small -> Small Big
Small Big    Small Big   -> Small Big
Small Big    Big   Small -> Big   Big
Small Big    Big   Big   -> Big   Big
Big   Small  Small Small -> Big   Small
Big   Small  Small Big   -> Big   Big
Big   Small  Big   Small -> Big   Small
Big   Small  Big   Big   -> Big   Big
Big   Big    Small Small -> Big   Big
Big   Big    Small Big   -> Big   Big
Big   Big    Big   Small -> Big   Big
Big   Big    Big   Big   -> Big   Big
-}