{-# LANGUAGE TemplateHaskell #-}

module MXNet.NN.EvalMetric where



import Data.IORef

import Control.Exception.Base (Exception)

import Control.Monad.Trans.Resource (MonadThrow(..))

import Data.Typeable (Typeable)

import Control.Monad

import Control.Monad.IO.Class (MonadIO, liftIO)

import Text.Printf (printf)

import qualified Data.Vector.Storable as SV

import Control.Lens (makeLenses)

import MXNet.Core.Base

import qualified MXNet.Core.Base.NDArray as A

import qualified MXNet.Core.Base.Internal.TH.NDArray as A



-- | Metric data

data Metric dytpe method = Metric {

    _metric_name :: String,

    _metric_labelname :: [String],

    _metric_instance :: IORef Int,

    _metric_sum :: IORef dytpe

}

makeLenses ''Metric



-- | create a new metric data

newMetric :: (DType dtype, MonadIO m) => method -> String -> [String] -> m (Metric dtype method)

newMetric _ name labels = do

    a <- liftIO $ newIORef 0

    b <- liftIO $ newIORef 0

    return $ Metric name labels a b



-- | reset all information

resetMetric :: (DType dtype, MonadIO m) => Metric dtype method -> m ()

resetMetric metric = liftIO $ do

    writeIORef (_metric_sum metric) 0

    writeIORef (_metric_instance metric) 0



-- | get the metric

getMetric :: (DType dtype, MonadIO m) => Metric dtype method -> m Float

getMetric metric = do

    s <- liftIO $ readIORef (_metric_sum metric)

    n <- liftIO $ readIORef (_metric_instance metric)

    return $ realToFrac s / fromIntegral n



-- | format the metric as string

formatMetric :: (DType dtype, MonadIO m) => Metric dtype method -> m String

formatMetric metric = do

    e <- getMetric metric 

    return $ printf "<%s: %0.3f>" (_metric_name metric) e



-- | Abstract Evaluation type class

class EvalMetricMethod method where

    evaluate :: DType dtype => Metric dtype method -> A.NDArray dtype -> A.NDArray dtype -> IO ()



-- | Basic evluation - cross entropy 

data CrossEntropy = CrossEntropy

instance EvalMetricMethod CrossEntropy where

    -- | evaluate the log-loss. 

    -- preds is of shape (batch_size, num_category), each element along the second dimension gives the probability of the category.

    -- label is of shape (batch_size,), each element gives the category number.

    evaluate metric preds label = do

        (n1, shp1) <- A.ndshape preds

        (n2, shp2) <- A.ndshape label

        when (n1 /= 2 || n2 /= 1 || head shp1 /= head shp2) (throwM InvalidInput)

        -- before call pick, we have to make sure preds and label 

        -- are in the same context

        preds_may_copy <- do

            c1 <- context preds

            c2 <- context label

            if c1 == c2 

                then return preds

                else do

                    (_, preds_shap) <- ndshape preds

                    preds_copy <- A.makeEmptyNDArray preds_shap c2 False

                    A._copyto' (A.getHandle preds) [A.getHandle preds_copy] :: IO ()

                    return preds_copy

        predprj <- A.pick (A.getHandle preds_may_copy) (A.getHandle label) nil

        predlog <- A.log predprj

        loss    <- A.sum predlog nil >>= A.items . A.NDArray

        modifyIORef (_metric_sum metric) (+ (negate $ loss SV.! 0))

        modifyIORef (_metric_instance metric) (+ head shp1)



-- | Possible exceptions in evaluation.

data EvalMetricExc = InvalidInput

    deriving (Show, Typeable)

instance Exception EvalMetricExc