module NN.Backend.Caffe(caffePasses, middleEnd, backend, addConnection, addLabels, optimizeInPlaceLayer) where

import           Gen.Caffe.NetParameter     as NP

import           Data.Graph.Inductive.Query
import           Gen.Caffe.LayerParameter   as LP

import           Control.Lens
import           Data.Char
import qualified Data.Foldable              as F
import           Data.Graph.Inductive.Graph hiding ((&))
import qualified Data.Graph.Inductive.Graph as G
import           Data.Maybe
import qualified Data.Sequence              as S

import           Text.Printf
import           Text.ProtocolBuffers       as P

import           NN.DSL
import           NN.Passes

caffePasses :: [Pass]
caffePasses = [addConnection, addLabels] ++ optimizeInPlaceLayer ReLU ++ optimizeInPlaceLayer Dropout

middleEnd :: Net -> Net
middleEnd = optimizeWith caffePasses


layerName :: LayerParameter -> Int -> Utf8
layerName l i = printf "%s_%d" (type' l & fromJust & toString & map toLower) i & s

backend :: Net -> NetParameter
backend gr = def & _layer <>~ S.fromList (topsort' gr)

addLabels :: Pass
addLabels (_, _, lp) = update (layerTy lp)
    where
      update Data = lp & LP._top <>~ S.singleton (s "label")
      update SoftmaxWithLoss = lp & LP._bottom <>~ S.singleton (s "label")
      update Accuracy = lp & LP._bottom <>~ S.singleton (s "label")
      update _ = lp

-- |If our layerTy is the given layer that is performed in-place, then
-- update `top` to point to `bottom`.
-- If any of our parents are performed in-place, update `bottom` to
-- point to our parents `top`
optimizeInPlaceLayer :: LayerTy -> [Pass]
optimizeInPlaceLayer layerTy' = [updateIfInPlace, updateIfParentInPlace] where
    inPlace lp = layerTy lp == layerTy'
    inPlaceParents gr i = filter inPlace . map fst $ pres gr i

    updateIfInPlace (_, i, lp) =
        case (layerTy lp == layerTy', F.toList (top lp)) of
          (True, [_]) -> lp & LP._top .~ bottom lp
          (True, _) -> error $ printf "Can only have one output for an in-place layer" ++ show (layerName lp i)
          (False, _) -> lp

    updateIfParentInPlace :: Pass
    updateIfParentInPlace (gr, i, lp) =
        case updateFromParents (gr, i, lp) of
          Left e -> error e
          Right lp' -> lp'

    updateFromParents :: (Net, Node, LayerParameter) -> Either String LayerParameter
    updateFromParents (gr, i, lp) =
       case (inPlaceParents gr i, F.toList (bottom lp)) of
         ([], _) -> Right lp
         (parents, bottoms) ->
             -- TODO this is super dodgy and incorrect in the general
             -- case (there are some weird invariants we rely on), but it works for now.
             if length parents /= length bottoms
             then Left $ printf "Must have all parents in-place for in-place optimizations" ++ show (layerName lp i)
             else let parentTops = F.concatMap (F.toList . LP.top) parents in
                  if length parentTops == length ((F.toList . LP.bottom) lp)
                  then Right $ lp & LP._bottom .~ S.fromList parentTops
                  else Left $ error "asdf"

labelled :: Net -> [Node] -> [(LayerParameter, Node)]
labelled gr = map (\ j -> (lab' (context gr j), j))

pres :: Net -> Node -> [(LayerParameter, Node)]
pres gr j = labelled gr (G.pre gr j)

addConnection :: Pass
addConnection (gr, i, lp) = lp
                            & LP._name ?~ layerName lp i
                            & LP._bottom .~ S.fromList (map (uncurry layerName) (pres gr i))
                            & LP._top <>~ S.singleton (layerName lp i)