{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-| This module declares all Shape related functions and data structures, as well as all singleton -- instances for the Shape data type. This module was highly influenciated by Grenade, a Haskell -- library for deep learning with dependent types. See: https://github.com/HuwCampbell/grenade -} module TensorSafe.Shape where import Data.Singletons import GHC.TypeLits as N import TensorSafe.Core -- -- Shape definition as in Haskell's Grenade library -- -- | The current shapes we accept. -- at the moment this is just one, two, and three dimensional -- Vectors/Matricies. -- -- These are only used with DataKinds, as Kind `Shape`, with Types 'D1, 'D2, 'D3. data Shape = D1 Nat -- ^ One dimensional vector | D2 Nat Nat -- ^ Two dimensional matrix. Row, Column. | D3 Nat Nat Nat -- ^ Three dimensional matrix. Row, Column, Channels. -- | Concrete data structures for a Shape. -- -- All shapes are held in contiguous memory. -- 3D is held in a matrix (usually row oriented) which has height depth * rows. data S (n :: Shape) where S1D :: ( KnownNat len ) => R len -> S ('D1 len) S2D :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S ('D2 rows columns) S3D :: ( KnownNat rows , KnownNat columns , KnownNat depth , KnownNat (rows N.* depth)) => L (rows N.* depth) columns -> S ('D3 rows columns depth) deriving instance Show (S n) -- Singleton instances. -- Check: http://hackage.haskell.org/package/singletons -- -- These could probably be derived with template haskell, but this seems -- clear and makes adding the KnownNat constraints simple. -- We can also keep our code TH free, which is great. data instance Sing (n :: Shape) where D1Sing :: KnownNat a => Sing a -> Sing ('D1 a) D2Sing :: (KnownNat a, KnownNat b) => Sing a -> Sing b -> Sing ('D2 a b) D3Sing :: (KnownNat a, KnownNat b, KnownNat c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c) instance KnownNat a => SingI ('D1 a) where sing = D1Sing sing instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where sing = D2Sing sing sing instance (KnownNat a, KnownNat b, KnownNat c) => SingI ('D3 a b c) where sing = D3Sing sing sing sing -- | Compares two Shapes at kinds level and returns a Bool kind type family ShapeEquals (sIn :: Shape) (sOut :: Shape) :: Bool where ShapeEquals s s = 'True ShapeEquals _ _ = 'False -- | Same as ShapeEquals, which compares two Shapes at kinds level, but raises a TypeError exception -- if the Shapes are not the equal. type family ShapeEquals' (sIn :: Shape) (sOut :: Shape) :: Bool where ShapeEquals' s s = 'True ShapeEquals' s1 s2 = TypeError ( 'Text "Couldn't match the Shape " ':<>: 'ShowType s1 ':<>: 'Text " with the Shape " ':<>: 'ShowType s2)