{-# LANGUAGE TemplateHaskell #-}
module MXNet.NN.EvalMetric where
import Data.IORef
import Control.Exception.Base (Exception)
import Control.Monad.Trans.Resource (MonadThrow(..))
import Data.Typeable (Typeable)
import Control.Monad
import Control.Monad.IO.Class (MonadIO, liftIO)
import Text.Printf (printf)
import qualified Data.Vector.Storable as SV
import Control.Lens (makeLenses)
import MXNet.Core.Base
import qualified MXNet.Core.Base.NDArray as A
import qualified MXNet.Core.Base.Internal.TH.NDArray as A
data Metric dytpe method = Metric {
_metric_name :: String,
_metric_labelname :: [String],
_metric_instance :: IORef Int,
_metric_sum :: IORef dytpe
}
makeLenses ''Metric
newMetric :: (DType dtype, MonadIO m) => method -> String -> [String] -> m (Metric dtype method)
newMetric _ name labels = do
a <- liftIO $ newIORef 0
b <- liftIO $ newIORef 0
return $ Metric name labels a b
resetMetric :: (DType dtype, MonadIO m) => Metric dtype method -> m ()
resetMetric metric = liftIO $ do
writeIORef (_metric_sum metric) 0
writeIORef (_metric_instance metric) 0
getMetric :: (DType dtype, MonadIO m) => Metric dtype method -> m Float
getMetric metric = do
s <- liftIO $ readIORef (_metric_sum metric)
n <- liftIO $ readIORef (_metric_instance metric)
return $ realToFrac s / fromIntegral n
formatMetric :: (DType dtype, MonadIO m) => Metric dtype method -> m String
formatMetric metric = do
e <- getMetric metric
return $ printf "<%s: %0.3f>" (_metric_name metric) e
class EvalMetricMethod method where
evaluate :: DType dtype => Metric dtype method -> A.NDArray dtype -> A.NDArray dtype -> IO ()
data CrossEntropy = CrossEntropy
instance EvalMetricMethod CrossEntropy where
evaluate metric preds label = do
(n1, shp1) <- A.ndshape preds
(n2, shp2) <- A.ndshape label
when (n1 /= 2 || n2 /= 1 || head shp1 /= head shp2) (throwM InvalidInput)
preds_may_copy <- do
c1 <- context preds
c2 <- context label
if c1 == c2
then return preds
else do
(_, preds_shap) <- ndshape preds
preds_copy <- A.makeEmptyNDArray preds_shap c2 False
A._copyto' (A.getHandle preds) [A.getHandle preds_copy] :: IO ()
return preds_copy
predprj <- A.pick (A.getHandle preds_may_copy) (A.getHandle label) nil
predlog <- A.log predprj
loss <- A.sum predlog nil >>= A.items . A.NDArray
modifyIORef (_metric_sum metric) (+ (negate $ loss SV.! 0))
modifyIORef (_metric_instance metric) (+ head shp1)
data EvalMetricExc = InvalidInput
deriving (Show, Typeable)
instance Exception EvalMetricExc