-- Copyright 2020 Google LLC -- -- Licensed under the Apache License, Version 2.0 (the "License"); -- you may not use this file except in compliance with the License. -- You may obtain a copy of the License at -- -- http://www.apache.org/licenses/LICENSE-2.0 -- -- Unless required by applicable law or agreed to in writing, software -- distributed under the License is distributed on an "AS IS" BASIS, -- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- See the License for the specific language governing permissions and -- limitations under the License. {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NoStarIsType #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Data.Array.Internal.Shape(module Data.Array.Internal.Shape) where import Data.Proxy import Type.Reflection import GHC.TypeLits import Data.Array.Internal(valueOf) type DivRoundUp n m = Div (n+m-1) m ----------------- -- Type level properties. -- | Compute the rank, i.e., length of a type level shape. type family Rank (s :: [Nat]) :: Nat where Rank '[] = 0 Rank (n : ns) = 1 + Rank ns -- | Compute the size, i.e., total number of elements of a type level shape. type Size (s :: [Nat]) = Size' 1 s type family Size' (a :: Nat) (s :: [Nat]) :: Nat where Size' a '[] = a Size' a (n : ns) = Size' (a * n) ns -- Using an accumulating parameter generates fewer constraints. ---------------------------------- -- Type level shape operations type family (++) (xs :: [Nat]) (ys :: [Nat]) :: [Nat] where (++) '[] ys = ys (++) (x ': xs) ys = x ': (xs ++ ys) {- -- XXX O(n^2) type family Reverse (xs :: [Nat]) :: [Nat] where Reverse '[] = '[] Reverse (x ': xs) = Reverse xs ++ '[x] -} type family Take (n :: Nat) (xs :: [Nat]) :: [Nat] where Take 0 xs = '[] Take n (x ': xs) = x ': Take (n-1) xs type family Drop (n :: Nat) (xs :: [Nat]) :: [Nat] where Drop 0 xs = xs Drop n (x ': xs) = Drop (n-1) xs type family Last (xs :: [Nat]) where Last '[x] = x Last (x ': xs) = Last xs type family Init (xs :: [Nat]) where Init '[x] = '[] Init (x ': xs) = x ': Init xs ----------------- class ValidStretch (from :: [Nat]) (to :: [Nat]) where stretching :: Proxy from -> Proxy to -> [Bool] instance ValidStretch '[] '[] where stretching _ _ = [] instance (BoolVal (Stretch s m), ValidStretch ss ms) => ValidStretch (s ': ss) (m ': ms) where stretching _ _ = boolVal (Proxy :: Proxy (Stretch s m)) : stretching (Proxy :: Proxy ss) (Proxy :: Proxy ms) type family Stretch (s::Nat) (m::Nat) :: Bool where Stretch 1 m = 'True Stretch m m = 'False Stretch s m = TypeError ('Text "Cannot stretch " ':<>: 'ShowType s ':<>: 'Text " to " ':<>: 'ShowType m) class BoolVal (b :: Bool) where boolVal :: Proxy b -> Bool instance BoolVal 'False where boolVal _ = False instance BoolVal 'True where boolVal _ = True ----------------- class Padded (ps :: [(Nat, Nat)]) (sh :: [Nat]) (sh' :: [Nat]) | ps sh -> sh' where padded :: Proxy ps -> Proxy sh -> [(Int, Int)] instance Padded '[] sh sh where padded _ _ = [] instance (KnownNat l, KnownNat h, (l+s+h) ~ s', Padded ps sh sh') => Padded ('(l,h) ': ps) (s ': sh) (s' ': sh') where padded _ _ = (valueOf @l, valueOf @h) : padded (Proxy :: Proxy ps) (Proxy :: Proxy sh) ----------------- class Permutation (is :: [Nat]) instance (AllElem is (Count 0 is)) => Permutation is type family Count (i :: Nat) (xs :: [Nat]) :: [Nat] where Count i '[] = '[] Count i (x ': xs) = i ': Count (i+1) xs class AllElem (is :: [Nat]) (ns :: [Nat]) instance AllElem '[] ns instance (Elem i ns, AllElem is ns) => AllElem (i ': is) ns class Elem (i :: Nat) (ns :: [Nat]) instance (Elem' (CmpNat i n) i ns) => Elem i (n : ns) class Elem' (e :: Ordering) (i :: Nat) (ns :: [Nat]) instance Elem' 'EQ i ns instance (Elem i ns) => Elem' 'LT i ns instance (Elem i ns) => Elem' 'GT i ns type Permute (is :: [Nat]) (xs :: [Nat]) = Permute' is (Take (Rank is) xs) ++ Drop (Rank is) xs type family Permute' (is :: [Nat]) (xs :: [Nat]) where Permute' '[] xs = '[] Permute' (i ': is) xs = Index xs i ': Permute' is xs type family Index (xs :: [Nat]) (i :: Nat) where Index (x : xs) 0 = x Index (x : xs) i = Index xs (i-1) class ValidDims (rs :: [Nat]) (sh :: [Nat]) instance (AllElem rs (Count 0 sh)) => ValidDims rs sh ----------------- class Window (ws :: [Nat]) (ss :: [Nat]) (rs :: [Nat]) | ws ss -> rs instance (Window' ws ws ss rs) => Window ws ss rs class Window' (ows :: [Nat]) (ws :: [Nat]) (ss :: [Nat]) (rs :: [Nat]) | ows ws ss -> rs instance ((ows ++ ss) ~ rs) => Window' ows '[] ss rs instance (Window' ows ws ss rs, w <= s, ((s+1)-w) ~ r) => Window' ows (w ': ws) (s ': ss) (r ': rs) ----------------- class Stride (ts :: [Nat]) (ss :: [Nat]) (rs :: [Nat]) | ts ss -> rs instance Stride '[] ss ss instance (Stride ts ss rs, DivRoundUp s t ~ r) => Stride (t ': ts) (s ': ss) (r ': rs) ----------------- class Slice (ls :: [(Nat,Nat)]) (ss :: [Nat]) (rs :: [Nat]) | ls ss -> rs where sliceOffsets :: Proxy ls -> Proxy ss -> [Int] instance Slice '[] ss ss where sliceOffsets _ _ = [] instance (Slice ls ss rs, (o+n) <= s, KnownNat o) => Slice ('(o,n) ': ls) (s ': ss) (n ': rs) where sliceOffsets _ _ = valueOf @o : sliceOffsets (Proxy :: Proxy ls) (Proxy :: Proxy ss) ----------------- -- Shape extraction class (Typeable s) => Shape (s :: [Nat]) where shapeP :: Proxy s -> [Int] sizeP :: Proxy s -> Int instance Shape '[] where {-# INLINE shapeP #-} shapeP _ = [] {-# INLINE sizeP #-} sizeP _ = 1 instance forall n s . (Shape s, KnownNat n) => Shape (n ': s) where {-# INLINE shapeP #-} shapeP _ = valueOf @n : shapeP (Proxy :: Proxy s) {-# INLINE sizeP #-} sizeP _ = valueOf @n * sizeP (Proxy :: Proxy s) {-# INLINE shapeT #-} shapeT :: forall sh . (Shape sh) => [Int] shapeT = shapeP (Proxy :: Proxy sh) {-# INLINE sizeT #-} sizeT :: forall sh . (Shape sh) => Int sizeT = sizeP (Proxy :: Proxy sh) -- | Turn a dynamic shape back into a type level shape. -- @withShape sh shapeP == sh@ withShapeP :: [Int] -> (forall sh . (Shape sh) => Proxy sh -> r) -> r withShapeP [] f = f (Proxy :: Proxy ('[] :: [Nat])) withShapeP (n:ns) f = case someNatVal (toInteger n) of Just (SomeNat (_ :: Proxy n)) -> withShapeP ns (\ (_ :: Proxy ns) -> f (Proxy :: Proxy (n ': ns))) _ -> error $ "withShape: bad size " ++ show n withShape :: [Int] -> (forall sh . (Shape sh) => r) -> r withShape sh f = withShapeP sh (\ (_ :: Proxy sh) -> f @sh) ----------------- -- | Using the dimension indices /ds/, can /sh/ be broadcast into shape /sh'/? class Broadcast (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]) where broadcasting :: [Bool] instance (Broadcast' 0 ds sh sh') => Broadcast ds sh sh' where broadcasting = broadcasting' @0 @ds @sh @sh' class Broadcast' (i :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]) where broadcasting' :: [Bool] instance Broadcast' i '[] '[] '[] where broadcasting' = [] instance (Broadcast' i '[] '[] sh') => Broadcast' i '[] '[] (s : sh') where broadcasting' = True : broadcasting' @i @'[] @'[] @sh' instance (TypeError ('Text "Too few dimension indices")) => Broadcast' i '[] (s ': sh) sh' where broadcasting' = undefined instance (TypeError ('Text "Too many dimensions indices")) => Broadcast' i (d ': ds) '[] sh' where broadcasting' = undefined instance (TypeError ('Text "Too few result dimensions")) => Broadcast' i (d ': ds) (s ': sh) '[] where broadcasting' = undefined instance (Broadcast'' (CmpNat i d) i d ds (s ': sh) (s' ': sh')) => Broadcast' i (d ': ds) (s ': sh) (s' ': sh') where broadcasting' = broadcasting'' @(CmpNat i d) @i @d @ds @(s ': sh) @(s' ': sh') class Broadcast'' (o :: Ordering) (i :: Nat) (d :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]) where broadcasting'' :: [Bool] instance (Broadcast' (i+1) ds sh rsh) => Broadcast'' 'EQ i d ds (s ': sh) (s ': rsh) where broadcasting'' = False : broadcasting' @(i+1) @ds @sh @rsh instance (Broadcast' (i+1) (d ': ds) sh rsh) => Broadcast'' 'LT i d ds sh (s' ': rsh) where broadcasting'' = True : broadcasting' @(i+1) @(d ': ds) @sh @rsh instance (TypeError ('Text "unordered dimensions")) => Broadcast'' 'GT i d ds sh rsh where broadcasting'' = undefined