{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE TypeApplications #-}
#if __GLASGOW_HASKELL__ >= 802
#else
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
#endif
module Numeric.Dimensions.Idxs
(
Idx (..), Idxs
, idxFromWord, unsafeIdxFromWord, idxToWord
, listIdxs, idxsFromWords
, module Numeric.Dimensions.Dims
) where
import Control.Arrow (first)
import Data.Data (Data)
import Foreign.Storable (Storable)
import GHC.Base
import GHC.Enum
import GHC.Generics (Generic, Generic1)
import Numeric.Dimensions.Dims
newtype Idx n = Idx { unIdx :: Word }
deriving ( Data, Generic, Generic1, Integral, Real, Storable, Eq, Ord )
instance Read (Idx n) where
readsPrec d = fmap (first Idx) . readsPrec d
instance Show (Idx n) where
showsPrec d = showsPrec d . unIdx
instance KnownDim n => Bounded (Idx n) where
minBound = 1
{-# INLINE minBound #-}
maxBound = unsafeCoerce# (dim @_ @n)
{-# INLINE maxBound #-}
instance KnownDim n => Enum (Idx n) where
#ifdef UNSAFE_INDICES
succ = unsafeCoerce# ((+ 1) :: Word -> Word)
#else
succ x@(Idx i)
| x /= maxBound = Idx (i + 1)
| otherwise = succError $ "Idx " ++ show (dim @_ @n)
#endif
{-# INLINE succ #-}
#ifdef UNSAFE_INDICES
pred = unsafeCoerce# ((+ (-1)) :: Word -> Word)
#else
pred x@(Idx i)
| x /= maxBound = Idx (i + 1)
| otherwise = predError $ "Idx " ++ show (dim @_ @n)
#endif
{-# INLINE pred #-}
#ifdef UNSAFE_INDICES
toEnum (I# i#) = unsafeCoerce# (W# (int2Word# (i# +# 1#)))
#else
toEnum i@(I# i#)
| i >= 0 && i < dm = unsafeCoerce# (W# (int2Word# (i# +# 1#) ))
| otherwise = toEnumError ("Idx " ++ show d) i (0, dm)
where
d = unsafeCoerce# (dim @_ @n) :: Word
dm = fromIntegral d - 1
#endif
{-# INLINE toEnum #-}
#ifdef UNSAFE_INDICES
fromEnum (Idx (W# w#)) = I# (word2Int# w# -# 1#)
#else
fromEnum (Idx x@(W# w#))
| x <= maxIntWord = I# (word2Int# w# -# 1#)
| otherwise = fromEnumError ("Idx " ++ show (dim @_ @n)) x
where
maxIntWord = W# (case maxInt of I# i -> int2Word# i)
#endif
{-# INLINE fromEnum #-}
enumFrom (Idx n)
= unsafeCoerce# (enumFromTo n (unsafeCoerce# (dim @_ @n)))
{-# INLINE enumFrom #-}
enumFromThen (Idx n0) (Idx n1)
= case compare n0 n1 of
LT -> unsafeCoerce# (enumFromThenTo n0 n1 (unsafeCoerce# (dim @_ @n)))
EQ -> unsafeCoerce# (repeat n0)
GT -> unsafeCoerce# (enumFromThenTo n0 n1 1)
{-# INLINE enumFromThen #-}
enumFromTo
= unsafeCoerce# (enumFromTo :: Word -> Word -> [Word])
{-# INLINE enumFromTo #-}
enumFromThenTo
= unsafeCoerce# (enumFromThenTo :: Word -> Word -> Word -> [Word])
{-# INLINE enumFromThenTo #-}
instance KnownDim n => Num (Idx n) where
#ifdef UNSAFE_INDICES
(+) = unsafeCoerce# ((+) :: Word -> Word -> Word)
#else
(Idx a) + (Idx b)
| r > d || r < a || r < b
= errorWithoutStackTrace
$ "Num.(+){Idx " ++ show d ++ "}: sum of "
++ show a ++ " and " ++ show b
++ " is outside of index bounds."
| otherwise = Idx r
where
r = a + b
d = unsafeCoerce# (dim @_ @n)
#endif
{-# INLINE (+) #-}
#ifdef UNSAFE_INDICES
(-) = unsafeCoerce# ((-) :: Word -> Word -> Word)
#else
(Idx a) - (Idx b)
| b >= a
= errorWithoutStackTrace
$ "Num.(-){Idx " ++ show (dim @_ @n) ++ "}: difference of "
++ show a ++ " and " ++ show b
++ " is not positive."
| otherwise = Idx (a - b)
#endif
{-# INLINE (-) #-}
#ifdef UNSAFE_INDICES
(*) = unsafeCoerce# ((*) :: Word -> Word -> Word)
#else
(Idx a) * (Idx b)
| r > d || r < a || r < b
= errorWithoutStackTrace
$ "Num.(*){Idx " ++ show d ++ "}: product of "
++ show a ++ " and " ++ show b
++ " is outside of index bounds."
| otherwise = Idx r
where
r = a * b
d = unsafeCoerce# (dim @_ @n)
#endif
{-# INLINE (*) #-}
negate = errorWithoutStackTrace
$ "Num.(*){Idx " ++ show (dim @_ @n) ++ "}: cannot negate index."
{-# INLINE negate #-}
abs = id
{-# INLINE abs #-}
signum _ = Idx 1
{-# INLINE signum #-}
#ifdef UNSAFE_INDICES
fromInteger = unsafeCoerce# (fromInteger :: Integer -> Word)
#else
fromInteger i
| i > 0 && i <= d = Idx $ fromInteger i
| otherwise = errorWithoutStackTrace
$ "Num.fromInteger{Idx "
++ show d ++ "}: integer "
++ show i ++ " is outside of index bounds."
where
d = toInteger (unsafeCoerce# (dim @_ @n) :: Word)
#endif
{-# INLINE fromInteger #-}
unsafeIdxFromWord :: forall d . KnownDim d => Word -> Idx d
#ifdef UNSAFE_INDICES
unsafeIdxFromWord = unsafeCoerce#
#else
unsafeIdxFromWord w
| w > 0 && w <= d = Idx w
| otherwise = errorWithoutStackTrace
$ "idxFromWord{Idx "
++ show d ++ "}: word "
++ show w ++ " is outside of index bounds."
where
d = unsafeCoerce# (dim @_ @d)
#endif
{-# INLINE unsafeIdxFromWord #-}
idxFromWord :: forall d . KnownDim d => Word -> Maybe (Idx d)
idxFromWord w
| w > 0 && w <= unsafeCoerce# (dim @_ @d) = Just (Idx w)
| otherwise = Nothing
{-# INLINE idxFromWord #-}
idxToWord :: Idx d -> Word
idxToWord = unsafeCoerce#
{-# INLINE idxToWord #-}
{-# RULES
"fromIntegral/idxToWord"
fromIntegral = idxToWord
#-}
type Idxs (xs :: [k]) = TypedList Idx xs
listIdxs :: Idxs xs -> [Word]
listIdxs = unsafeCoerce#
{-# INLINE listIdxs #-}
idxsFromWords :: forall ds . Dimensions ds => [Word] -> Maybe (Idx ds)
idxsFromWords = unsafeCoerce# . go (listDims (dims @_ @ds))
where
go [] [] = Just []
go (d : ds) (i : is)
| i > 0 && i <= d = (i:) <$> go ds is
go _ _ = Nothing
instance Eq (Idxs xs) where
(==) = unsafeCoerce# ((==) :: [Word] -> [Word] -> Bool)
{-# INLINE (==) #-}
instance Ord (Idxs xs) where
compare a b = compare (reverse $ listIdxs a) (reverse $ listIdxs b)
{-# INLINE compare #-}
instance Show (Idxs xs) where
show ds = "Idxs " ++ show (listIdxs ds)
showsPrec p ds
= showParen (p >= 10)
$ showString "Idxs " . showsPrec p (listIdxs ds)
instance KnownDim n => Num (Idxs '[n]) where
(a:*U) + (b:*U) = (a+b) :* U
{-# INLINE (+) #-}
(a:*U) - (b:*U) = (a-b) :* U
{-# INLINE (-) #-}
(a:*U) * (b:*U) = (a*b) :* U
{-# INLINE (*) #-}
signum (a:*U) = signum a :* U
{-# INLINE signum #-}
abs (a:*U) = abs a :* U
{-# INLINE abs #-}
fromInteger i = fromInteger i :* U
{-# INLINE fromInteger #-}
instance Dimensions ds => Bounded (Idxs ds) where
maxBound = f (dims @_ @ds)
where
f :: forall ns . Dims ns -> Idxs ns
f U = U
f (d :* ds) = Idx (dimVal d) :* f ds
{-# INLINE maxBound #-}
minBound = f (dims @_ @ds)
where
f :: forall ns . Dims ns -> Idxs ns
f U = U
f (_ :* ds) = Idx 1 :* f ds
{-# INLINE minBound #-}
instance Dimensions ds => Enum (Idxs ds) where
succ = go (dims @_ @ds)
where
go :: forall ns . Dims ns -> Idxs ns -> Idxs ns
go U U = succError $ "Idxs " ++ show (listDims $ dims @_ @ds)
go (d :* ds) (Idx i :* is)
| i == dimVal d = Idx 1 :* go ds is
| otherwise = Idx (i+1) :* is
{-# INLINE succ #-}
pred = go (dims @_ @ds)
where
go :: forall ns . Dims ns -> Idxs ns -> Idxs ns
go U U = predError $ "Idxs " ++ show (listDims $ dims @_ @ds)
go (d :* ds) (Idx i :* is)
| i == 1 = Idx (dimVal d) :* go ds is
| otherwise = Idx (i-1) :* is
{-# INLINE pred #-}
toEnum i = go dds $ fromIntegral i
where
dds = dims @_ @ds
go :: forall ns . Dims ns -> Word -> Idxs ns
go U 0 = U
go U _ = toEnumError ("Idxs " ++ show (listDims dds))
i (0, totalDim dds - 1)
go (d :* ds) off = case divMod off (dimVal d) of
(off', j) -> Idx (j+1) :* go ds off'
{-# INLINE toEnum #-}
fromEnum = fromIntegral . go 1 (dims @_ @ds)
where
go :: forall ns . Word -> Dims ns -> Idxs ns -> Word
go _ U U = 0
go m (d :* ds) (Idx i :* is) = m * (i - 1) + go (m * dimVal d) ds is
{-# INLINE fromEnum #-}
enumFrom x = take (diffIdx (dims @_ @ds) maxBound x + 1) $ iterate succ x
{-# INLINE enumFrom #-}
enumFromTo x y | x >= y = take (diffIdx ds x y + 1) $ iterate pred x
| otherwise = take (diffIdx ds y x + 1) $ iterate succ x
where
ds = dims @_ @ds
{-# INLINE enumFromTo #-}
enumFromThen x x' = take n $ iterate (stepIdx ds dn) x
where
ds = dims @_ @ds
dn = diffIdx ds x' x
n = 1 + if dn == 0
then 0
else if dn > 0
then diffIdx ds maxBound x `div` dn
else diffIdx ds x minBound `div` negate dn
{-# INLINE enumFromThen #-}
enumFromThenTo x x' y = take n $ iterate (stepIdx ds dn) x
where
ds = dims @_ @ds
dn = diffIdx ds x' x
n = 1 + if dn == 0 then 0
else diffIdx ds y x `div` dn
{-# INLINE enumFromThenTo #-}
diffIdx :: Dims xs -> Idxs xs -> Idxs xs -> Int
diffIdx U U U = 0
diffIdx (d :* ds) (Idx i1 :* is1) (Idx i2 :* is2)
= fromIntegral i1 - fromIntegral i2
+ fromIntegral (dimVal d) * diffIdx ds is1 is2
{-# INLINE diffIdx #-}
stepIdx :: Dims ds -> Int -> Idxs ds -> Idxs ds
stepIdx U _ U = U
stepIdx (d :* ds) di (Idx i :* is)
= case divMod (di + fromIntegral i - 1) (fromIntegral (dimVal d)) of
(0 , i') -> Idx (fromIntegral (i'+1)) :* is
(di', i') -> Idx (fromIntegral (i'+1)) :* stepIdx ds di' is
{-# INLINE stepIdx #-}