{-# LANGUAGE AllowAmbiguousTypes        #-}
{-# LANGUAGE CPP                        #-}
{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE DeriveDataTypeable         #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE ExistentialQuantification  #-}
{-# LANGUAGE ExplicitNamespaces         #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash                  #-}
{-# LANGUAGE PatternSynonyms            #-}
{-# LANGUAGE PolyKinds                  #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeApplications           #-}
{-# LANGUAGE UnboxedTuples              #-}
{-# LANGUAGE UndecidableInstances       #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.Dimensions.Idx
-- Copyright   :  (c) Artem Chirkin
-- License     :  BSD3
--
--
-- Provides a data type `Idx` to index `Dim` and `Idxs`
--   that enumerates through multiple dimensions.
--
-- Higher indices go first, i.e. assumed enumeration
--          is i = i1*n1*n2*...*n(k-1) + ... + i(k-2)*n1*n2 + i(k-1)*n1 + ik
-- This corresponds to row-first layout of matrices and multidimenional arrays.
--
-----------------------------------------------------------------------------

module Numeric.Dimensions.Idx
  ( -- * Data types
    Idx (Idx), Idxs
  , idxFromWord, unsafeIdxFromWord, idxToWord
  , listIdxs, idxsFromWords
  ) where


import           Data.Coerce
import           Data.Constraint  (Dict (..))
import           Data.Data        (Data)
import           Foreign.Storable (Storable)
import           GHC.Enum
import           GHC.Generics     (Generic)
import qualified Text.Read        as P

#ifdef UNSAFE_INDICES
import GHC.Base (Int (..), Type, Word (..), int2Word#, unsafeCoerce#, word2Int#)
#else
import GHC.Base (Int (..), Type, Word (..), int2Word#, maxInt, plusWord2#,
                 timesWord2#, unsafeCoerce#, word2Int#)
#endif

import Numeric.Dimensions.Dim
import Numeric.TypedList      (typedListReadPrec, typedListShowsPrec)


-- | This type is used to index a single dimension;
--   the range of indices is from @0@ to @n-1@.
--
newtype Idx (n :: k) = Idx' Word
  deriving ( Data, Generic, Integral, Real, Storable, Eq, Ord )


-- | Convert between `Word` and `Idx`.
--
--   If the word is outside of the bounds, fails with an error
--     (unless @unsafeindices@ flag is turned on).
--
pattern Idx :: forall (k :: Type) (n :: k) . BoundedDim n => Word -> Idx n
pattern Idx w <- Idx' w
  where
    Idx = unsafeIdxFromWord
{-# COMPLETE Idx #-}

-- | Type-level dimensional indexing with arbitrary Word values inside.
--   Most of the operations on it require `Dimensions` constraint,
--   because the @Idxs@ itself does not store info about dimension bounds.
type Idxs (xs :: [k]) = TypedList Idx xs

-- | Convert an arbitrary Word to @Idx@.
--
--   If the word is outside of the bounds, fails with an error
--     (unless @unsafeindices@ flag is turned on).
--
unsafeIdxFromWord :: forall (k :: Type) (d :: k) . BoundedDim d => Word -> Idx d
#ifdef UNSAFE_INDICES
unsafeIdxFromWord = coerce
#else
unsafeIdxFromWord w
  | w < d     = coerce w
  | otherwise = errorWithoutStackTrace
              $ "idxFromWord{" ++ showIdxType @k @d ++ "}: word "
              ++ show w ++ " is outside of index bounds."
  where
    d = dimVal (dimBound @k @d)
#endif
{-# INLINE unsafeIdxFromWord #-}

-- | Convert an arbitrary Word to @Idx@.
idxFromWord :: forall (k :: Type) (d :: k) . BoundedDim d => Word -> Maybe (Idx d)
idxFromWord w
  | w < dimVal (dimBound @k @d) = Just (coerce w)
  | otherwise                   = Nothing
{-# INLINE idxFromWord #-}

-- | Get the value of an @Idx@.
idxToWord :: forall (k :: Type) (d :: k) . Idx d -> Word
idxToWord = coerce
{-# INLINE idxToWord #-}

{-# RULES
"fromIntegral/idxToWord"
  fromIntegral = idxToWord
  #-}

listIdxs :: forall (k :: Type) (xs :: [k]) . Idxs xs -> [Word]
listIdxs = unsafeCoerce#
{-# INLINE listIdxs #-}

idxsFromWords :: forall (k :: Type) (xs :: [k])
               . BoundedDims xs => [Word] -> Maybe (Idxs xs)
idxsFromWords = unsafeCoerce# . go (listDims (dimsBound @k @xs))
  where
    go :: [Word] -> [Word] -> Maybe [Word]
    go [] [] = Just []
    go (d : ds) (i : is)
      | i < d = (i:) <$> go ds is
    go _ _   = Nothing



instance BoundedDim x => Read (Idx (x :: k)) where
    readPrec = do
      w <- P.readPrec
      if w < dimVal (dimBound @k @x)
      then return (Idx' w)
      else P.pfail
    readList = P.readListDefault
    readListPrec = P.readListPrecDefault

instance Show (Idx (x :: k)) where
    showsPrec = coerce (showsPrec :: Int -> Word -> ShowS)

instance BoundedDim n => Bounded (Idx (n :: k)) where
    minBound = 0
    {-# INLINE minBound #-}
    maxBound = coerce (dimVal(dimBound @k @n)  - 1)
    {-# INLINE maxBound #-}

instance BoundedDim n => Enum (Idx (n :: k)) where

#ifdef UNSAFE_INDICES
    succ = coerce ((+ 1) :: Word -> Word)
#else
    succ x@(Idx' i)
      | x < maxBound = coerce (i + 1)
      | otherwise = succError $ showIdxType @k @n
#endif
    {-# INLINE succ #-}

#ifdef UNSAFE_INDICES
    pred = coerce (subtract 1 :: Word -> Word)
#else
    pred x@(Idx' i)
      | x > minBound = coerce (i - 1)
      | otherwise = predError $ showIdxType @k @n
#endif
    {-# INLINE pred #-}

#ifdef UNSAFE_INDICES
    toEnum (I# i#) = coerce (W# (int2Word# i#))
#else
    toEnum i
        | i >= 0 && i' < d = coerce i'
        | otherwise        = toEnumError (showIdxType @k @n) i (0, d - 1)
      where
        d  = dimVal (dimBound @k @n)
        i' = fromIntegral i
#endif
    {-# INLINE toEnum #-}

#ifdef UNSAFE_INDICES
    fromEnum (Idx' (W# w#)) = I# (word2Int# w#)
#else
    fromEnum (Idx' x@(W# w#))
        | x <= maxIntWord = I# (word2Int# w#)
        | otherwise       = fromEnumError (showIdxType @k @n) x
        where
          maxIntWord = W# (case maxInt of I# i -> int2Word# i)
#endif
    {-# INLINE fromEnum #-}

    enumFrom (Idx' n)
      = coerce (enumFromTo n (dimVal (dimBound @k @n) - 1))
    {-# INLINE enumFrom #-}
    enumFromThen (Idx' n0) (Idx' n1)
      = coerce (enumFromThenTo n0 n1 lim)
      where
        lim = if n1 >= n0 then dimVal (dimBound @k @n) - 1 else 0
    {-# INLINE enumFromThen #-}
    enumFromTo
      = coerce (enumFromTo :: Word -> Word -> [Word])
    {-# INLINE enumFromTo #-}
    enumFromThenTo
      = coerce (enumFromThenTo :: Word -> Word -> Word -> [Word])
    {-# INLINE enumFromThenTo #-}

instance BoundedDim n => Num (Idx (n :: k)) where

#ifdef UNSAFE_INDICES
    (+) = coerce ((+) :: Word -> Word -> Word)
#else
    (Idx' a@(W# a#)) + b@(Idx' (W# b#))
        | ovf || r >= d
          = errorWithoutStackTrace
          $ "Num.(+){" ++ showIdxType @k @n ++ "}: sum of "
            ++ show a ++ " and " ++ show b
            ++ " is outside of index bounds."
        | otherwise = coerce r
      where
        (ovf, r) = case plusWord2# a# b# of
          (# r2#, r1# #) -> ( W# r2# > 0 , W# r1# )
        d = dimVal (dimBound @k @n)
#endif
    {-# INLINE (+) #-}

#ifdef UNSAFE_INDICES
    (-) = coerce ((-) :: Word -> Word -> Word)
#else
    (Idx' a) - (Idx' b)
        | b > a
          = errorWithoutStackTrace
          $ "Num.(-){" ++ showIdxType @k @n ++ "}: difference of "
            ++ show a ++ " and " ++ show b
            ++ " is negative."
        | otherwise = coerce (a - b)
#endif
    {-# INLINE (-) #-}

#ifdef UNSAFE_INDICES
    (*) = coerce ((*) :: Word -> Word -> Word)
#else
    (Idx' a@(W# a#)) * b@(Idx' (W# b#))
        | ovf || r >= d
          = errorWithoutStackTrace
          $ "Num.(*){" ++ showIdxType @k @n ++ "}: product of "
            ++ show a ++ " and " ++ show b
            ++ " is outside of index bounds."
        | otherwise = coerce r
      where
        (ovf, r) = case timesWord2# a# b# of
          (# r2#, r1# #) -> ( W# r2# > 0 , W# r1# )
        d = dimVal (dimBound @k @n)
#endif
    {-# INLINE (*) #-}

    negate = errorWithoutStackTrace
           $ "Num.(*){" ++ showIdxType @k @n ++ "}: cannot negate index."
    {-# INLINE negate #-}
    abs = id
    {-# INLINE abs #-}
    signum = const (Idx' 1)
    {-# INLINE signum #-}

#ifdef UNSAFE_INDICES
    fromInteger = coerce (fromInteger :: Integer -> Word)
#else
    fromInteger i
      | i >= 0 && i < d = Idx' $ fromInteger i
      | otherwise       = errorWithoutStackTrace
                        $ "Num.fromInteger{" ++ showIdxType @k @n ++ "}: integer "
                        ++ show i ++ " is outside of index bounds."
      where
        d = toInteger $ dimVal (dimBound @k @n)
#endif
    {-# INLINE fromInteger #-}




instance Eq (Idxs (xs :: [k])) where
    (==) = unsafeCoerce# ((==) :: [Word] -> [Word] -> Bool)
    {-# INLINE (==) #-}

-- | Compare indices by their importance in lexicorgaphic order
--   from the first dimension to the last dimension
--   (the first dimension is the most significant one).
--
--   Literally,
--
--   > compare a b = compare (listIdxs a) (listIdxs b)
--
--   This is the same @compare@ rule, as for `Dims`.
--   This is also consistent with offsets:
--
--   > sort == sortOn fromEnum
--
instance Ord (Idxs (xs :: [k])) where
    compare = unsafeCoerce# (compare :: [Word] -> [Word] -> Ordering)
    {-# INLINE compare #-}

instance Show (Idxs (xs :: [k])) where
    showsPrec = typedListShowsPrec @k @Idx @xs showsPrec

instance BoundedDims xs => Read (Idxs (xs :: [k])) where
    readPrec = case inferAllBoundedDims @k @xs of
      Dict -> typedListReadPrec @k @BoundedDim ":*" P.readPrec (tList @k @xs)
    readList = P.readListDefault
    readListPrec = P.readListPrecDefault

-- | With this instance we can slightly reduce indexing expressions, e.g.
--
--   > x ! (1 :* 2 :* 4) == x ! (1 :* 2 :* 4 :* U)
--
instance BoundedDim n => Num (Idxs '[(n :: k)]) where
    (a:*U) + (b:*U) = (a+b) :* U
    {-# INLINE (+) #-}
    (a:*U) - (b:*U) = (a-b) :* U
    {-# INLINE (-) #-}
    (a:*U) * (b:*U) = (a*b) :* U
    {-# INLINE (*) #-}
    signum (a:*U)   = signum a :* U
    {-# INLINE signum #-}
    abs (a:*U)      = abs a :* U
    {-# INLINE abs #-}
    fromInteger i   = fromInteger i :* U
    {-# INLINE fromInteger #-}

instance BoundedDims ds => Bounded (Idxs (ds :: [k])) where
    maxBound = f (minDims @k @ds)
      where
        f :: forall (ns :: [k]) . Dims ns -> Idxs ns
        f U         = U
        f (d :* ds) = coerce (dimVal d - 1) :* f ds
    {-# INLINE maxBound #-}
    minBound = f (minDims @k @ds)
      where
        f :: forall (ns :: [k]) . Dims ns -> Idxs ns
        f U         = U
        f (_ :* ds) = Idx' 0 :* f ds
    {-# INLINE minBound #-}

-- @ds@ must be @[Nat]@ for @Enum (Idxs ds)@,
--   because succ and pred would break otherwise
instance Dimensions ds => Enum (Idxs (ds :: [Nat])) where

    succ idx = case go dds idx of
        (True , _ ) -> succError $ showIdxsType dds
        (False, i') -> i'
      where
        dds = dims @ds
        go :: forall (ns :: [Nat]) . Dims ns -> Idxs ns -> (Bool, Idxs ns)
        go U U = (True, U)
        go (d :* ds) (Idx' i :* is) = case go ds is of
          (True , is')
            | i + 1 == dimVal d -> (True , Idx'  0    :* is')
            | otherwise         -> (False, Idx' (i+1) :* is')
          (False, is')          -> (False, Idx'  i    :* is')
    {-# INLINE succ #-}

    pred idx = case go dds idx of
        (True , _ ) -> predError $ showIdxsType dds
        (False, i') -> i'
      where
        dds = dims @ds
        go :: forall (ns :: [Nat]) . Dims ns -> Idxs ns -> (Bool, Idxs ns)
        go U U = (True, U)
        go (d :* ds) (Idx' i :* is) = case go ds is of
          (True , is')
            | i == 0    -> (True , Idx' (dimVal d - 1) :* is')
            | otherwise -> (False, Idx' (i-1)          :* is')
          (False, is')  -> (False, Idx'  i             :* is')
    {-# INLINE pred #-}

    toEnum off0 = case go dds of
        (0, i) -> i
        _      -> toEnumError (showIdxsType dds) off0 (0, totalDim dds - 1)
      where
        dds = dims @ds
        go :: forall (ns :: [Nat]) . Dims ns -> (Word, Idxs ns)
        go  U = (fromIntegral off0, U)
        go (d :* ds)
          | (off , is) <- go ds
          , (off', i ) <- quotRem off (dimVal d)
              = (off', Idx' i :* is)
    {-# INLINE toEnum #-}

    fromEnum = fromIntegral . snd
             . foldr f (1, 0)
             . zip (listDims $ dims @ds) . listIdxs
      where
        f :: (Word, Word) -> (Word, Word) -> (Word, Word)
        f (d, i) (td, off) = (d * td, off + td * i)
    {-# INLINE fromEnum #-}

    enumFrom = unsafeCoerce# go True (dims @ds)
      where
        go :: Bool -> [Word] -> [Word] -> [[Word]]
        go b (d:ds) (i:is) =
          [ i' : is' | (b', i') <- zip (b : repeat False)
                                     $ enumFromTo (if b then i else 0) (d - 1)
                     , is' <- go b' ds is ]
        go _ _ _  = [[]]
    {-# INLINE enumFrom #-}

    enumFromTo = unsafeCoerce# go True True (dims @ds)
      where
        go :: Bool -> Bool -> [Word] -> [Word] -> [Word] -> [[Word]]
        go bl bu (d:ds) (x:xs) (y:ys) =
          [ i : is | (bl', bu', i) <- prepapp bl bu
                                    $ enumFromTo (if bl then x else 0)
                                                 (if bu then y else d - 1)
                   , is <- go bl' bu' ds xs ys ]
        go _ _ _ _ _ = [[]]
        prepapp _  _  []     = []
        prepapp bl bu [i]    = [(bl, bu, i)]
        prepapp bl bu (i:is) = (bl, False, i :: Word) : app bu is
        app _  []     = []
        app bu [i]    = [(False, bu, i :: Word)]
        app bu (i:is) = (False, False, i) : app bu is
    {-# INLINE enumFromTo #-}

    enumFromThen x0 x1 = case compare x1 x0 of
      EQ -> repeat x0
      GT -> enumFromThenTo x0 x1 maxBound
      LT -> enumFromThenTo x0 x1 minBound
    {-# INLINE enumFromThen #-}

    enumFromThenTo x0 x1 y = case dir of
        EQ -> if allYs >= allX0s then repeat x0 else []
        GT -> let (_, allDXs) = idxMinus allDs allX0s allX1s
                  repeatStep is
                    = if is <= allYs
                      then is : case idxPlus allDs is allDXs of
                        (0, is') -> repeatStep is'
                        _        -> []
                      else []
              in unsafeCoerce# (repeatStep allX0s)
        LT -> let (_, allDXs) = idxMinus allDs allX1s allX0s
                  repeatStep is
                    = if is >= allYs
                      then is : case idxMinus allDs allDXs is of
                        (0, is') -> repeatStep is'
                        _        -> []
                      else []
              in unsafeCoerce# (repeatStep allX0s)
      where
        allDs  = listDims $ dims @ds
        allX0s = listIdxs x0
        allX1s = listIdxs x1
        allYs  = listIdxs y
        dir    = compare allX1s allX0s -- succ or pred?
        -- second arg minus first arg
        idxMinus :: [Word] -> [Word] -> [Word] -> (Word, [Word])
        idxMinus (d:ds) (a:as) (b:bs)
          = let (one , xs ) = idxMinus ds as bs
                (one', x  ) = quotRem (d + b - a - one) d
            in  (1 - one', x : xs)
        idxMinus _ _ _ = (0, [])
        idxPlus :: [Word] -> [Word] -> [Word] -> (Word, [Word])
        idxPlus (d:ds) (a:as) (b:bs)
          = let (one , xs ) = idxPlus ds as bs
                (one', x  ) = quotRem (a + b + one) d
            in  (one', x : xs)
        idxPlus _ _ _ = (0, [])
    {-# INLINE enumFromThenTo #-}



-- | Show type of Idx (for displaying nice errors).
showIdxType :: forall (k :: Type) (x :: k) . BoundedDim x => String
showIdxType = "Idx " ++ show (dimVal (dimBound @k @x))

-- | Show type of Idxs (for displaying nice errors).
showIdxsType :: Dims ns -> String
showIdxsType ds = "Idxs '" ++ show (listDims ds)