{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GADTs #-}
module Numeric.LAPACK.Matrix.Extent.Private where

import qualified Numeric.LAPACK.Matrix.Extent.Kind as EK
import Numeric.LAPACK.Shape.Private (Unchecked(deconsUnchecked))
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import Data.Array.Comfort.Shape ((:+:)((:+:)))

import Control.DeepSeq (NFData, rnf)

import Data.Maybe.HT (toMaybe)
import Data.Tuple.HT (swap)
import Data.Eq.HT (equating)


data Extent vertical horizontal height width =
   Extent {
      extentDir :: (vertical,horizontal),
      extentDim :: Dimensions vertical horizontal height width
   }

instance
   (C vertical, C horizontal, NFData height, NFData width) =>
      NFData (Extent vertical horizontal height width) where
   rnf =
      getAccessor $
      switchTagPair
         (Accessor $ \(Extent o (EK.Square s)) -> rnf (o,s))
         (Accessor $ \(Extent o (EK.Wide h w)) -> rnf (o,(h,w)))
         (Accessor $ \(Extent o (EK.Tall h w)) -> rnf (o,(h,w)))
         (Accessor $ \(Extent o (EK.General h w)) -> rnf (o,(h,w)))


data Big = Big deriving (Eq,Show)
data Small = Small deriving (Eq,Show)

instance NFData Big where rnf Big = ()
instance NFData Small where rnf Small = ()

type General = Extent Big Big
type Tall = Extent Big Small
type Wide = Extent Small Big
type Square sh = Extent Small Small sh sh


type family Dimensions vertical horizontal :: * -> * -> *

type instance Dimensions Big Big = EK.General
type instance Dimensions Big Small = EK.Tall
type instance Dimensions Small Big = EK.Wide
type instance Dimensions Small Small = EK.Square


general :: height -> width -> General height width
general h w = Extent (Big,Big) $ EK.General h w

tall :: height -> width -> Tall height width
tall h w = Extent (Big,Small) $ EK.Tall h w

wide :: height -> width -> Wide height width
wide h w = Extent (Small,Big) $ EK.Wide h w

square :: sh -> Square sh
square sh = Extent (Small,Small) $ EK.Square sh


newtype Map vertA horizA vertB horizB height width =
   Map {
      apply ::
         Extent vertA horizA height width ->
         Extent vertB horizB height width
   }


class C tag where switchTag :: f Small -> f Big -> f tag
instance C Small where switchTag f _ = f
instance C Big where switchTag _ f = f


switchTagPair ::
   (C vert, C horiz) =>
   f Small Small -> f Small Big -> f Big Small -> f Big Big -> f vert horiz
switchTagPair fSquare fWide fTall fGeneral =
   getFlip $
   switchTag
      (Flip $ switchTag fSquare fWide)
      (Flip $ switchTag fTall fGeneral)


newtype CaseTallWide height width vert horiz =
   CaseTallWide {
      getCaseTallWide ::
         Extent vert horiz height width ->
         Either (Tall height width) (Wide height width)
   }

caseTallWide ::
   (C vert, C horiz) =>
   (height -> width -> Bool) ->
   Extent vert horiz height width ->
   Either (Tall height width) (Wide height width)
caseTallWide ge =
   getCaseTallWide $
   switchTagPair
      (CaseTallWide $ \(Extent _ (EK.Square sh)) -> Left $ tall sh sh)
      (CaseTallWide Right)
      (CaseTallWide Left)
      (CaseTallWide $ \(Extent _ (EK.General h w)) ->
         if ge h w
            then Left $ tall h w
            else Right $ wide h w)


newtype GenSquare sh vert horiz =
   GenSquare {getGenSquare :: sh -> Extent vert horiz sh sh}

genSquare :: (C vert, C horiz) => sh -> Extent vert horiz sh sh
genSquare =
   getGenSquare $
   switchTagPair
      (GenSquare square)
      (GenSquare (\sh -> wide sh sh))
      (GenSquare (\sh -> tall sh sh))
      (GenSquare (\sh -> general sh sh))

newtype GenTall height width vert horiz =
   GenTall {
      getGenTall ::
         Extent vert Small height width -> Extent vert horiz height width
   }

generalizeTall :: (C vert, C horiz) =>
   Extent vert Small height width -> Extent vert horiz height width
generalizeTall =
   getGenTall $
   switchTagPair
      (GenTall id) (GenTall $ \(Extent _ (EK.Square s)) -> wide s s)
      (GenTall id) (GenTall $ \(Extent _ (EK.Tall h w)) -> general h w)

newtype GenWide height width vert horiz =
   GenWide {
      getGenWide ::
         Extent Small horiz height width -> Extent vert horiz height width
   }

generalizeWide :: (C vert, C horiz) =>
   Extent Small horiz height width -> Extent vert horiz height width
generalizeWide =
   getGenWide $
   switchTagPair
      (GenWide id)
      (GenWide id)
      (GenWide $ \(Extent _ (EK.Square s)) -> tall s s)
      (GenWide $ \(Extent _ (EK.Wide h w)) -> general h w)


newtype GenToTall height width vert horiz =
   GenToTall {
      getGenToTall ::
         Extent vert horiz height width -> Extent Big horiz height width
   }

genToTall :: (C vert, C horiz) =>
   Extent vert horiz height width -> Extent Big horiz height width
genToTall =
   getGenToTall $
   switchTagPair
      (GenToTall $ \(Extent _ (EK.Square s)) -> tall s s)
      (GenToTall $ \(Extent _ (EK.Wide h w)) -> general h w)
      (GenToTall id)
      (GenToTall id)


newtype GenToWide height width vert horiz =
   GenToWide {
      getGenToWide ::
         Extent vert horiz height width -> Extent vert Big height width
   }

genToWide :: (C vert, C horiz) =>
   Extent vert horiz height width -> Extent vert Big height width
genToWide =
   getGenToWide $
   switchTagPair
      (GenToWide $ \(Extent _ (EK.Square s)) -> wide s s)
      (GenToWide id)
      (GenToWide $ \(Extent _ (EK.Tall h w)) -> general h w)
      (GenToWide id)


squareSize :: Square sh -> sh
squareSize (Extent (Small,Small) (EK.Square sh)) = sh


newtype Accessor a height width vert horiz =
   Accessor {getAccessor :: Extent vert horiz height width -> a}

height :: (C vert, C horiz) => Extent vert horiz height width -> height
height =
   getAccessor $
   switchTagPair
      (Accessor (\(Extent _ (EK.Square s)) -> s))
      (Accessor (EK.wideHeight . extentDim))
      (Accessor (EK.tallHeight . extentDim))
      (Accessor (EK.generalHeight . extentDim))

width :: (C vert, C horiz) => Extent vert horiz height width -> width
width =
   getAccessor $
   switchTagPair
      (Accessor (\(Extent _ (EK.Square s)) -> s))
      (Accessor (EK.wideWidth . extentDim))
      (Accessor (EK.tallWidth . extentDim))
      (Accessor (EK.generalWidth . extentDim))


dimensions ::
   (C vert, C horiz) => Extent vert horiz height width -> (height,width)
dimensions x = (height x, width x)


toGeneral ::
   (C vert, C horiz) => Extent vert horiz height width -> General height width
toGeneral x = general (height x) (width x)

fromSquare :: (C vert, C horiz) => Square size -> Extent vert horiz size size
fromSquare = genSquare . squareSize

fromSquareLiberal :: (C vert, C horiz) =>
   Extent Small Small height width -> Extent vert horiz height width
fromSquareLiberal x@(Extent _ (EK.Square _)) = genSquare $ height x

squareFromGeneral ::
   (C vert, C horiz, Eq size) =>
   Extent vert horiz size size -> Square size
squareFromGeneral x =
   let size = height x
   in if size == width x
        then square size
        else error "Extent.squareFromGeneral: no square shape"


newtype Transpose height width vert horiz =
   Transpose {
      getTranspose ::
         Extent vert horiz height width ->
         Extent horiz vert width height
   }

transpose ::
   (C vert, C horiz) =>
   Extent vert horiz height width ->
   Extent horiz vert width height
transpose =
   getTranspose $
   switchTagPair
      (Transpose $ \(Extent o (EK.Square s)) -> Extent o (EK.Square s))
      (Transpose $ \(Extent o (EK.Wide h w)) -> Extent (swap o) (EK.Tall w h))
      (Transpose $ \(Extent o (EK.Tall h w)) -> Extent (swap o) (EK.Wide w h))
      (Transpose $ \(Extent o (EK.General h w)) -> Extent o (EK.General w h))


newtype Equal height width vert horiz =
   Equal {
      getEqual ::
         Extent vert horiz height width ->
         Extent vert horiz height width -> Bool
   }

instance
   (C vert, C horiz, Eq height, Eq width) =>
      Eq (Extent vert horiz height width) where
   (==) =
      getEqual $
      switchTagPair
         (Equal $ equating extentDim)
         (Equal $ equating extentDim)
         (Equal $ equating extentDim)
         (Equal $ equating extentDim)


instance
   (C vert, C horiz, Show height, Show width) =>
      Show (Extent vert horiz height width) where
   showsPrec prec =
      getAccessor $
      switchTagPair
         (Accessor $ showsPrecSquare prec)
         (Accessor $ showsPrecAny "Extent.wide" prec)
         (Accessor $ showsPrecAny "Extent.tall" prec)
         (Accessor $ showsPrecAny "Extent.general" prec)

showsPrecSquare ::
   (Show height) =>
   Int -> Extent Small Small height width -> ShowS
showsPrecSquare p x =
   showParen (p>10) $
   showString "Extent.square " . showsPrec 11 (height x)

showsPrecAny ::
   (C vert, C horiz, Show height, Show width) =>
   String -> Int -> Extent vert horiz height width -> ShowS
showsPrecAny name p x =
   showParen (p>10) $
   showString name .
   showString " " . showsPrec 11 (height x) .
   showString " " . showsPrec 11 (width x)


newtype Widen heightA widthA heightB widthB vert =
   Widen {
      getWiden ::
         Extent vert Big heightA widthA ->
         Extent vert Big heightB widthB
   }

widen ::
   (C vert) =>
   widthB -> Extent vert Big height widthA -> Extent vert Big height widthB
widen w =
   getWiden $
   switchTag
      (Widen (\(Extent o x) -> Extent o (x{EK.wideWidth = w})))
      (Widen (\(Extent o x) -> Extent o (x{EK.generalWidth = w})))

reduceWideHeight ::
   (C vert) =>
   heightB -> Extent vert Big heightA width -> Extent vert Big heightB width
reduceWideHeight h =
   getWiden $
   switchTag
      (Widen (\(Extent o x) -> Extent o (x{EK.wideHeight = h})))
      (Widen (\(Extent o x) -> Extent o (x{EK.generalHeight = h})))


newtype Adapt heightA widthA heightB widthB vert horiz =
   Adapt {
      getAdapt ::
         Extent vert horiz heightA widthA ->
         Extent vert horiz heightB widthB
   }

reduceConsistent ::
   (C vert, C horiz) =>
   height -> width ->
   Extent vert horiz height width -> Extent vert horiz height width
reduceConsistent h w =
   getAdapt $
   switchTagPair
      (Adapt $ \(Extent o (EK.Square _)) -> Extent o (EK.Square h))
      (Adapt $ \(Extent o (EK.Wide _ _)) -> Extent o (EK.Wide h w))
      (Adapt $ \(Extent o (EK.Tall _ _)) -> Extent o (EK.Tall h w))
      (Adapt $ \(Extent o (EK.General _ _)) -> Extent o (EK.General h w))


class (C vert, C horiz) => GeneralTallWide vert horiz where
   switchTagGTW :: f Small Big -> f Big Small -> f Big Big -> f vert horiz

instance GeneralTallWide Small Big where switchTagGTW f _ _ = f
instance GeneralTallWide Big Small where switchTagGTW _ f _ = f
instance GeneralTallWide Big Big where switchTagGTW _ _ f = f

mapHeight ::
   (GeneralTallWide vert horiz) =>
   (heightA -> heightB) ->
   Extent vert horiz heightA width -> Extent vert horiz heightB width
mapHeight f =
   getAdapt $
   switchTagGTW
      (Adapt $ \(Extent o (EK.Wide h w)) -> Extent o (EK.Wide (f h) w))
      (Adapt $ \(Extent o (EK.Tall h w)) -> Extent o (EK.Tall (f h) w))
      (Adapt $ \(Extent o (EK.General h w)) -> Extent o (EK.General (f h) w))

mapWidth ::
   (GeneralTallWide vert horiz) =>
   (widthA -> widthB) ->
   Extent vert horiz height widthA -> Extent vert horiz height widthB
mapWidth f =
   getAdapt $
   switchTagGTW
      (Adapt $ \(Extent o (EK.Wide h w)) -> Extent o (EK.Wide h (f w)))
      (Adapt $ \(Extent o (EK.Tall h w)) -> Extent o (EK.Tall h (f w)))
      (Adapt $ \(Extent o (EK.General h w)) -> Extent o (EK.General h (f w)))

mapSquareSize :: (shA -> shB) -> Square shA -> Square shB
mapSquareSize f (Extent o (EK.Square s)) = Extent o (EK.Square (f s))


mapWrap ::
   (C vert, C horiz) =>
   (height -> f height) ->
   (width -> f width) ->
   Extent vert horiz height width ->
   Extent vert horiz (f height) (f width)
mapWrap fh fw =
   getAdapt $
   switchTagPair
      (Adapt $ \(Extent o (EK.Square h)) -> Extent o (EK.Square (fh h)))
      (Adapt $ \(Extent o (EK.Wide h w)) -> Extent o (EK.Wide (fh h) (fw w)))
      (Adapt $ \(Extent o (EK.Tall h w)) -> Extent o (EK.Tall (fh h) (fw w)))
      (Adapt $ \(Extent o (EK.General h w)) ->
                  Extent o (EK.General (fh h) (fw w)))

{- only admissible since GHC-7.8
mapUnwrap ::
   (C vert, C horiz) =>
   (f height -> height) ->
   (f width -> width) ->
   Extent vert horiz (f height) (f width) ->
   Extent vert horiz height width
mapUnwrap fh fw =
   getAdapt $
   switchTagPair
      (Adapt $ \(Extent o (EK.Square h)) -> Extent o (EK.Square (fh h)))
      (Adapt $ \(Extent o (EK.Wide h w)) -> Extent o (EK.Wide (fh h) (fw w)))
      (Adapt $ \(Extent o (EK.Tall h w)) -> Extent o (EK.Tall (fh h) (fw w)))
      (Adapt $ \(Extent o (EK.General h w)) ->
                  Extent o (EK.General (fh h) (fw w)))
-}

recheck ::
   (C vert, C horiz) =>
   Extent vert horiz (Unchecked height) (Unchecked width) ->
   Extent vert horiz height width
recheck =
   getAdapt $
   switchTagPair
      (Adapt $ \(Extent o (EK.Square h)) ->
         Extent o (EK.Square (deconsUnchecked h)))
      (Adapt $ \(Extent o (EK.Wide h w)) ->
         Extent o (EK.Wide (deconsUnchecked h) (deconsUnchecked w)))
      (Adapt $ \(Extent o (EK.Tall h w)) ->
         Extent o (EK.Tall (deconsUnchecked h) (deconsUnchecked w)))
      (Adapt $ \(Extent o (EK.General h w)) ->
         Extent o (EK.General (deconsUnchecked h) (deconsUnchecked w)))



newtype Fuse height fuse width vert horiz =
   Fuse {
      getFuse ::
         Extent vert horiz height fuse ->
         Extent vert horiz fuse width ->
         Maybe (Extent vert horiz height width)
   }

fuse ::
   (C vert, C horiz, Eq fuse) =>
   Extent vert horiz height fuse ->
   Extent vert horiz fuse width ->
   Maybe (Extent vert horiz height width)
fuse =
   getFuse $
   switchTagPair
      (Fuse $
       \(Extent o (EK.Square s0)) (Extent _ (EK.Square s1)) ->
         toMaybe (s0==s1) $ Extent o (EK.Square s0))
      (Fuse $
       \(Extent o (EK.Wide h f0)) (Extent _ (EK.Wide f1 w)) ->
         toMaybe (f0==f1) $ Extent o (EK.Wide h w))
      (Fuse $
       \(Extent o (EK.Tall h f0)) (Extent _ (EK.Tall f1 w)) ->
         toMaybe (f0==f1) $ Extent o (EK.Tall h w))
      (Fuse $
       \(Extent o (EK.General h f0)) (Extent _ (EK.General f1 w)) ->
         toMaybe (f0==f1) $ Extent o (EK.General h w))


newtype Kronecker heightA widthA heightB widthB vert horiz =
   Kronecker {
      getKronecker ::
         Extent vert horiz heightA widthA ->
         Extent vert horiz heightB widthB ->
         Extent vert horiz (heightA,heightB) (widthA,widthB)
   }

kronecker ::
   (C vert, C horiz) =>
   Extent vert horiz heightA widthA ->
   Extent vert horiz heightB widthB ->
   Extent vert horiz (heightA,heightB) (widthA,widthB)
kronecker =
   getKronecker $
   switchTagPair
      (Kronecker $
       \(Extent o (EK.Square s0)) (Extent _ (EK.Square s1)) ->
         Extent o (EK.Square (s0,s1)))
      (Kronecker $
       \(Extent o (EK.Wide h0 w0)) (Extent _ (EK.Wide h1 w1)) ->
         Extent o (EK.Wide (h0,h1) (w0,w1)))
      (Kronecker $
       \(Extent o (EK.Tall h0 w0)) (Extent _ (EK.Tall h1 w1)) ->
         Extent o (EK.Tall (h0,h1) (w0,w1)))
      (Kronecker $
       \(Extent o (EK.General h0 w0)) (Extent _ (EK.General h1 w1)) ->
         Extent o (EK.General (h0,h1) (w0,w1)))



{-
Tag table for 'beside'.

Small Small  Small Small -> Small Big
Small Small  Small Big   -> Small Big
Small Small  Big   Small -> Small Big
Small Small  Big   Big   -> Small Big
Small Big    Small Small -> Small Big
Small Big    Small Big   -> Small Big
Small Big    Big   Small -> Small Big
Small Big    Big   Big   -> Small Big
Big   Small  Small Small -> Small Big
Big   Small  Small Big   -> Small Big
Big   Small  Big   Small -> Big   Big
Big   Small  Big   Big   -> Big   Big
Big   Big    Small Small -> Small Big
Big   Big    Small Big   -> Small Big
Big   Big    Big   Small -> Big   Big
Big   Big    Big   Big   -> Big   Big
-}
newtype AppendMode vertA vertB vertC height widthA widthB =
   AppendMode (
      Extent vertA Big height widthA ->
      Extent vertB Big height widthB ->
      Extent vertC Big height (widthA:+:widthB)
   )

appendLeftAux ::
   (C vertA, C vertB) => AppendMode vertA vertB vertA height widthA widthB
appendLeftAux =
   AppendMode $ \extentA extentB ->
      widen (width extentA :+: width extentB) extentA

appendSame :: (C vert) => AppendMode vert vert vert height widthA widthB
appendSame = appendLeftAux

appendLeft :: (C vert) => AppendMode vert Big vert height widthA widthB
appendLeft = appendLeftAux

appendRight :: (C vert) => AppendMode Big vert vert height widthA widthB
appendRight =
   AppendMode $ \extentA extentB ->
      widen (width extentA :+: width extentB) extentB

type family Append a b
type instance Append Small b = Small
type instance Append Big   b = b

newtype
   AppendAny vertB height widthA widthB vertA =
      AppendAny {
         getAppendAny ::
            AppendMode vertA vertB (Append vertA vertB) height widthA widthB
      }

appendAny ::
   (C vertA, C vertB) =>
   AppendMode vertA vertB (Append vertA vertB) height widthA widthB
appendAny =
   getAppendAny $ switchTag (AppendAny appendLeftAux) (AppendAny appendRight)



type family Multiply a b
type instance Multiply Small b = b
type instance Multiply Big   b = Big


data TagFact a = C a => TagFact

newtype MultiplyTagLaw b a =
   MultiplyTagLaw {
      getMultiplyTagLaw :: TagFact a -> TagFact b -> TagFact (Multiply a b)
   }

multiplyTagLaw :: TagFact a -> TagFact b -> TagFact (Multiply a b)
multiplyTagLaw a@TagFact =
   ($a) $ getMultiplyTagLaw $
   switchTag
      (MultiplyTagLaw $ flip const)
      (MultiplyTagLaw const)

heightFact :: (C vert) => Extent vert horiz height width -> TagFact vert
heightFact _ = TagFact

widthFact :: (C horiz) => Extent vert horiz height width -> TagFact horiz
widthFact _ = TagFact


newtype Unify height fuse width heightC widthC vertB horizB vertA horizA =
   Unify {
      getUnify ::
         Extent vertA horizA height fuse ->
         Extent vertB horizB fuse width ->
         Extent (Multiply vertA vertB) (Multiply horizA horizB) heightC widthC
   }

unifyLeft ::
   (C vertA, C horizA, C vertB, C horizB) =>
   Extent vertA horizA height fuse ->
   Extent vertB horizB fuse width ->
   Extent (Multiply vertA vertB) (Multiply horizA horizB) height fuse
unifyLeft =
   getUnify $
   switchTagPair
      (Unify $ const . fromSquareLiberal)
      (Unify $ const . generalizeWide)
      (Unify $ const . generalizeTall)
      (Unify $ const . toGeneral)

unifyRight ::
   (C vertA, C horizA, C vertB, C horizB) =>
   Extent vertA horizA height fuse ->
   Extent vertB horizB fuse width ->
   Extent (Multiply vertA vertB) (Multiply horizA horizB) fuse width
unifyRight =
   getUnify $
   switchTagPair
      (Unify $ const id)
      (Unify $ const genToWide)
      (Unify $ const genToTall)
      (Unify $ const toGeneral)


{-
Square  Square  -> Square
Square  Wide    -> Wide
Square  Tall    -> Tall
Square  General -> General
Wide    Square  -> Wide
Wide    Wide    -> Wide
Wide    Tall    -> General
Wide    General -> General
Tall    Square  -> Tall
Tall    Wide    -> General
Tall    Tall    -> Tall
Tall    General -> General
General Square  -> General
General Wide    -> General
General Tall    -> General
General General -> General

Small Small  Small Small -> Small Small
Small Small  Small Big   -> Small Big
Small Small  Big   Small -> Big   Small
Small Small  Big   Big   -> Big   Big
Small Big    Small Small -> Small Big
Small Big    Small Big   -> Small Big
Small Big    Big   Small -> Big   Big
Small Big    Big   Big   -> Big   Big
Big   Small  Small Small -> Big   Small
Big   Small  Small Big   -> Big   Big
Big   Small  Big   Small -> Big   Small
Big   Small  Big   Big   -> Big   Big
Big   Big    Small Small -> Big   Big
Big   Big    Small Big   -> Big   Big
Big   Big    Big   Small -> Big   Big
Big   Big    Big   Big   -> Big   Big
-}