{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
#if MIN_VERSION_base(4,12,0)
{-# LANGUAGE NoStarIsType #-}
#endif
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Torch.Models.Vision.LeNet where
import Data.Function ((&))
import Data.Generics.Product.Fields (field)
import Data.Generics.Product.Typed (typed)
import Data.List (intercalate)
import Data.Singletons.Prelude (SBool, sing)
import GHC.Generics (Generic)
import Lens.Micro (Lens', (^.))
import Numeric.Backprop (Backprop, BVar, Reifies, W, (^^.))
import GHC.TypeLits (KnownNat)
import qualified Numeric.Backprop as Bp
import qualified GHC.TypeLits
#ifdef CUDA
import Numeric.Dimensions
import Torch.Cuda.Double as Torch
import Torch.Cuda.Double.NN.Linear
import qualified Torch.Cuda.Double.NN.Conv2d as Conv2d
import qualified Torch.Cuda.Double.NN.Linear as Linear
#else
import Torch.Double as Torch
import Torch.Double.NN.Linear
import qualified Torch.Double.NN.Conv2d as Conv2d
import qualified Torch.Double.NN.Linear as Linear
#endif
import Torch.Initialization
type Flattened ker = (16*ker*ker)
data LeNet ch ker = LeNet
{ _conv1 :: !(Conv2d ch 6 '(ker, ker))
, _conv2 :: !(Conv2d 6 16 '(ker,ker))
, _fc1 :: !(Linear (Flattened ker) 120)
, _fc2 :: !(Linear 120 84)
, _fc3 :: !(Linear 84 10)
} deriving (Generic)
conv1 :: Lens' (LeNet ch ker) (Conv2d ch 6 '(ker, ker))
conv1 = field @"_conv1"
conv2 :: Lens' (LeNet ch ker) (Conv2d 6 16 '(ker,ker))
conv2 = field @"_conv2"
fc1 :: forall ch ker . Lens' (LeNet ch ker) (Linear (Flattened ker) 120)
fc1 = typed @(Linear (Flattened ker) 120)
fc2 :: Lens' (LeNet ch ker) (Linear 120 84)
fc2 = field @"_fc2"
fc3 :: Lens' (LeNet ch ker) (Linear 84 10)
fc3 = field @"_fc3"
instance (KnownDim (Flattened ker), KnownDim ch, KnownDim ker) => Show (LeNet ch ker) where
show (LeNet c1 c2 f1 f2 f3) = intercalate "\n"
#ifdef CUDA
[ "CudaLeNet {"
#else
[ "LeNet {"
#endif
, " conv1 :: " ++ show c1
, " conv2 :: " ++ show c2
, " fc1 :: " ++ show f1
, " fc2 :: " ++ show f2
, " fc3 :: " ++ show f3
, "}"
]
instance (KnownDim (Flattened ker), KnownDim ch, KnownDim ker) => Backprop (LeNet ch ker) where
add a b = LeNet
(Bp.add (_conv1 a) (_conv1 b))
(Bp.add (_conv2 a) (_conv2 b))
(Bp.add (_fc1 a) (_fc1 b))
(Bp.add (_fc2 a) (_fc2 b))
(Bp.add (_fc3 a) (_fc3 b))
one net = LeNet
(Bp.one (net^.conv1))
(Bp.one (net^.conv2))
(Bp.one (net^.fc1) )
(Bp.one (net^.fc2) )
(Bp.one (net^.fc3) )
zero net = LeNet
(Bp.zero (net^.conv1))
(Bp.zero (net^.conv2))
(Bp.zero (net^.fc1) )
(Bp.zero (net^.fc2) )
(Bp.zero (net^.fc3) )
newLeNet :: All KnownDim '[ch,ker,Flattened ker, ker*ker] => Generator -> IO (LeNet ch ker)
newLeNet g = LeNet
<$> newConv2d g
<*> newConv2d g
<*> newLinear g
<*> newLinear g
<*> newLinear g
update net lr grad = LeNet
(Conv2d.update (net^.conv1) lr (grad^.conv1))
(Conv2d.update (net^.conv2) lr (grad^.conv2))
(Linear.update (net^.fc1) lr (grad^.fc1))
(Linear.update (net^.fc2) lr (grad^.fc2))
(Linear.update (net^.fc3) lr (grad^.fc3))
update_ net lr grad = do
(Conv2d.update_ (net^.conv1) lr (grad^.conv1))
(Conv2d.update_ (net^.conv2) lr (grad^.conv2))
(Linear.update_ (net^.fc1) lr (grad^.fc1))
(Linear.update_ (net^.fc2) lr (grad^.fc2))
(Linear.update_ (net^.fc3) lr (grad^.fc3))
lenet lr arch inp
= lenetLayer lr (arch ^^. conv1) inp
& lenetLayer lr (arch ^^. conv2)
& flattenBP
& relu . linear (arch ^^. fc1)
& relu . linear (arch ^^. fc2)
& linear (arch ^^. fc3)
& softmax
lenetLayer
:: forall inp h w ker ow oh s out mow moh step pad
. Reifies s W
=> All KnownDim '[inp,out,ker,(ker*ker)*inp]
=> pad ~ 0
=> step ~ 1
=> SpatialConvolutionC inp h w ker ker step step pad pad oh ow
=> SpatialDilationC oh ow 2 2 2 2 pad pad mow moh 1 1 'True
=> Double
-> BVar s (Conv2d inp out '(ker,ker))
-> BVar s (Tensor '[inp, h, w])
-> BVar s (Tensor '[out, moh, mow])
lenetLayer lr conv inp
= Conv2d.conv2d
(Step2d :: Step2d '(1,1))
(Padding2d :: Padding2d '(0,0))
lr conv inp
& relu
& maxPooling2d
(Kernel2d :: Kernel2d '(2,2))
(Step2d :: Step2d '(2,2))
(Padding2d :: Padding2d '(0,0))
(sing :: SBool 'True)
lenetBatch lr arch inp
= lenetLayerBatch lr (arch ^^. conv1) inp
& lenetLayerBatch lr (arch ^^. conv2)
& flattenBPBatch
& relu . linearBatch (arch ^^. fc1)
& relu . linearBatch (arch ^^. fc2)
& linearBatch (arch ^^. fc3)
& softmaxN (dim :: Dim 1)
lenetLayerBatch
:: forall inp h w ker ow oh s out mow moh step pad batch
. Reifies s W
=> All KnownDim '[batch,inp,out,ker,(ker*ker)*inp]
=> pad ~ 0
=> step ~ 1
=> SpatialConvolutionC inp h w ker ker step step pad pad oh ow
=> SpatialDilationC oh ow 2 2 2 2 pad pad mow moh 1 1 'True
=> Double
-> BVar s (Conv2d inp out '(ker,ker))
-> BVar s (Tensor '[batch, inp, h, w])
-> BVar s (Tensor '[batch, out, moh, mow])
lenetLayerBatch lr conv inp
= Conv2d.conv2dBatch
(Step2d :: Step2d '(1,1))
(Padding2d :: Padding2d '(0,0))
lr conv inp
& relu
& maxPooling2dBatch
(Kernel2d :: Kernel2d '(2,2))
(Step2d :: Step2d '(2,2))
(Padding2d :: Padding2d '(0,0))
(sing :: SBool 'True)