{-# LANGUAGE TypeOperators #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ConstraintKinds #-} module MXNet.NN.Optimizer ( Optimizer(..), OptArgsCst, SGD, ADAM ) where import qualified Data.HashMap.Strict as M import MXNet.Core.Base.NDArray (NDArray) import qualified MXNet.Core.Base.NDArray as A import MXNet.Core.Base.HMap import MXNet.Core.Base.Internal.TH.NDArray as A import Data.IORef -- | Constraint of using an optimizer type OptArgsCst opt args = (ShowKV args, MatchKVList args (OptArgsList opt)) -- | Abstract Optimizer type class class Optimizer (opt :: * -> [KV *] -> *) where -- | Specific constraints of the optimizer type OptArgsList opt :: [KV *] -- | make the optimizer makeOptimizer :: OptArgsCst opt args => Float -> HMap args -> IO (opt dtype args) -- | run the optimizer with the input & expected tensor optimize :: OptArgsCst opt args => opt dtype args -> String -> NDArray dytpe -> NDArray dtype -> IO (NDArray dtype) -- | SGD optimizer data SGD dtype args = SGD Float (HMap args) instance Optimizer SGD where type OptArgsList SGD = '["wd" ':= Float, "rescale_grad" ':= Float, "clip_gradient" ':= Float] makeOptimizer lr args = return $ SGD lr args optimize (SGD lr args) _ weight gradient = A.NDArray <$> A.sgd_update (A.getHandle weight) (A.getHandle gradient) lr args -- | ADAM optmizer data ADAM dtype args = ADAM Float (HMap args) (IORef (M.HashMap String (NDArray dtype, NDArray dtype))) instance Optimizer ADAM where type OptArgsList ADAM = '["beta1" ':= Float, "beta2" ':= Float, "epsilon" ':= Float, "wd" ':= Float, "rescale_grad" ':= Float, "clip_gradient" ':= Float] makeOptimizer lr args = do empty <- newIORef M.empty return $ ADAM lr args empty optimize (ADAM lr args emaref) symbol weight gradient = do ema <- readIORef emaref (moving_avg, moving_var) <- case M.lookup symbol ema of Nothing -> do avg <- A.zeros_like (A.getHandle weight) var <- A.zeros_like (A.getHandle weight) writeIORef emaref (M.insert symbol (A.NDArray avg, A.NDArray var) ema) return (avg, var) Just (a,v) -> return (A.getHandle a, A.getHandle v) A.NDArray <$> adam_update (A.getHandle weight) (A.getHandle gradient) moving_avg moving_var lr args