{-# LANGUAGE BangPatterns           #-}
{-# LANGUAGE CPP                    #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE PatternSynonyms        #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TypeFamilies           #-}
{-# LANGUAGE TypeOperators          #-}
{-# LANGUAGE UndecidableInstances   #-}

#if __GLASGOW_HASKELL__ >= 800

{-# LANGUAGE TypeFamilyDependencies #-}

#endif
-- |
-- Module      : Data.Massiv.Core.Index.Ix
-- Copyright   : (c) Alexey Kuleshevich 2018
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Core.Index.Ix where

import           Control.DeepSeq
import           Control.Monad               (liftM)
import           Data.Massiv.Core.Index.Class
import           Data.Monoid                 ((<>))
import           Data.Proxy
import qualified Data.Vector.Generic         as V
import qualified Data.Vector.Generic.Mutable as VM
import qualified Data.Vector.Unboxed         as VU
import           GHC.TypeLits


infixr 5 :>, :.

type Ix1 = Int

pattern Ix1 :: Int -> Ix1
pattern Ix1 i = i

data Ix2 = (:.) {-# UNPACK #-} !Int {-# UNPACK #-} !Int
pattern Ix2 :: Int -> Int -> Ix2
pattern Ix2 i j = i :. j

type Ix3 = IxN 3
pattern Ix3 :: Int -> Int -> Int -> Ix3
pattern Ix3 i j k = i :> j :. k

type Ix4 = IxN 4
pattern Ix4 :: Int -> Int -> Int -> Int -> Ix4
pattern Ix4 i j k l = i :> j :> k :. l

type Ix5 = IxN 5
pattern Ix5 :: Int -> Int -> Int -> Int -> Int -> Ix5
pattern Ix5 i j k l m = i :> j :> k :> l :. m


#if __GLASGOW_HASKELL__ >= 800

data IxN (n :: Nat) where
  (:>) :: {-# UNPACK #-} !Int -> !(Ix (n - 1)) -> IxN n

type family Ix (n :: Nat) = r | r -> n where
  Ix 0 = Ix0
  Ix 1 = Ix1
  Ix 2 = Ix2
  Ix n = IxN n

#else

data IxN (n :: Nat) where
  (:>) :: Rank (Ix (n - 1)) ~ (n - 1) => {-# UNPACK #-} !Int -> !(Ix (n - 1)) -> IxN n

type family Ix (n :: Nat) where
  Ix 0 = Ix0
  Ix 1 = Ix1
  Ix 2 = Ix2
  Ix n = IxN n

#endif


type instance Lower Ix2 = Ix1
type instance Lower (IxN n) = Ix (n - 1)


instance Show Ix2 where
  show (i :. j)  = show i ++ " :. " ++ show j

instance Show (Ix (n - 1)) => Show (IxN n) where
  show (i :> ix) = show i ++ " :> " ++ show ix


instance Num Ix2 where
  (+) = liftIndex2 (+)
  {-# INLINE [1] (+) #-}
  (-) = liftIndex2 (-)
  {-# INLINE [1] (-) #-}
  (*) = liftIndex2 (*)
  {-# INLINE [1] (*) #-}
  negate = liftIndex negate
  {-# INLINE [1] negate #-}
  abs = liftIndex abs
  {-# INLINE [1] abs #-}
  signum = liftIndex signum
  {-# INLINE [1] signum #-}
  fromInteger = pureIndex . fromInteger
  {-# INLINE [1] fromInteger #-}

instance Num Ix3 where
  (+) = liftIndex2 (+)
  {-# INLINE [1] (+) #-}
  (-) = liftIndex2 (-)
  {-# INLINE [1] (-) #-}
  (*) = liftIndex2 (*)
  {-# INLINE [1] (*) #-}
  negate = liftIndex negate
  {-# INLINE [1] negate #-}
  abs = liftIndex abs
  {-# INLINE [1] abs #-}
  signum = liftIndex signum
  {-# INLINE [1] signum #-}
  fromInteger = pureIndex . fromInteger
  {-# INLINE [1] fromInteger #-}


instance {-# OVERLAPPABLE #-} (4 <= n,
          KnownNat n,
          Index (Ix (n - 1)),
#if __GLASGOW_HASKELL__ < 800
          Rank (Ix ((n - 1) - 1)) ~ ((n - 1) - 1),
#endif
          IxN (n - 1) ~ Ix (n - 1)
          ) => Num (IxN n) where
  (+) = liftIndex2 (+)
  {-# INLINE [1] (+) #-}
  (-) = liftIndex2 (-)
  {-# INLINE [1] (-) #-}
  (*) = liftIndex2 (*)
  {-# INLINE [1] (*) #-}
  negate = liftIndex negate
  {-# INLINE [1] negate #-}
  abs = liftIndex abs
  {-# INLINE [1] abs #-}
  signum = liftIndex signum
  {-# INLINE [1] signum #-}
  fromInteger = pureIndex . fromInteger
  {-# INLINE [1] fromInteger #-}



instance Bounded Ix2 where
  minBound = pureIndex minBound
  {-# INLINE minBound #-}
  maxBound = pureIndex maxBound
  {-# INLINE maxBound #-}

instance Bounded Ix3 where
  minBound = pureIndex minBound
  {-# INLINE minBound #-}
  maxBound = pureIndex maxBound
  {-# INLINE maxBound #-}

instance {-# OVERLAPPABLE #-} (4 <= n,
          KnownNat n,
          Index (Ix (n - 1)),
#if __GLASGOW_HASKELL__ < 800
          Rank (Ix ((n - 1) - 1)) ~ ((n - 1) - 1),
#endif
          IxN (n - 1) ~ Ix (n - 1)
          ) => Bounded (IxN n) where
  minBound = pureIndex minBound
  {-# INLINE minBound #-}
  maxBound = pureIndex maxBound
  {-# INLINE maxBound #-}

instance NFData Ix2 where
  rnf ix = ix `seq` ()

instance NFData (IxN n) where
  rnf ix = ix `seq` ()


instance Eq Ix2 where
  (i1 :. j1)  == (i2 :. j2) = i1 == i2 && j1 == j2

instance Eq (Ix (n - 1)) => Eq (IxN n) where
  (i1 :> ix1) == (i2 :> ix2) = i1 == i2 && ix1 == ix2


instance Ord Ix2 where
  compare (i1 :. j1) (i2 :. j2) = compare i1 i2 <> compare j1 j2

instance Ord (Ix (n - 1)) => Ord (IxN n) where
  compare (i1 :> ix1) (i2 :> ix2) = compare i1 i2 <> compare ix1 ix2


toIx2 :: Ix2T -> Ix2
toIx2 (i, j) = i :. j
{-# INLINE toIx2 #-}

fromIx2 :: Ix2 -> Ix2T
fromIx2 (i :. j) = (i, j)
{-# INLINE fromIx2 #-}

toIx3 :: Ix3T -> Ix3
toIx3 (i, j, k) = i :> j :. k
{-# INLINE toIx3 #-}

fromIx3 :: Ix3 -> Ix3T
fromIx3 (i :> j :. k) = (i, j, k)
{-# INLINE fromIx3 #-}

toIx4 :: Ix4T -> Ix4
toIx4 (i, j, k, l) = i :> j :> k :. l
{-# INLINE toIx4 #-}

fromIx4 :: Ix4 -> Ix4T
fromIx4 (i :> j :> k :. l) = (i, j, k, l)
{-# INLINE fromIx4 #-}

toIx5 :: Ix5T -> Ix5
toIx5 (i, j, k, l, m) = i :> j :> k :> l :. m
{-# INLINE toIx5 #-}

fromIx5 :: Ix5 -> Ix5T
fromIx5 (i :> j :> k :> l :. m) = (i, j, k, l, m)
{-# INLINE fromIx5 #-}


instance {-# OVERLAPPING #-} Index Ix2 where
  type Rank Ix2 = 2
  rank _ = 2
  {-# INLINE [1] rank #-}
  totalElem (m :. n) = m * n
  {-# INLINE [1] totalElem #-}
  isSafeIndex (m :. n) (i :. j) = 0 <= i && 0 <= j && i < m && j < n
  {-# INLINE [1] isSafeIndex #-}
  toLinearIndex (_ :. n) (i :. j) = n * i + j
  {-# INLINE [1] toLinearIndex #-}
  fromLinearIndex (_ :. n) k = case k `quotRem` n of
                                 (i, j) -> i :. j
  {-# INLINE [1] fromLinearIndex #-}
  consDim = (:.)
  {-# INLINE [1] consDim #-}
  unconsDim (i :. ix) = (i, ix)
  {-# INLINE [1] unconsDim #-}
  snocDim i j = i :. j
  {-# INLINE [1] snocDim #-}
  unsnocDim (i :. j) = (i, j)
  {-# INLINE [1] unsnocDim #-}
  getIndex (i :. _) 2 = Just i
  getIndex (_ :. j) 1 = Just j
  getIndex _        _ = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex (_ :. j) 2 i = Just (i :. j)
  setIndex (i :. _) 1 j = Just (i :. j)
  setIndex _        _ _ = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim (_ :. j) 2 = Just j
  dropDim (i :. _) 1 = Just i
  dropDim _      _   = Nothing
  {-# INLINE [1] dropDim #-}
  pureIndex i = i :. i
  {-# INLINE [1] pureIndex #-}
  liftIndex f (i :. j) = f i :. f j
  {-# INLINE [1] liftIndex #-}
  liftIndex2 f (i0 :. j0) (i1 :. j1) = f i0 i1 :. f j0 j1
  {-# INLINE [1] liftIndex2 #-}
  repairIndex (n :. szL) (i :. ixL) rBelow rOver =
    repairIndex n i rBelow rOver :. repairIndex szL ixL rBelow rOver
  {-# INLINE [1] repairIndex #-}


instance {-# OVERLAPPING #-} Index (IxN 3) where
  type Rank Ix3 = 3
  rank _ = 3
  {-# INLINE [1] rank #-}
  totalElem (m :> n :. o) = m * n * o
  {-# INLINE [1] totalElem #-}
  isSafeIndex (m :> n :. o) (i :> j :. k) =
    0 <= i && 0 <= j && 0 <= k && i < m && j < n && k < o
  {-# INLINE [1] isSafeIndex #-}
  toLinearIndex (_ :> n :. o) (i :> j :. k) = (n * i + j) * o + k
  {-# INLINE [1] toLinearIndex #-}
  fromLinearIndex (_ :> ix) k = let !(q, ixL) = fromLinearIndexAcc ix k in q :> ixL
  {-# INLINE [1] fromLinearIndex #-}
  consDim = (:>)
  {-# INLINE [1] consDim #-}
  unconsDim (i :> ix) = (i, ix)
  {-# INLINE [1] unconsDim #-}
  snocDim (i :. j) k = i :> j :. k
  {-# INLINE [1] snocDim #-}
  unsnocDim (i :> j :. k) = (i :. j, k)
  {-# INLINE [1] unsnocDim #-}
  getIndex (i :> _ :. _) 3 = Just i
  getIndex (_ :> j :. _) 2 = Just j
  getIndex (_ :> _ :. k) 1 = Just k
  getIndex _             _ = Nothing
  {-# INLINE [1] getIndex #-}
  setIndex (_ :> j :. k) 3 i = Just (i :> j :. k)
  setIndex (i :> _ :. k) 2 j = Just (i :> j :. k)
  setIndex (i :> j :. _) 1 k = Just (i :> j :. k)
  setIndex _             _ _ = Nothing
  {-# INLINE [1] setIndex #-}
  dropDim (_ :> j :. k) 3 = Just (j :. k)
  dropDim (i :> _ :. k) 2 = Just (i :. k)
  dropDim (i :> j :. _) 1 = Just (i :. j)
  dropDim _             _ = Nothing
  {-# INLINE [1] dropDim #-}
  pureIndex i = i :> i :. i
  {-# INLINE [1] pureIndex #-}
  liftIndex f (i :> j :. k) = f i :> f j :. f k
  {-# INLINE [1] liftIndex #-}
  liftIndex2 f (i0 :> j0 :. k0) (i1 :> j1 :. k1) = f i0 i1 :> f j0 j1 :. f k0 k1
  {-# INLINE [1] liftIndex2 #-}
  repairIndex (n :> szL) (i :> ixL) rBelow rOver =
    repairIndex n i rBelow rOver :> repairIndex szL ixL rBelow rOver
  {-# INLINE [1] repairIndex #-}

instance {-# OVERLAPPABLE #-} (4 <= n,
          KnownNat n,
          Index (Ix (n - 1)),
#if __GLASGOW_HASKELL__ < 800
          Rank (Ix ((n - 1) - 1)) ~ ((n - 1) - 1),
#endif
          IxN (n - 1) ~ Ix (n - 1)
          ) => Index (IxN n) where
  type Rank (IxN n) = n
  rank _ = fromInteger $ natVal (Proxy :: Proxy n)
  {-# INLINE [1] rank #-}
  totalElem (i :> ix) = i * totalElem ix
  {-# INLINE [1] totalElem #-}
  consDim = (:>)
  {-# INLINE [1] consDim #-}
  unconsDim (i :> ix) = (i, ix)
  {-# INLINE [1] unconsDim #-}
  snocDim (i :> ix) k = i :> snocDim ix k
  {-# INLINE [1] snocDim #-}
  unsnocDim (i :> ix) = case unsnocDim ix of
                          (jx, j) -> (i :> jx, j)
  {-# INLINE [1] unsnocDim #-}
  getIndex ix@(j :> jx) k | k == rank ix = Just j
                          | otherwise = getIndex jx k
  {-# INLINE [1] getIndex #-}
  setIndex ix@(j :> jx) k o | k == rank ix = Just (o :> jx)
                            | otherwise = (j :>) <$> setIndex jx k o
  {-# INLINE [1] setIndex #-}
  dropDim ix@(j :> jx) k | k == rank ix = Just jx
                           | otherwise = (j :>) <$> dropDim jx k
  {-# INLINE [1] dropDim #-}
  pureIndex i = i :> (pureIndex i :: Ix (n - 1))
  {-# INLINE [1] pureIndex #-}
  liftIndex f (i :> ix) = f i :> liftIndex f ix
  {-# INLINE [1] liftIndex #-}
  liftIndex2 f (i1 :> ix1) (i2 :> ix2) = f i1 i2 :> liftIndex2 f ix1 ix2
  {-# INLINE [1] liftIndex2 #-}
  repairIndex (n :> szL) (i :> ixL) rBelow rOver =
    repairIndex n i rBelow rOver :> repairIndex szL ixL rBelow rOver
  {-# INLINE [1] repairIndex #-}



---- Unbox Ix

-- | Unboxing of a `Ix2`.
instance VU.Unbox Ix2

newtype instance VU.MVector s Ix2 = MV_Ix2 (VU.MVector s Ix2T)

instance VM.MVector VU.MVector Ix2 where
  basicLength (MV_Ix2 mvec) = VM.basicLength mvec
  {-# INLINE basicLength #-}
  basicUnsafeSlice idx len (MV_Ix2 mvec) = MV_Ix2 (VM.basicUnsafeSlice idx len mvec)
  {-# INLINE basicUnsafeSlice #-}
  basicOverlaps (MV_Ix2 mvec) (MV_Ix2 mvec') = VM.basicOverlaps mvec mvec'
  {-# INLINE basicOverlaps #-}
  basicUnsafeNew len = MV_Ix2 `liftM` VM.basicUnsafeNew len
  {-# INLINE basicUnsafeNew #-}
  basicUnsafeReplicate len val = MV_Ix2 `liftM` VM.basicUnsafeReplicate len (fromIx2 val)
  {-# INLINE basicUnsafeReplicate #-}
  basicUnsafeRead (MV_Ix2 mvec) idx = toIx2 `liftM` VM.basicUnsafeRead mvec idx
  {-# INLINE basicUnsafeRead #-}
  basicUnsafeWrite (MV_Ix2 mvec) idx val = VM.basicUnsafeWrite mvec idx (fromIx2 val)
  {-# INLINE basicUnsafeWrite #-}
  basicClear (MV_Ix2 mvec) = VM.basicClear mvec
  {-# INLINE basicClear #-}
  basicSet (MV_Ix2 mvec) val = VM.basicSet mvec (fromIx2 val)
  {-# INLINE basicSet #-}
  basicUnsafeCopy (MV_Ix2 mvec) (MV_Ix2 mvec') = VM.basicUnsafeCopy mvec mvec'
  {-# INLINE basicUnsafeCopy #-}
  basicUnsafeMove (MV_Ix2 mvec) (MV_Ix2 mvec') = VM.basicUnsafeMove mvec mvec'
  {-# INLINE basicUnsafeMove #-}
  basicUnsafeGrow (MV_Ix2 mvec) len = MV_Ix2 `liftM` VM.basicUnsafeGrow mvec len
  {-# INLINE basicUnsafeGrow #-}
#if MIN_VERSION_vector(0,11,0)
  basicInitialize (MV_Ix2 mvec) = VM.basicInitialize mvec
  {-# INLINE basicInitialize #-}
#endif


newtype instance VU.Vector Ix2 = V_Ix2 (VU.Vector Ix2T)

instance V.Vector VU.Vector Ix2 where
  basicUnsafeFreeze (MV_Ix2 mvec) = V_Ix2 `liftM` V.basicUnsafeFreeze mvec
  {-# INLINE basicUnsafeFreeze #-}
  basicUnsafeThaw (V_Ix2 vec) = MV_Ix2 `liftM` V.basicUnsafeThaw vec
  {-# INLINE basicUnsafeThaw #-}
  basicLength (V_Ix2 vec) = V.basicLength vec
  {-# INLINE basicLength #-}
  basicUnsafeSlice idx len (V_Ix2 vec) = V_Ix2 (V.basicUnsafeSlice idx len vec)
  {-# INLINE basicUnsafeSlice #-}
  basicUnsafeIndexM (V_Ix2 vec) idx = toIx2 `liftM` V.basicUnsafeIndexM vec idx
  {-# INLINE basicUnsafeIndexM #-}
  basicUnsafeCopy (MV_Ix2 mvec) (V_Ix2 vec) = V.basicUnsafeCopy mvec vec
  {-# INLINE basicUnsafeCopy #-}
  elemseq _ = seq
  {-# INLINE elemseq #-}



---- Unbox Ix



-- | Unboxing of a `IxN`.
instance (3 <= n,
#if __GLASGOW_HASKELL__ < 800
          Rank (Ix (n - 1)) ~ (n - 1),
#endif
          VU.Unbox (Ix (n-1))) => VU.Unbox (IxN n)

newtype instance VU.MVector s (IxN n) = MV_IxN (VU.MVector s Int, VU.MVector s (Ix (n-1)))

instance (3 <= n,
#if __GLASGOW_HASKELL__ < 800
          Rank (Ix (n - 1)) ~ (n - 1),
#endif
          VU.Unbox (Ix (n - 1))) =>
         VM.MVector VU.MVector (IxN n) where
  basicLength (MV_IxN (_, mvec)) = VM.basicLength mvec
  {-# INLINE basicLength #-}
  basicUnsafeSlice idx len (MV_IxN (mvec1, mvec)) =
    MV_IxN (VM.basicUnsafeSlice idx len mvec1, VM.basicUnsafeSlice idx len mvec)
  {-# INLINE basicUnsafeSlice #-}
  basicOverlaps (MV_IxN (mvec1, mvec)) (MV_IxN (mvec1', mvec')) =
    VM.basicOverlaps mvec1 mvec1' && VM.basicOverlaps mvec mvec'
  {-# INLINE basicOverlaps #-}
  basicUnsafeNew len = do
    iv <- VM.basicUnsafeNew len
    ivs <- VM.basicUnsafeNew len
    return $ MV_IxN (iv, ivs)
  {-# INLINE basicUnsafeNew #-}
  basicUnsafeReplicate len (i :> ix) = do
    iv <- VM.basicUnsafeReplicate len i
    ivs <- VM.basicUnsafeReplicate len ix
    return $ MV_IxN (iv, ivs)
  {-# INLINE basicUnsafeReplicate #-}
  basicUnsafeRead (MV_IxN (mvec1, mvec)) idx = do
    i <- VM.basicUnsafeRead mvec1 idx
    ix <- VM.basicUnsafeRead mvec idx
    return (i :> ix)
  {-# INLINE basicUnsafeRead #-}
  basicUnsafeWrite (MV_IxN (mvec1, mvec)) idx (i :> ix) = do
    VM.basicUnsafeWrite mvec1 idx i
    VM.basicUnsafeWrite mvec idx ix
  {-# INLINE basicUnsafeWrite #-}
  basicClear (MV_IxN (mvec1, mvec)) = VM.basicClear mvec1 >> VM.basicClear mvec
  {-# INLINE basicClear #-}
  basicSet (MV_IxN (mvec1, mvec)) (i :> ix) = VM.basicSet mvec1 i >> VM.basicSet mvec ix
  {-# INLINE basicSet #-}
  basicUnsafeCopy (MV_IxN (mvec1, mvec)) (MV_IxN (mvec1', mvec')) =
    VM.basicUnsafeCopy mvec1 mvec1' >> VM.basicUnsafeCopy mvec mvec'
  {-# INLINE basicUnsafeCopy #-}
  basicUnsafeMove (MV_IxN (mvec1, mvec)) (MV_IxN (mvec1', mvec')) =
    VM.basicUnsafeMove mvec1 mvec1' >> VM.basicUnsafeMove mvec mvec'
  {-# INLINE basicUnsafeMove #-}
  basicUnsafeGrow (MV_IxN (mvec1, mvec)) len = do
    iv <- VM.basicUnsafeGrow mvec1 len
    ivs <- VM.basicUnsafeGrow mvec len
    return $ MV_IxN (iv, ivs)
  {-# INLINE basicUnsafeGrow #-}
#if MIN_VERSION_vector(0,11,0)
  basicInitialize (MV_IxN (mvec1, mvec)) =
    VM.basicInitialize mvec1 >> VM.basicInitialize mvec
  {-# INLINE basicInitialize #-}
#endif


newtype instance VU.Vector (IxN n) = V_IxN (VU.Vector Int, VU.Vector (Ix (n-1)))

instance (3 <= n,
#if __GLASGOW_HASKELL__ < 800
          Rank (Ix (n - 1)) ~ (n - 1),
#endif
          VU.Unbox (Ix (n-1))) => V.Vector VU.Vector (IxN n) where
  basicUnsafeFreeze (MV_IxN (mvec1, mvec)) = do
    iv <- V.basicUnsafeFreeze mvec1
    ivs <- V.basicUnsafeFreeze mvec
    return $ V_IxN (iv, ivs)
  {-# INLINE basicUnsafeFreeze #-}
  basicUnsafeThaw (V_IxN (vec1, vec)) = do
    imv <- V.basicUnsafeThaw vec1
    imvs <- V.basicUnsafeThaw vec
    return $ MV_IxN (imv, imvs)
  {-# INLINE basicUnsafeThaw #-}
  basicLength (V_IxN (_, vec)) = V.basicLength vec
  {-# INLINE basicLength #-}
  basicUnsafeSlice idx len (V_IxN (vec1, vec)) = do
    V_IxN (V.basicUnsafeSlice idx len vec1, V.basicUnsafeSlice idx len vec)
  {-# INLINE basicUnsafeSlice #-}
  basicUnsafeIndexM (V_IxN (vec1, vec)) idx = do
    i <- V.basicUnsafeIndexM vec1 idx
    ix <- V.basicUnsafeIndexM vec idx
    return (i :> ix)
  {-# INLINE basicUnsafeIndexM #-}
  basicUnsafeCopy (MV_IxN (mvec1, mvec)) (V_IxN (vec1, vec)) =
    V.basicUnsafeCopy mvec1 vec1 >> V.basicUnsafeCopy mvec vec
  {-# INLINE basicUnsafeCopy #-}
  elemseq _ = seq
  {-# INLINE elemseq #-}