{-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Extra.Solver #-} {-| This module is the core of TensorSafe. It defines all Network data structures -- and types functions that respresent Layers modifications of shapes, as well as -- all needed information for compiling the Network structures to CNetworks for later code -- generation. -} -- module TensorSafe.Network ( -- Network (..), -- INetwork (..), -- MkINetwork, -- ValidNetwork, -- mkINetwork, -- toCNetwork -- ) where module TensorSafe.Network where import Data.Kind (Type) import Data.Singletons import GHC.TypeLits as N import GHC.TypeLits.Extra (Div) import TensorSafe.Compile.Expr import TensorSafe.Layer (Layer, compile, layer) import TensorSafe.Layers import TensorSafe.Shape -- | A network that defines a specific sequence of layers data Network :: [Type] -> Type where NNil :: Network '[] (:~~) :: Layer x => !x -> !(Network xs) -> Network (x ': xs) infixr 5 :~~ instance Show (Network '[]) where show NNil = "NNil" instance (Show x, Show (Network xs)) => Show (Network (x ': xs)) where show (x :~~ xs) = show x ++ "\n :~~ " ++ show xs -- | A network that defines a specific sequence of layers with the corresponding shape -- transformation along the network. It's an Instance of a Network: given a Network and a initial -- Shape, this type structure can be generated automatically using the type functions defined in -- this module, like `Out` and `MkINetwork`. data INetwork :: [Type] -> [Shape] -> Type where INNil :: SingI i => INetwork '[] '[i] (:~>) :: (SingI i, SingI h, Layer x) => !x -> !(INetwork xs (h ': hs)) -> INetwork (x ': xs) (i ': h ': hs) infixr 5 :~> instance Show (INetwork '[] '[i]) where show INNil = "NNil" instance (Show x, Show (INetwork xs rs)) => Show (INetwork (x ': xs) (i ': rs)) where show (x :~> xs) = show x ++ "\n :~> " ++ show xs -- | This instance of INetwork as a Layer makes possible nesting INetworks instance ValidNetwork ls ss => Layer (INetwork ls ss) where layer = mkINetwork compile n i = toCNetwork' n True i -- -- COMPUTING RESULTING SHAPES FROM A LIST OF LAYERS. -- -- | Returns the result of applying all the layers transformation to a specific shape. -- Given a list of layers, this returns the expected output for the computation of each layer -- starting with the first layer transforming the `Shape` s. -- For example, if the initial Shape is [28, 28] and the layers are [Relu, Flatten], the result -- will be [784]. type family ComputeOut (layers :: [Type]) (s :: Shape) :: Shape where ComputeOut '[] s = s ComputeOut (l : ls) s = ComputeOut ls (Out l s) -- | Returns a list of shapes describing ALL the transformations applied to a specific shape. -- Given a list of layers return a type with all the Shapes from the initial Shape until the -- last one. In theory, the last Shape should be the same than the ComputeOut function applied -- to this same parameters. type family ComposeOut' (layers :: [Type]) (s :: Shape) :: [Shape] where ComposeOut' '[] s = '[] ComposeOut' (l : ls) s = ((Out l s) ': (ComposeOut' ls (Out l s))) -- | Same than ComposeOut' but the Shape list includes the initial Shape type family ComposeOut (layers :: [Type]) (s :: Shape) :: [Shape] where ComposeOut '[] s = '[] ComposeOut ls s = s ': (ComposeOut' ls s) -- | Compares the layers shape computation and the expected output type family ValidateOutput (layers :: [Type]) (sIn :: Shape) (sOut :: Shape) :: Bool where ValidateOutput ls sIn sOut = ShapeEquals' (ComputeOut ls sIn) sOut -- -- CREATE INETWORK TYPE INSTANCES FROM LIST OF LAYERS AND INTIAL AND ENDING SHAPES -- -- | Creates an INetwork type, and by "unconstrained" I mean that I don't check for an -- expected output type family MkINetworkUnconstrained (layers :: [Type]) (s :: Shape) :: Type where MkINetworkUnconstrained ls s = INetwork ls (ComposeOut ls s) -- | If the second type argument is 'True, then it returns the type t, otherwise it returns -- a default type. Note that for this example, ValidateOutput would raise an exception -- if the expected output and the actual one do not match. type family MaybeType (t :: Type) (b :: Bool) :: Type where MaybeType t 'False = Type -- HACK: ValidateOutput should raise an exception on this case MaybeType t 'True = t -- | Creates an INetwork type validating the the expected output and the computed one match. type family MkINetwork (layers :: [Type]) (sIn :: Shape) (sOut :: Shape) :: Type where MkINetworkUnconstrained ls sIn sOut = MaybeType (INetwork ls (ComposeOut ls sIn)) (ValidateOutput ls sIn sOut) -- -- MAPPING TRANSFORMATIONS OF LAYERS AND SHAPES -- type family MaybeShape (s :: Shape) (b :: Bool) :: Shape where MaybeType s 'False = 'D1 0 -- HACK: ShapeEquals' should raise an exception on this case MaybeType s 'True = s type family Add' (layers :: [Type]) (layers2 :: [Type]) (shape :: Shape) where Add' ls1 _ sIn = ComputeOut ls1 sIn -- | Defines the expected output of a layer -- This type function should be instanciated for each of the Layers defined. type family Out (l :: Type) (s :: Shape) :: Shape where -- -- -- Out (INetwork ls (s : ss)) s = ComputeOut ls s -- -- -- Out (Add ls1 ls2) sIn = Add' ls1 ls2 sIn -- MaybeShape -- (ComputeOut ls1 sIn) -- (ShapeEquals' (ComputeOut ls1 sIn) (ComputeOut ls2 sIn)) -- Validation that computes the same -- Out (Add (INetwork ls (s : ss))) s = ComputeOut ls s -- -- -- Out (BatchNormalization _ _ _) s = s -- -- -- Out (Conv2D 1 1 k k' s s') ('D2 inputRows inputColumns) = ('D2 (1 + (Div (inputRows - k) s)) (1 + (Div (inputColumns - k') s')) ) Out (Conv2D 1 filters k k' s s') ('D2 inputRows inputColumns) = ('D3 (1 + (Div (inputRows - k) s)) (1 + (Div (inputColumns - k') s')) filters ) Out (Conv2D channels 1 k k' s s') ('D3 inputRows inputColumns channels) = ('D2 (1 + (Div (inputRows - k) s)) (1 + (Div (inputColumns - k') s')) ) Out (Conv2D channels filters k k' s s') ('D3 inputRows inputColumns channels) = ('D3 (1 + (Div (inputRows - k) s)) (1 + (Div (inputColumns - k') s')) filters ) -- -- -- Out (Dense i o) ('D1 i) = 'D1 o -- -- -- Out (Dropout rate seed) s = s -- -- -- Out Flatten ('D1 x) = 'D1 x Out Flatten ('D2 x y) = 'D1 (x N.* y) Out Flatten ('D3 x y z) = 'D1 (x N.* y N.* z) -- -- -- Out GlobalAvgPooling2D ('D3 _ _ z) = 'D1 z -- -- -- Out Input s = s -- -- -- Out (LSTM units 'False) _ = 'D1 units Out (LSTM units 'True) ('D2 x _) = 'D2 x units Out (LSTM units 'True) ('D3 x _ _) = 'D2 x units -- -- -- Out (MaxPooling k k' s s') ('D2 inputRows inputColumns) = ('D2 (1 + (Div (inputRows - k) s)) (1 + (Div (inputColumns - k') s')) ) Out (MaxPooling k k' s s') ('D3 inputRows inputColumns channels) = ('D3 (1 + (Div (inputRows - k) s)) (1 + (Div (inputColumns - k') s')) channels ) -- -- -- Out Relu s = s -- -- -- Out Sigmoid s = s -- -- -- Out (ZeroPadding2D padding_rows padding_cols) ('D2 inputRows inputColumns) = ('D2 (inputRows + (2 N.* padding_rows)) (inputColumns + (2 N.* padding_cols))) Out (ZeroPadding2D padding_rows padding_cols) ('D3 inputRows inputColumns channels) = ('D3 (inputRows + (2 N.* padding_rows)) (inputColumns + (2 N.* padding_cols)) channels) -- -- Edge case or not defined raise an error -- Out l sIn = TypeError ( 'Text "Couldn't apply the Layer \"" ':<>: 'ShowType l ':<>: 'Text "\" with the input Shape \"" ':<>: 'ShowType sIn ':<>: 'Text "\"") -- -- INETWORK VALIDATION -- -- | Instanciates a Network after defining a type definition, -- using MkINetworkUnconstrained or MkINetwork, for example. -- After defining a variable with INetwork type, you can instanciate that variable like this: -- ``` -- myNet :: MNIST -- myNet = mkINetwork -- ``` class ValidNetwork (xs :: [Type]) (ss :: [Shape]) where -- | Makes a valid instance of INetwork mkINetwork :: INetwork xs ss {-# MINIMAL mkINetwork #-} instance (SingI i) => ValidNetwork '[] '[i] where mkINetwork = INNil instance ( SingI i , SingI o , Layer x , ValidNetwork xs (o ': rs) , (Out x i) ~ o -- IMPORTANT: validation that the output and the computation of the layer -- will match. Without this constraint we could be able to create an -- instance of ValidNetwork that doesn't satisfies the type constraints -- of MkINetwork for example. ) => ValidNetwork (x ': xs) (i ': o ': rs) where mkINetwork = layer :~> mkINetwork -- -- INETWORK MAPPING TO CNETWORK -- -- | Compilation: Gets the initial shape using Singleton instances. Since this is the function we -- run for transforming an INetwork to CNetwork, the nested argument of `toCNetwork'` is set -- to False. toCNetwork :: forall i x xs ss. ( SingI i , Layer x , ValidNetwork (x ': xs) (i ': ss)) => INetwork (x ': xs) (i ': ss) -> CNetwork toCNetwork n = case (sing :: Sing i) of D1Sing a -> CNSequence (toCNetwork' n False (Just $ show [ natVal a])) D2Sing a b -> CNSequence (toCNetwork' n False (Just $ show [ natVal a , natVal b])) D3Sing a b c -> CNSequence (toCNetwork' n False (Just $ show [ natVal a , natVal b , natVal c])) -- | Helper function for `toCNetwork` toCNetwork' :: INetwork xs ss -> Bool -> Maybe String -> CNetwork toCNetwork' INNil nested _ = if nested then CNNil else CNReturn toCNetwork' (l :~> n) nested inputShape = let compilatedLayer = compile l inputShape compilatedNetwork = toCNetwork' n nested Nothing in CNCons compilatedLayer compilatedNetwork