module NN.Backend.Caffe(caffePasses, middleEnd, backend, addConnection, addLabels, optimizeInPlaceLayer) where
import Gen.Caffe.NetParameter as NP
import Data.Graph.Inductive.Query
import Gen.Caffe.LayerParameter as LP
import Control.Lens
import Data.Char
import qualified Data.Foldable as F
import Data.Graph.Inductive.Graph hiding ((&))
import qualified Data.Graph.Inductive.Graph as G
import Data.Maybe
import qualified Data.Sequence as S
import Text.Printf
import Text.ProtocolBuffers as P
import NN.DSL
import NN.Passes
caffePasses :: [Pass]
caffePasses = [addConnection, addLabels] ++ optimizeInPlaceLayer ReLU ++ optimizeInPlaceLayer Dropout
middleEnd :: Net -> Net
middleEnd = optimizeWith caffePasses
layerName :: LayerParameter -> Int -> Utf8
layerName l i = printf "%s_%d" (type' l & fromJust & toString & map toLower) i & s
backend :: Net -> NetParameter
backend gr = def & _layer <>~ S.fromList (topsort' gr)
addLabels :: Pass
addLabels (_, _, lp) = update (layerTy lp)
where
update Data = lp & LP._top <>~ S.singleton (s "label")
update SoftmaxWithLoss = lp & LP._bottom <>~ S.singleton (s "label")
update Accuracy = lp & LP._bottom <>~ S.singleton (s "label")
update _ = lp
optimizeInPlaceLayer :: LayerTy -> [Pass]
optimizeInPlaceLayer layerTy' = [updateIfInPlace, updateIfParentInPlace] where
inPlace lp = layerTy lp == layerTy'
inPlaceParents gr i = filter inPlace . map fst $ pres gr i
updateIfInPlace (_, i, lp) =
case (layerTy lp == layerTy', F.toList (top lp)) of
(True, [_]) -> lp & LP._top .~ bottom lp
(True, _) -> error $ printf "Can only have one output for an in-place layer" ++ show (layerName lp i)
(False, _) -> lp
updateIfParentInPlace :: Pass
updateIfParentInPlace (gr, i, lp) =
case updateFromParents (gr, i, lp) of
Left e -> error e
Right lp' -> lp'
updateFromParents :: (Net, Node, LayerParameter) -> Either String LayerParameter
updateFromParents (gr, i, lp) =
case (inPlaceParents gr i, F.toList (bottom lp)) of
([], _) -> Right lp
(parents, bottoms) ->
if length parents /= length bottoms
then Left $ printf "Must have all parents in-place for in-place optimizations" ++ show (layerName lp i)
else let parentTops = F.concatMap (F.toList . LP.top) parents in
if length parentTops == length ((F.toList . LP.bottom) lp)
then Right $ lp & LP._bottom .~ S.fromList parentTops
else Left $ error "asdf"
labelled :: Net -> [Node] -> [(LayerParameter, Node)]
labelled gr = map (\ j -> (lab' (context gr j), j))
pres :: Net -> Node -> [(LayerParameter, Node)]
pres gr j = labelled gr (G.pre gr j)
addConnection :: Pass
addConnection (gr, i, lp) = lp
& LP._name ?~ layerName lp i
& LP._bottom .~ S.fromList (map (uncurry layerName) (pres gr i))
& LP._top <>~ S.singleton (layerName lp i)