{-# LANGUAGE UndecidableInstances #-}
module Data.Tensor.Index where
import Data.Proxy
import Data.Reflection
import Data.Singletons
import qualified Data.Singletons.Prelude.List as N
import Data.Tensor.Type
import GHC.Exts
import GHC.TypeLits
newtype TensorIndex (shape :: [Nat]) = TensorIndex Index deriving (Eq,Show,Ord)
instance forall s. SingI s => Bounded (TensorIndex s) where
minBound = toEnum 0
maxBound = let s = natsVal (Proxy :: Proxy s) in toEnum (product s - 1)
instance forall s. SingI s => Enum (TensorIndex s) where
toEnum i = let s = natsVal (Proxy :: Proxy s) in TensorIndex $ viToti s i
fromEnum (TensorIndex i) = let s = natsVal (Proxy :: Proxy s) in tiTovi s i
instance forall s. SingI s => IsList (TensorIndex s) where
type Item (TensorIndex s) = Int
fromList v =
let s = natsVal (Proxy :: Proxy s)
in if length v /= length s then error "length not match"
else if or (zipWith (\i n-> i <0 || i >= n) v s) then error "index overflow"
else TensorIndex v
toList (TensorIndex v) = v
type TensorRank (s :: [Nat]) = N.Length s
data TensorRankIndex (shape :: [Nat]) = forall (i :: Nat). KnownNat i => TensorRankIndex (Proxy i)
instance SingI s => Show (TensorRankIndex s) where
show = show . fromEnum
instance forall s. (SingI s, KnownNat (TensorRank s - 1)) => Bounded (TensorRankIndex s) where
minBound = TensorRankIndex i0
maxBound = TensorRankIndex (Proxy :: Proxy (TensorRank s - 1))
instance forall (s::[Nat]). (SingI s) => Enum (TensorRankIndex s) where
toEnum i = reifyNat (toInteger i) TensorRankIndex
fromEnum (TensorRankIndex p) = fromInteger $ natVal p