{-# LANGUAGE UndecidableInstances #-} module Data.Tensor.Index where import Data.Proxy import Data.Reflection import Data.Singletons import qualified Data.Singletons.Prelude.List as N import Data.Tensor.Type import GHC.Exts import GHC.TypeLits -- | Tensor Index, used to locate each point of tensor newtype TensorIndex (shape :: [Nat]) = TensorIndex Index deriving (Eq,Show,Ord) instance forall s. SingI s => Bounded (TensorIndex s) where minBound = toEnum 0 maxBound = let s = natsVal (Proxy :: Proxy s) in toEnum (product s - 1) instance forall s. SingI s => Enum (TensorIndex s) where toEnum i = let s = natsVal (Proxy :: Proxy s) in TensorIndex $ viToti s i fromEnum (TensorIndex i) = let s = natsVal (Proxy :: Proxy s) in tiTovi s i instance forall s. SingI s => IsList (TensorIndex s) where type Item (TensorIndex s) = Int fromList v = let s = natsVal (Proxy :: Proxy s) in if length v /= length s then error "length not match" else if or (zipWith (\i n-> i <0 || i >= n) v s) then error "index overflow" else TensorIndex v toList (TensorIndex v) = v -- | Tensor rank. type TensorRank (s :: [Nat]) = N.Length s -- type TensorRankConstraint s i = N.And '[ (N.>=) i 0, (N.<) i (TensorRank s)] data TensorRankIndex (shape :: [Nat]) = forall (i :: Nat). KnownNat i => TensorRankIndex (Proxy i) instance SingI s => Show (TensorRankIndex s) where show = show . fromEnum instance forall s. (SingI s, KnownNat (TensorRank s - 1)) => Bounded (TensorRankIndex s) where minBound = TensorRankIndex i0 maxBound = TensorRankIndex (Proxy :: Proxy (TensorRank s - 1)) instance forall (s::[Nat]). (SingI s) => Enum (TensorRankIndex s) where toEnum i = reifyNat (toInteger i) TensorRankIndex fromEnum (TensorRankIndex p) = fromInteger $ natVal p