module Grenade.Core.Network (
Network (..)
, Gradients (..)
, Tapes (..)
, runNetwork
, runGradient
, applyUpdate
, randomNetwork
) where
import Control.Monad.Random ( MonadRandom )
import Data.Singletons
import Data.Singletons.Prelude
import Data.Serialize
import Grenade.Core.Layer
import Grenade.Core.LearningParameters
import Grenade.Core.Shape
data Network :: [*] -> [Shape] -> * where
NNil :: SingI i
=> Network '[] '[i]
(:~>) :: (SingI i, SingI h, Layer x i h)
=> !x
-> !(Network xs (h ': hs))
-> Network (x ': xs) (i ': h ': hs)
infixr 5 :~>
instance Show (Network '[] '[i]) where
show NNil = "NNil"
instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) where
show (x :~> xs) = show x ++ "\n~>\n" ++ show xs
data Gradients :: [*] -> * where
GNil :: Gradients '[]
(:/>) :: UpdateLayer x
=> Gradient x
-> Gradients xs
-> Gradients (x ': xs)
data Tapes :: [*] -> [Shape] -> * where
TNil :: SingI i
=> Tapes '[] '[i]
(:\>) :: (SingI i, SingI h, Layer x i h)
=> !(Tape x i h)
-> !(Tapes xs (h ': hs))
-> Tapes (x ': xs) (i ': h ': hs)
runNetwork :: forall layers shapes.
Network layers shapes
-> S (Head shapes)
-> (Tapes layers shapes, S (Last shapes))
runNetwork =
go
where
go :: forall js ss. (Last js ~ Last shapes)
=> Network ss js
-> S (Head js)
-> (Tapes ss js, S (Last js))
go (layer :~> n) !x =
let (tape, forward) = runForwards layer x
(tapes, answer) = go n forward
in (tape :\> tapes, answer)
go NNil !x
= (TNil, x)
runGradient :: forall layers shapes.
Network layers shapes
-> Tapes layers shapes
-> S (Last shapes)
-> (Gradients layers, S (Head shapes))
runGradient net tapes o =
go net tapes
where
go :: forall js ss. (Last js ~ Last shapes)
=> Network ss js
-> Tapes ss js
-> (Gradients ss, S (Head js))
go (layer :~> n) (tape :\> nt) =
let (gradients, feed) = go n nt
(layer', backGrad) = runBackwards layer tape feed
in (layer' :/> gradients, backGrad)
go NNil TNil
= (GNil, o)
applyUpdate :: LearningParameters
-> Network layers shapes
-> Gradients layers
-> Network layers shapes
applyUpdate rate (layer :~> rest) (gradient :/> grest)
= runUpdate rate layer gradient :~> applyUpdate rate rest grest
applyUpdate _ NNil GNil
= NNil
class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
randomNetwork :: MonadRandom m => m (Network xs ss)
instance SingI i => CreatableNetwork '[] '[i] where
randomNetwork = return NNil
instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': rs)) => CreatableNetwork (x ': xs) (i ': o ': rs) where
randomNetwork = (:~>) <$> createRandom <*> randomNetwork
instance SingI i => Serialize (Network '[] '[i]) where
put NNil = pure ()
get = return NNil
instance (SingI i, SingI o, Layer x i o, Serialize x, Serialize (Network xs (o ': rs))) => Serialize (Network (x ': xs) (i ': o ': rs)) where
put (x :~> r) = put x >> put r
get = (:~>) <$> get <*> get
instance CreatableNetwork sublayers subshapes => UpdateLayer (Network sublayers subshapes) where
type Gradient (Network sublayers subshapes) = Gradients sublayers
runUpdate = applyUpdate
createRandom = randomNetwork
instance (CreatableNetwork sublayers subshapes, i ~ (Head subshapes), o ~ (Last subshapes)) => Layer (Network sublayers subshapes) i o where
type Tape (Network sublayers subshapes) i o = Tapes sublayers subshapes
runForwards = runNetwork
runBackwards = runGradient