-- Copyright 2020 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Array.Internal.Shape(module Data.Array.Internal.Shape) where
import Data.Proxy
import Type.Reflection
import GHC.TypeLits

import Data.Array.Internal(valueOf)

type DivRoundUp  n m = Div (n+m-1) m

-----------------
-- Type level properties.

-- | Compute the rank, i.e., length of a type level shape.
type family Rank (s :: [Nat]) :: Nat where
  Rank '[] = 0
  Rank (n : ns) = 1 + Rank ns

-- | Compute the size, i.e., total number of elements of a type level shape.
type Size (s :: [Nat]) = Size' 1 s
type family Size' (a :: Nat) (s :: [Nat]) :: Nat where
  Size' a '[] = a
  Size' a (n : ns) = Size' (a * n) ns
-- Using an accumulating parameter generates fewer constraints.

----------------------------------
-- Type level shape operations

type family (++) (xs :: [Nat]) (ys :: [Nat]) :: [Nat] where
  (++) '[] ys = ys
  (++) (x ': xs) ys = x ': (xs ++ ys)

{-
-- XXX O(n^2)
type family Reverse (xs :: [Nat]) :: [Nat] where
    Reverse '[] = '[]
    Reverse (x ': xs) = Reverse xs ++ '[x]
-}

type family Take (n :: Nat) (xs :: [Nat]) :: [Nat] where
    Take 0 xs = '[]
    Take n (x ': xs) = x ': Take (n-1) xs

type family Drop (n :: Nat) (xs :: [Nat]) :: [Nat] where
    Drop 0 xs = xs
    Drop n (x ': xs) = Drop (n-1) xs

type family Last (xs :: [Nat]) where
  Last '[x] = x
  Last (x ': xs) = Last xs

type family Init (xs :: [Nat]) where
  Init '[x] = '[]
  Init (x ': xs) = x ': Init xs

-----------------

class ValidStretch (from :: [Nat]) (to :: [Nat]) where
  stretching :: Proxy from -> Proxy to -> [Bool]
instance ValidStretch '[] '[] where
  stretching :: Proxy '[] -> Proxy '[] -> [Bool]
stretching Proxy '[]
_ Proxy '[]
_ = []
instance (BoolVal (Stretch s m), ValidStretch ss ms) => ValidStretch (s ': ss) (m ': ms) where
  stretching :: Proxy (s : ss) -> Proxy (m : ms) -> [Bool]
stretching Proxy (s : ss)
_ Proxy (m : ms)
_ = Proxy (Stretch s m) -> Bool
forall (b :: Bool). BoolVal b => Proxy b -> Bool
boolVal (Proxy (Stretch s m)
forall k (t :: k). Proxy t
Proxy :: Proxy (Stretch s m)) Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
:
    Proxy ss -> Proxy ms -> [Bool]
forall (from :: [Nat]) (to :: [Nat]).
ValidStretch from to =>
Proxy from -> Proxy to -> [Bool]
stretching (Proxy ss
forall k (t :: k). Proxy t
Proxy :: Proxy ss) (Proxy ms
forall k (t :: k). Proxy t
Proxy :: Proxy ms)
type family Stretch (s::Nat) (m::Nat) :: Bool where
  Stretch 1 m = 'True
  Stretch m m = 'False
  Stretch s m = TypeError ('Text "Cannot stretch " ':<>: 'ShowType s ':<>: 'Text " to " ':<>: 'ShowType m)

class BoolVal (b :: Bool) where
  boolVal :: Proxy b -> Bool
instance BoolVal 'False where
  boolVal :: Proxy 'False -> Bool
boolVal Proxy 'False
_ = Bool
False
instance BoolVal 'True where
  boolVal :: Proxy 'True -> Bool
boolVal Proxy 'True
_ = Bool
True

-----------------

class Padded (ps :: [(Nat, Nat)]) (sh :: [Nat]) (sh' :: [Nat]) | ps sh -> sh' where
  padded :: Proxy ps -> Proxy sh -> [(Int, Int)]
instance Padded '[] sh sh where
  padded :: Proxy '[] -> Proxy sh -> [(Int, Int)]
padded Proxy '[]
_ Proxy sh
_ = []
instance (KnownNat l, KnownNat h, (l+s+h) ~ s', Padded ps sh sh') =>
         Padded ('(l,h) ': ps) (s ': sh) (s' ': sh') where
  padded :: Proxy ('(l, h) : ps) -> Proxy (s : sh) -> [(Int, Int)]
padded Proxy ('(l, h) : ps)
_ Proxy (s : sh)
_ = (forall i. (KnownNat l, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @l, forall i. (KnownNat h, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @h) (Int, Int) -> [(Int, Int)] -> [(Int, Int)]
forall a. a -> [a] -> [a]
: Proxy ps -> Proxy sh -> [(Int, Int)]
forall (ps :: [(Nat, Nat)]) (sh :: [Nat]) (sh' :: [Nat]).
Padded ps sh sh' =>
Proxy ps -> Proxy sh -> [(Int, Int)]
padded (Proxy ps
forall k (t :: k). Proxy t
Proxy :: Proxy ps) (Proxy sh
forall k (t :: k). Proxy t
Proxy :: Proxy sh)

-----------------

class Permutation (is :: [Nat])
instance (AllElem is (Count 0 is)) => Permutation is

type family Count (i :: Nat) (xs :: [Nat]) :: [Nat] where
  Count i '[] = '[]
  Count i (x ': xs) = i ': Count (i+1) xs

class AllElem (is :: [Nat]) (ns :: [Nat])
instance AllElem '[] ns
instance (Elem i ns, AllElem is ns) => AllElem (i ': is) ns
class Elem (i :: Nat) (ns :: [Nat])
instance (Elem' (CmpNat i n) i ns) => Elem i (n : ns)
class Elem' (e :: Ordering) (i :: Nat) (ns :: [Nat])
instance Elem' 'EQ i ns
instance (Elem i ns) => Elem' 'LT i ns
instance (Elem i ns) => Elem' 'GT i ns

type Permute (is :: [Nat]) (xs :: [Nat]) = Permute' is (Take (Rank is) xs) ++ Drop (Rank is) xs

type family Permute' (is :: [Nat]) (xs :: [Nat]) where
  Permute' '[] xs = '[]
  Permute' (i ': is) xs = Index xs i ': Permute' is xs

type family Index (xs :: [Nat]) (i :: Nat) where
  Index (x : xs) 0 = x
  Index (x : xs) i = Index xs (i-1)

class ValidDims (rs :: [Nat]) (sh :: [Nat])
instance (AllElem rs (Count 0 sh)) => ValidDims rs sh

-----------------

class Window (ws :: [Nat]) (ss :: [Nat]) (rs :: [Nat]) | ws ss -> rs
instance (Window' ws ws ss rs) => Window ws ss rs

class Window' (ows :: [Nat]) (ws :: [Nat]) (ss :: [Nat]) (rs :: [Nat]) | ows ws ss -> rs
instance ((ows ++ ss) ~ rs) => Window' ows '[] ss rs
instance (Window' ows ws ss rs, w <= s, ((s+1)-w) ~ r) => Window' ows (w ': ws) (s ': ss) (r ': rs)

-----------------

class Stride (ts :: [Nat]) (ss :: [Nat]) (rs :: [Nat]) | ts ss -> rs
instance Stride '[] ss ss
instance (Stride ts ss rs, DivRoundUp s t ~ r) => Stride (t ': ts) (s ': ss) (r ': rs)

-----------------

class Slice (ls :: [(Nat,Nat)]) (ss :: [Nat]) (rs :: [Nat]) | ls ss -> rs where
  sliceOffsets :: Proxy ls -> Proxy ss -> [Int]
instance Slice '[] ss ss where
  sliceOffsets :: Proxy '[] -> Proxy ss -> [Int]
sliceOffsets Proxy '[]
_ Proxy ss
_ = []
instance (Slice ls ss rs, (o+n) <= s, KnownNat o) => Slice ('(o,n) ': ls) (s ': ss) (n ': rs) where
  sliceOffsets :: Proxy ('(o, n) : ls) -> Proxy (s : ss) -> [Int]
sliceOffsets Proxy ('(o, n) : ls)
_ Proxy (s : ss)
_ = forall i. (KnownNat o, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @o Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Proxy ls -> Proxy ss -> [Int]
forall (ls :: [(Nat, Nat)]) (ss :: [Nat]) (rs :: [Nat]).
Slice ls ss rs =>
Proxy ls -> Proxy ss -> [Int]
sliceOffsets (Proxy ls
forall k (t :: k). Proxy t
Proxy :: Proxy ls) (Proxy ss
forall k (t :: k). Proxy t
Proxy :: Proxy ss)


-----------------
-- Shape extraction

class (Typeable s) => Shape (s :: [Nat]) where
  shapeP :: Proxy s -> [Int]
  sizeP  :: Proxy s -> Int

instance Shape '[] where
  {-# INLINE shapeP #-}
  shapeP :: Proxy '[] -> [Int]
shapeP Proxy '[]
_ = []
  {-# INLINE sizeP #-}
  sizeP :: Proxy '[] -> Int
sizeP  Proxy '[]
_ = Int
1

instance forall n s . (Shape s, KnownNat n) => Shape (n ': s) where
  {-# INLINE shapeP #-}
  shapeP :: Proxy (n : s) -> [Int]
shapeP Proxy (n : s)
_ = forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Proxy s -> [Int]
forall (s :: [Nat]). Shape s => Proxy s -> [Int]
shapeP (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s)
  {-# INLINE sizeP #-}
  sizeP :: Proxy (n : s) -> Int
sizeP  Proxy (n : s)
_ = forall i. (KnownNat n, Num i) => i
forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Proxy s -> Int
forall (s :: [Nat]). Shape s => Proxy s -> Int
sizeP  (Proxy s
forall k (t :: k). Proxy t
Proxy :: Proxy s)

{-# INLINE shapeT #-}
shapeT :: forall sh . (Shape sh) => [Int]
shapeT :: [Int]
shapeT = Proxy sh -> [Int]
forall (s :: [Nat]). Shape s => Proxy s -> [Int]
shapeP (Proxy sh
forall k (t :: k). Proxy t
Proxy :: Proxy sh)

{-# INLINE sizeT #-}
sizeT :: forall sh . (Shape sh) => Int
sizeT :: Int
sizeT = Proxy sh -> Int
forall (s :: [Nat]). Shape s => Proxy s -> Int
sizeP (Proxy sh
forall k (t :: k). Proxy t
Proxy :: Proxy sh)

-- | Turn a dynamic shape back into a type level shape.
-- @withShape sh shapeP == sh@
withShapeP :: [Int] -> (forall sh . (Shape sh) => Proxy sh -> r) -> r
withShapeP :: [Int] -> (forall (sh :: [Nat]). Shape sh => Proxy sh -> r) -> r
withShapeP [] forall (sh :: [Nat]). Shape sh => Proxy sh -> r
f = Proxy '[] -> r
forall (sh :: [Nat]). Shape sh => Proxy sh -> r
f (Proxy '[]
forall k (t :: k). Proxy t
Proxy :: Proxy ('[] :: [Nat]))
withShapeP (Int
n:[Int]
ns) forall (sh :: [Nat]). Shape sh => Proxy sh -> r
f =
  case Integer -> Maybe SomeNat
someNatVal (Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
n) of
    Just (SomeNat (Proxy n
_ :: Proxy n)) -> [Int] -> (forall (sh :: [Nat]). Shape sh => Proxy sh -> r) -> r
forall r.
[Int] -> (forall (sh :: [Nat]). Shape sh => Proxy sh -> r) -> r
withShapeP [Int]
ns (\ (Proxy sh
_ :: Proxy ns) -> Proxy (n : sh) -> r
forall (sh :: [Nat]). Shape sh => Proxy sh -> r
f (Proxy (n : sh)
forall k (t :: k). Proxy t
Proxy :: Proxy (n ': ns)))
    Maybe SomeNat
_ -> [Char] -> r
forall a. HasCallStack => [Char] -> a
error ([Char] -> r) -> [Char] -> r
forall a b. (a -> b) -> a -> b
$ [Char]
"withShape: bad size " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
n

withShape :: [Int] -> (forall sh . (Shape sh) => r) -> r
withShape :: [Int] -> (forall (sh :: [Nat]). Shape sh => r) -> r
withShape [Int]
sh forall (sh :: [Nat]). Shape sh => r
f = [Int] -> (forall (sh :: [Nat]). Shape sh => Proxy sh -> r) -> r
forall r.
[Int] -> (forall (sh :: [Nat]). Shape sh => Proxy sh -> r) -> r
withShapeP [Int]
sh (\ (Proxy sh
_ :: Proxy sh) -> Shape sh => r
forall (sh :: [Nat]). Shape sh => r
f @sh)

-----------------

-- | Using the dimension indices /ds/, can /sh/ be broadcast into shape /sh'/?
class Broadcast (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]) where
  broadcasting :: [Bool]
instance (Broadcast' 0 ds sh sh') => Broadcast ds sh sh' where
  broadcasting :: [Bool]
broadcasting = Broadcast' 0 ds sh sh' => [Bool]
forall (i :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]).
Broadcast' i ds sh sh' =>
[Bool]
broadcasting' @0 @ds @sh @sh'

class Broadcast' (i :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]) where
  broadcasting' :: [Bool]
instance Broadcast' i '[] '[] '[] where
  broadcasting' :: [Bool]
broadcasting' = []
instance (Broadcast' i '[] '[] sh') => Broadcast' i '[] '[] (s : sh') where
  broadcasting' :: [Bool]
broadcasting' = Bool
True Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: Broadcast' i '[] '[] sh' => [Bool]
forall (i :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]).
Broadcast' i ds sh sh' =>
[Bool]
broadcasting' @i @'[] @'[] @sh'
instance (TypeError ('Text "Too few dimension indices")) => Broadcast' i '[] (s ': sh) sh' where
  broadcasting' :: [Bool]
broadcasting' = [Bool]
forall a. HasCallStack => a
undefined
instance (TypeError ('Text "Too many dimensions indices")) => Broadcast' i (d ': ds) '[] sh' where
  broadcasting' :: [Bool]
broadcasting' = [Bool]
forall a. HasCallStack => a
undefined
instance (TypeError ('Text "Too few result dimensions")) => Broadcast' i (d ': ds) (s ': sh) '[] where
  broadcasting' :: [Bool]
broadcasting' = [Bool]
forall a. HasCallStack => a
undefined
instance (Broadcast'' (CmpNat i d) i d ds (s ': sh) (s' ': sh')) => Broadcast' i (d ': ds) (s ': sh) (s' ': sh') where
  broadcasting' :: [Bool]
broadcasting' = Broadcast'' (CmpNat i d) i d ds (s : sh) (s' : sh') => [Bool]
forall (o :: Ordering) (i :: Nat) (d :: Nat) (ds :: [Nat])
       (sh :: [Nat]) (sh' :: [Nat]).
Broadcast'' o i d ds sh sh' =>
[Bool]
broadcasting'' @(CmpNat i d) @i @d @ds @(s ': sh) @(s' ': sh')

class Broadcast'' (o :: Ordering) (i :: Nat) (d :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]) where
  broadcasting'' :: [Bool]
instance (Broadcast' (i+1)       ds  sh rsh) => Broadcast'' 'EQ i d ds (s ': sh) (s ': rsh) where
  broadcasting'' :: [Bool]
broadcasting'' = Bool
False Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: Broadcast' (i + 1) ds sh rsh => [Bool]
forall (i :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]).
Broadcast' i ds sh sh' =>
[Bool]
broadcasting' @(i+1) @ds @sh @rsh
instance (Broadcast' (i+1) (d ': ds) sh rsh) => Broadcast'' 'LT i d ds sh (s' ': rsh) where
  broadcasting'' :: [Bool]
broadcasting'' = Bool
True Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: Broadcast' (i + 1) (d : ds) sh rsh => [Bool]
forall (i :: Nat) (ds :: [Nat]) (sh :: [Nat]) (sh' :: [Nat]).
Broadcast' i ds sh sh' =>
[Bool]
broadcasting' @(i+1) @(d ': ds) @sh @rsh
instance (TypeError ('Text "unordered dimensions")) => Broadcast'' 'GT i d ds sh rsh where
  broadcasting'' :: [Bool]
broadcasting'' = [Bool]
forall a. HasCallStack => a
undefined