{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wall #-} -- | Functions for manipulating shape. The module tends to supply equivalent functionality at type-level and value-level with functions of the same name (except for capitalization). module NumHask.Array.Shape where import Data.List ((!!)) import Data.Type.Bool import GHC.TypeLits as L import NumHask.Prelude as P hiding (Last, minimum) -- | The Shape type holds a [Nat] at type level and the equivalent [Int] at value level. newtype Shape (s :: [Nat]) = Shape {shapeVal :: [Int]} deriving (Show) class HasShape s where toShape :: Shape s instance HasShape '[] where toShape = Shape [] instance (KnownNat n, HasShape s) => HasShape (n : s) where toShape = Shape $ fromInteger (natVal (Proxy :: Proxy n)) : shapeVal (toShape :: Shape s) -- | Number of dimensions rank :: [a] -> Int rank = length {-# INLINE rank #-} type family Rank (s :: [a]) :: Nat where Rank '[] = 0 Rank (_ : s) = Rank s + 1 -- | The shape of a list of element indexes ranks :: [[a]] -> [Int] ranks = fmap rank {-# INLINE ranks #-} type family Ranks (s :: [[a]]) :: [Nat] where Ranks '[] = '[] Ranks (x : xs) = Rank x : Ranks xs -- | Number of elements size :: [Int] -> Int size [] = 1 size [x] = x size xs = P.product xs {-# INLINE size #-} type family Size (s :: [Nat]) :: Nat where Size '[] = 1 Size (n : s) = n L.* Size s -- | convert from n-dim shape index to a flat index -- -- >>> flatten [2,3,4] [1,1,1] -- 17 -- -- >>> flatten [] [1,1,1] -- 0 flatten :: [Int] -> [Int] -> Int flatten [] _ = 0 flatten _ [x'] = x' flatten ns xs = sum $ zipWith (*) xs (drop 1 $ scanr (*) one ns) {-# INLINE flatten #-} -- | convert from a flat index to a shape index -- -- >>> shapen [2,3,4] 17 -- [1,1,1] shapen :: [Int] -> Int -> [Int] shapen [] _ = [] shapen [_] x' = [x'] shapen [_, y] x' = let (i, j) = divMod x' y in [i, j] shapen ns x = fst $ foldr ( \a (acc, r) -> let (d, m) = divMod r a in (m : acc, d) ) ([], x) ns {-# INLINE shapen #-} -- | /checkIndex i n/ checks if /i/ is a valid index of a list of length /n/ checkIndex :: Int -> Int -> Bool checkIndex i n = zero <= i && i + one <= n type family CheckIndex (i :: Nat) (n :: Nat) :: Bool where CheckIndex i n = If ((0 <=? i) && (i + 1 <=? n)) 'True (L.TypeError ('Text "index outside range")) -- | /checkIndexes is n/ check if /is/ are valid indexes of a list of length /n/ checkIndexes :: [Int] -> Int -> Bool checkIndexes is n = all (`checkIndex` n) is type family CheckIndexes (i :: [Nat]) (n :: Nat) :: Bool where CheckIndexes '[] n = 'True CheckIndexes (i : is) n = CheckIndex i n && CheckIndexes is n -- | dimension i is the i'th dimension of a Shape dimension :: [Int] -> Int -> Int dimension (s : _) 0 = s dimension (_ : s) n = dimension s (n - 1) dimension _ _ = throw (NumHaskException "dimension overflow") type family Dimension (s :: [Nat]) (i :: Nat) :: Nat where Dimension (s : _) 0 = s Dimension (_ : s) n = Dimension s (n - 1) Dimension _ _ = L.TypeError ('Text "dimension overflow") -- | minimum value in a list minimum :: [Int] -> Int minimum [] = throw (NumHaskException "dimension underflow") minimum [x] = x minimum (x : xs) = P.min x (minimum xs) type family Minimum (s :: [Nat]) :: Nat where Minimum '[] = L.TypeError ('Text "zero dimension") Minimum '[x] = x Minimum (x : xs) = If (x <=? Minimum xs) x (Minimum xs) type family Take (n :: Nat) (a :: [k]) :: [k] where Take 0 _ = '[] Take n (x : xs) = x : Take (n - 1) xs type family Drop (n :: Nat) (a :: [k]) :: [k] where Drop 0 xs = xs Drop n (_ : xs) = Drop (n - 1) xs type family Tail (a :: [k]) :: [k] where Tail '[] = L.TypeError ('Text "No tail") Tail (_ : xs) = xs type family Init (a :: [k]) :: [k] where Init '[] = L.TypeError ('Text "No init") Init '[_] = '[] Init (x : xs) = x : Init xs type family Head (a :: [k]) :: k where Head '[] = L.TypeError ('Text "No head") Head (x : _) = x type family Last (a :: [k]) :: k where Last '[] = L.TypeError ('Text "No last") Last '[x] = x Last (_ : xs) = Last xs type family (a :: [k]) ++ (b :: [k]) :: [k] where '[] ++ b = b (a : as) ++ b = a : (as ++ b) -- | drop the i'th dimension from a shape -- -- >>> dropIndex [2, 3, 4] 1 -- [2,4] dropIndex :: [Int] -> Int -> [Int] dropIndex s i = take i s ++ drop (i + 1) s type DropIndex s i = Take i s ++ Drop (i + 1) s -- | /addIndex s i d/ adds a new dimension to shape /s/ at position /i/ -- -- >>> addIndex [2,4] 1 3 -- [2,3,4] addIndex :: [Int] -> Int -> Int -> [Int] addIndex s i d = take i s ++ (d : drop i s) type AddIndex s i d = Take i s ++ (d : Drop i s) type Reverse (a :: [k]) = ReverseGo a '[] type family ReverseGo (a :: [k]) (b :: [k]) :: [k] where ReverseGo '[] b = b ReverseGo (a : as) b = ReverseGo as (a : b) -- | convert a list of position that references a final shape to one that references positions relative to an accumulator. Deletions are from the left and additions are from the right. -- -- deletions -- -- >>> posRelative [0,1] -- [0,0] -- -- additions -- -- >>> reverse (posRelative (reverse [1,0])) -- [0,0] posRelative :: [Int] -> [Int] posRelative as = reverse (go [] as) where go r [] = r go r (x : xs) = go (x : r) ((\y -> bool (y - one) y (y < x)) <$> xs) type family PosRelative (s :: [Nat]) where PosRelative s = PosRelativeGo s '[] type family PosRelativeGo (r :: [Nat]) (s :: [Nat]) where PosRelativeGo '[] r = Reverse r PosRelativeGo (x : xs) r = PosRelativeGo (DecMap x xs) (x : r) type family DecMap (x :: Nat) (ys :: [Nat]) :: [Nat] where DecMap _ '[] = '[] DecMap x (y : ys) = If (y + 1 <=? x) y (y - 1) : DecMap x ys -- | drop dimensions of a shape according to a list of positions (where position refers to the initial shape) -- -- >>> dropIndexes [2, 3, 4] [1, 0] -- [4] dropIndexes :: [Int] -> [Int] -> [Int] dropIndexes s i = foldl' dropIndex s (posRelative i) type family DropIndexes (s :: [Nat]) (i :: [Nat]) where DropIndexes s i = DropIndexesGo s (PosRelative i) type family DropIndexesGo (s :: [Nat]) (i :: [Nat]) where DropIndexesGo s '[] = s DropIndexesGo s (i : is) = DropIndexesGo (DropIndex s i) is -- | insert a list of dimensions according to position and dimension lists. Note that the list of positions references the final shape and not the initial shape. -- -- >>> addIndexes [4] [1,0] [3,2] -- [2,3,4] addIndexes :: () => [Int] -> [Int] -> [Int] -> [Int] addIndexes as xs = addIndexesGo as (reverse (posRelative (reverse xs))) where addIndexesGo as' [] _ = as' addIndexesGo as' (x : xs') (y : ys') = addIndexesGo (addIndex as' x y) xs' ys' addIndexesGo _ _ _ = throw (NumHaskException "mismatched ranks") type family AddIndexes (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where AddIndexes as xs ys = AddIndexesGo as (Reverse (PosRelative (Reverse xs))) ys type family AddIndexesGo (as :: [Nat]) (xs :: [Nat]) (ys :: [Nat]) where AddIndexesGo as' '[] _ = as' AddIndexesGo as' (x : xs') (y : ys') = AddIndexesGo (AddIndex as' x y) xs' ys' AddIndexesGo _ _ _ = L.TypeError ('Text "mismatched ranks") -- | take list of dimensions according to position lists. -- -- >>> takeIndexes [2,3,4] [2,0] -- [4,2] takeIndexes :: [Int] -> [Int] -> [Int] takeIndexes s i = (s !!) <$> i type family TakeIndexes (s :: [Nat]) (i :: [Nat]) where TakeIndexes '[] _ = '[] TakeIndexes _ '[] = '[] TakeIndexes s (i : is) = (s !! i) ': TakeIndexes s is type family (a :: [k]) !! (b :: Nat) :: k where (!!) '[] i = L.TypeError ('Text "Index Underflow") (!!) (x : _) 0 = x (!!) (_ : xs) i = (!!) xs (i - 1) type family Enumerate (n :: Nat) where Enumerate n = Reverse (EnumerateGo n) type family EnumerateGo (n :: Nat) where EnumerateGo 0 = '[] EnumerateGo n = (n - 1) : EnumerateGo (n - 1) -- | turn a list of included positions for a given rank into a list of excluded positions -- -- >>> exclude 3 [1,2] -- [0] exclude :: Int -> [Int] -> [Int] exclude r = dropIndexes [0 .. (r - 1)] type family Exclude (r :: Nat) (i :: [Nat]) where Exclude r i = DropIndexes (EnumerateGo r) i concatenate' :: Int -> [Int] -> [Int] -> [Int] concatenate' i s0 s1 = take i s0 ++ (dimension s0 i + dimension s1 i : drop (i + 1) s0) type Concatenate i s0 s1 = Take i s0 ++ (Dimension s0 i + Dimension s1 i : Drop (i + 1) s0) type CheckConcatenate i s0 s1 s = ( CheckIndex i (Rank s0) && DropIndex s0 i == DropIndex s1 i && Rank s0 == Rank s1 ) ~ 'True type CheckInsert d i s = (CheckIndex d (Rank s) && CheckIndex i (Dimension s d)) ~ 'True type Insert d s = Take d s ++ (Dimension s d + 1 : Drop (d + 1) s) -- | /incAt d s/ increments the index at /d/ of shape /s/ by one. incAt :: Int -> [Int] -> [Int] incAt d s = take d s ++ (dimension s d + 1 : drop (d + 1) s) -- | /decAt d s/ decrements the index at /d/ of shape /s/ by one. decAt :: Int -> [Int] -> [Int] decAt d s = take d s ++ (dimension s d - 1 : drop (d + 1) s) -- /reorder' s i/ reorders the dimensions of shape /s/ according to a list of positions /i/ -- -- >>> reorder' [2,3,4] [2,0,1] -- [4,2,3] reorder' :: [Int] -> [Int] -> [Int] reorder' [] _ = [] reorder' _ [] = [] reorder' s (d : ds) = dimension s d : reorder' s ds type family Reorder (s :: [Nat]) (ds :: [Nat]) :: [Nat] where Reorder '[] _ = '[] Reorder _ '[] = '[] Reorder s (d : ds) = Dimension s d : Reorder s ds type family CheckReorder (ds :: [Nat]) (s :: [Nat]) where CheckReorder ds s = If ( Rank ds == Rank s && CheckIndexes ds (Rank s) ) 'True (L.TypeError ('Text "bad dimensions")) ~ 'True squeeze' :: (Eq a, Num a) => [a] -> [a] squeeze' = filter (/=1) type family Squeeze (a :: [Nat]) where Squeeze '[] = '[] Squeeze a = Filter '[] a 1 type family Filter (r :: [Nat]) (xs :: [Nat]) (i :: Nat) where Filter r '[] _ = Reverse r Filter r (x : xs) i = Filter (If (x == i) r (x : r)) xs i -- unused but useful type-level functions type family Sort (xs :: [k]) :: [k] where Sort '[] = '[] Sort (x ': xs) = (Sort (SFilter 'FMin x xs) ++ '[x]) ++ Sort (SFilter 'FMax x xs) data Flag = FMin | FMax type family Cmp (a :: k) (b :: k) :: Ordering type family SFilter (f :: Flag) (p :: k) (xs :: [k]) :: [k] where SFilter f p '[] = '[] SFilter 'FMin p (x ': xs) = If (Cmp x p == 'LT) (x ': SFilter 'FMin p xs) (SFilter 'FMin p xs) SFilter 'FMax p (x ': xs) = If (Cmp x p == 'GT || Cmp x p == 'EQ) (x ': SFilter 'FMax p xs) (SFilter 'FMax p xs) type family Zip lst lst' where Zip lst lst' = ZipWith '(,) lst lst' -- Implemented as TF because #11375 type family ZipWith f lst lst' where ZipWith f '[] lst = '[] ZipWith f lst '[] = '[] ZipWith f (l ': ls) (n ': ns) = f l n ': ZipWith f ls ns type family Fst a where Fst '(a, _) = a type family Snd a where Snd '(_, a) = a type family FMap f lst where FMap f '[] = '[] FMap f (l ': ls) = f l ': FMap f ls -- | Reflect a list of Nats class KnownNats (ns :: [Nat]) where natVals :: Proxy ns -> [Int] instance KnownNats '[] where natVals _ = [] instance (KnownNat n, KnownNats ns) => KnownNats (n : ns) where natVals _ = fromInteger (natVal (Proxy @n)) : natVals (Proxy @ns) -- | Reflect a list of list of Nats class KnownNatss (ns :: [[Nat]]) where natValss :: Proxy ns -> [[Int]] instance KnownNatss '[] where natValss _ = [] instance (KnownNats n, KnownNatss ns) => KnownNatss (n : ns) where natValss _ = natVals (Proxy @n) : natValss (Proxy @ns)