{-# LANGUAGE OverloadedStrings #-}
{-| This module describes the expression structure of a INetwork instance.
-- The INetwork can be structured into a Data structure called CNetwork, with which later
-- to compilation external languages can be done.
-}
module TensorSafe.Compile.Expr (
    DLayer (..),
    CNetwork (..),
    JavaScript (..),
    Python (..),
    Generator,
    generate,
    generateFile
) where

import           Data.Map
import           Data.Text.Lazy as T
import           Formatting
import           Text.Casing    (camel, quietSnake)

-- | Auxiliary data representation of Layers
-- IMPORTANT: If you add new Layers definitions to `TensorSafe.Layers`, you should add
-- the corresponding data structure here for the same layer.
data DLayer = DActivation
            | DAdd
            | DBatchNormalization
            | DConv2D
            | DDense
            | DDropout
            | DFlatten
            | DGlobalAvgPooling2D
            | DInput
            | DLSTM
            | DMaxPooling
            | DRelu
            | DZeroPadding2D
            deriving Show

-- | Defines the
data CNetwork = CNSequence CNetwork
              | CNAdd CNetwork CNetwork
              | CNCons CNetwork CNetwork
              | CNLayer DLayer (Map String String)
              | CNReturn  -- End of initial sequence network
              | CNNil     -- End of possible nested sequence networks
              deriving Show

-- | Support for JavaScript compilation
data JavaScript = JavaScript deriving Show

-- | Support for Python compilation
data Python = Python deriving Show

-- | Defines how are the layers going to be translated to the domain language
-- This translates DLayer to String for each supported language
class LayerGenerator l where
    generateName :: l -> DLayer -> String

instance LayerGenerator JavaScript where
    generateName _ DActivation         = "activation"
    generateName _ DAdd                = "addStrict"
    generateName _ DBatchNormalization = "batchNormalization"
    generateName _ DConv2D             = "conv2d"
    generateName _ DDense              = "dense"
    generateName _ DDropout            = "dropout"
    generateName _ DFlatten            = "flatten"
    generateName _ DGlobalAvgPooling2D = "globalAvgeragePooling2D"
    generateName _ DInput              = "input"
    generateName _ DLSTM               = "lstm"
    generateName _ DMaxPooling         = "maxPooling2d"
    generateName _ DRelu               = "reLU"
    generateName _ DZeroPadding2D      = "zeroPadding2D"

instance LayerGenerator Python where
    generateName _ DActivation         = "Activation"
    generateName _ DAdd                = "add"
    generateName _ DBatchNormalization = "BatchNormalization"
    generateName _ DConv2D             = "Conv2D"
    generateName _ DDense              = "Dense"
    generateName _ DDropout            = "Dropout"
    generateName _ DFlatten            = "Flatten"
    generateName _ DGlobalAvgPooling2D = "GlobalAvgeragePooling2D"
    generateName _ DInput              = "Input"
    generateName _ DLSTM               = "LSTM"
    generateName _ DMaxPooling         = "MaxPool2D"
    generateName _ DRelu               = "ReLu"
    generateName _ DZeroPadding2D      = "ZeroPadding2D"

-- | Class that defines which languages are supported for CNetworks generation to text
class Generator l where

    -- | Adds supports for a language. Generates a CNetwork to Text
    generate :: l -> CNetwork -> Text

    -- | Similar to 'generate', but also adds necessary header and module lines of text so as to
    -- have the CNetwork compiled at a separate file.
    generateFile :: l -> CNetwork -> Text

instance Generator JavaScript where
    generate l =
        T.intercalate "\n" . generateJS
        where
            generateJS :: CNetwork -> [Text]
            generateJS (CNSequence cn)  = ["var model = tf.sequential();"] ++ generateJS cn
            generateJS (CNCons cn1 cn2) = (generateJS cn1) ++ (generateJS cn2)
            generateJS CNNil = []
            generateJS CNReturn = []
            generateJS (CNLayer layer params) =
                [format
                    ("model.add(tf.layers." % string % "(" % string % "));")
                    (generateName l layer)
                    (paramsToJS params)
                ]

    generateFile l cn =
        startCode `append` (generate l cn) `append` endCode
        where
            startCode :: Text
            startCode = T.intercalate "\n"
                [ "// Autogenerated code"
                , "var tf = require(\"@tensorflow/tfjs\");"
                , "function model() {"
                , "\n"
                ]

            endCode :: Text
            endCode = T.intercalate "\n"
                [ "\n"
                , "return model;"
                , "}"
                , "\n"
                , "module.exports = model();"
                ]

-- | Converts a map to a parameter object in JavaScript
paramsToJS :: Map String String -> String
paramsToJS m =
    (foldrWithKey showParam "{ " m) ++ "}"
    where
        showParam :: String -> String -> String -> String
        showParam key value accum = accum ++ (camel key) ++ ": " ++ value ++ ", "

instance Generator Python where
    generate l =
        T.intercalate "\n" . generatePy
        where
            generatePy :: CNetwork -> [Text]
            generatePy (CNSequence cn)  = ["model = tf.keras.models.Sequential()"] ++ generatePy cn
            generatePy (CNCons cn1 cn2) = (generatePy cn1) ++ (generatePy cn2)
            generatePy CNNil = []
            generatePy CNReturn = []
            generatePy (CNLayer layer params) =
                [format
                    ("model.add(tf.layers." % string % "(" % string % "))")
                    (generateName l layer)
                    (paramsToPython params)]

    generateFile l cn =
        startCode `append` (generate l cn)
        where
            startCode :: Text
            startCode = T.intercalate "\n"
                [ "// Autogenerated code"
                , "import tensorflow as tf"
                , "\n"
                ]

-- | Converts a map to keyword arguments in Python
paramsToPython :: Map String String -> String
paramsToPython =
    foldrWithKey showParam ""
    where
        showParam :: String -> String -> String -> String
        showParam key value accum = accum ++ (transform key) ++ "=" ++ value ++ ", "

        -- | Translates keys to python keys of layers
        --
        --   There are some minor changes in names of keys for layers with respect to JS.
        --   Those changes should be delcared here. For most of the keys, transforming them to
        --   snake case does the trick.
        transform :: String -> String
        transform key
            | key == "inputDim" = "input_shape"
            | otherwise         = quietSnake key