{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE UndecidableInstances #-} module Data.Tensor.Type where import Data.List (foldl') import Data.Proxy import Data.Type.Bool hiding (If) import GHC.Exts import GHC.TypeLits import qualified GHC.TypeLits as L import Unsafe.Coerce type Shape = [Int] type Index = [Int] newtype SShape (shape :: [Nat]) = SShape { unShape :: Shape } deriving Show class HasShape s where toShape :: SShape s toRank :: Proxy s -> Int toRank _ = length $ unShape (toShape :: SShape s) toSize :: Proxy s -> Int toSize _ = product $ unShape (toShape :: SShape s) instance HasShape '[] where toShape = SShape [] instance (KnownNat n, HasShape s) => HasShape (n:s) where toShape = SShape $ fromInteger (natVal (Proxy :: Proxy n)) : unShape (toShape :: SShape s) toNat :: KnownNat n => Proxy n -> Int toNat = unsafeCoerce . natVal viToti :: Shape -> Int -> Index viToti s i = snd $ foldl' (\(r,xs) si -> let (r',x) = divMod r si in (r', x:xs)) (i,[]) (reverse s) tiTovi :: Shape -> Index -> Int tiTovi s i = foldl' (\b (n,ind) -> b * n + ind) 0 $ zipWith (,) s i -- | Tensor Index, used to locate each point of tensor newtype TensorIndex (shape :: [Nat]) = TensorIndex Index deriving (Eq,Show,Ord) instance HasShape s => IsList (TensorIndex s) where type Item (TensorIndex s) = Int fromList v = let s = unShape (toShape :: SShape s) in if length v /= length s then error "length not match" else if or (zipWith (\i n-> i <0 || i >= n) v s) then error "index overflow" else TensorIndex v toList (TensorIndex v) = v -- | Tensor rank. type family TensorRank (s :: [Nat]) :: Nat where TensorRank '[] = 0 TensorRank (_:s) = TensorRank s + 1 -- | Tensor size. type family TensorSize (s :: [Nat]) :: Nat where TensorSize '[] = 1 TensorSize (n:s) = n L.* (TensorRank s) type family Reverse (a :: [k]) (b :: [k]) :: [k] where Reverse '[] b = b Reverse (a:as) b = Reverse as (a:b) type family If (b :: Bool) c d where If 'True c d = c If 'False c d = d type family Replicate (a :: k) (dim :: Nat) :: [k] where Replicate a 0 = '[] Replicate a n = a : Replicate a n type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where Dimension (s:_) 0 = s Dimension (_:s) n = Dimension s (n-1) Dimension _ _ = TypeError ('Text "Index overflow") type CheckDimension dim s = IsIndex dim (TensorRank s) type CheckIndices i j s = IsIndices i j (TensorRank s) ~ 'True type IsIndex i n = (0 <=? i) && (i + 1 <=? n) type IsIndices i j n = (0 <=? i) && (i + 1 <=? j) && (j + 1 <=? n) type family Take (n :: Nat) (a :: [k]) :: [k] where Take 0 _ = '[] Take n (x:xs) = x : Take (n-1) xs type family Drop (n :: Nat) (a :: [k]) :: [k] where Drop 0 xs = xs Drop n (_:xs) = Take (n-1) xs type family Tail (a :: [k]) :: [k] where Tail '[] = TypeError ('Text "No tail") Tail (_:xs) = xs type family Init (a :: [k]) :: [k] where Init '[] = TypeError ('Text "No init") Init '[_] = '[] Init (x:xs) = x : Init xs type family Head (a :: [k]) :: k where Head '[] = TypeError ('Text "No head") Head (x:_) = x type family Last (a :: [k]) :: k where Last '[] = TypeError ('Text "No last") Last '[x] = x Last (_:xs) = Last xs type family (a :: [k]) ++ (b :: [k]) :: [k] where '[] ++ b = b (a:as) ++ b = a : (as ++ b) mult :: (Eq a, Num a) => a -> a -> a mult a b | a == 0 = 0 | b == 0 = 0 | otherwise = a * b ----------------------- -- Tensor Type Index ----------------------- i0 = Proxy :: Proxy 0 i1 = Proxy :: Proxy 1 i2 = Proxy :: Proxy 2 i3 = Proxy :: Proxy 3 i4 = Proxy :: Proxy 4 i5 = Proxy :: Proxy 5 i6 = Proxy :: Proxy 6 i7 = Proxy :: Proxy 7 i8 = Proxy :: Proxy 8 i9 = Proxy :: Proxy 9