{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE ImplicitParams #-} {-# LANGUAGE RecordWildCards #-} module ML.DMLC.XGBoost ( module ML.DMLC.XGBoost , module ML.DMLC.XGBoost.FFI ) where import Foundation import Foundation.Collection import Foundation.Numerical import qualified Prelude (fromIntegral, Show(..)) import Control.Exception (assert) import Control.Monad (when, foldM_) import ML.DMLC.XGBoost.FFI import ML.DMLC.XGBoost.Rabit.FFI {------------------------------------------------------------------------------ -- Utility functions. ------------------------------------------------------------------------------} integerToFloat :: Int -> Float integerToFloat = Prelude.fromIntegral -- | Cast floating point output to integer label. valueToLabel :: (IntegralRounding a) => a -> Int32 valueToLabel = roundNearest {-# SPECIALIZE valueToLabel :: Float -> Int32 #-} {-# INLINE valueToLabel #-} compareLabels :: UArray Float -- ^ Wanted -> UArray Float -- ^ Actual output -> Float -- ^ Successfully rate compareLabels wanted actual = integerToFloat sameLength / integerToFloat nLength where sameLabel a b = valueToLabel a == valueToLabel b accLabels (a, b) n = if sameLabel a b then n + 1 else n nLength = let (CountOf k) = min (length wanted) (length actual) in k sameLength = foldr' accLabels 0 $ (zipWith (,) wanted actual :: [(Float, Float)]) compareLabels' :: UArray Float -- ^ Wanted -> UArray Float -- ^ Actual output -> (Float -> Float -> Bool) -- ^ decide whether two given labels are equal -> Float -- ^ Successfully rate compareLabels' wanted actual sameLabel = integerToFloat sameLength / integerToFloat nLength where accLabels (a, b) n = if sameLabel a b then n + 1 else n nLength = let (CountOf k) = min (length wanted) (length actual) in k sameLength = foldr' accLabels 0 $ (zipWith (,) wanted actual :: [(Float, Float)]) {------------------------------------------------------------------------------ -- DMatrix related APIs. ------------------------------------------------------------------------------} {------------------------------------------------------------------------------ -- Booster related APIs. ------------------------------------------------------------------------------} -- | Parameter passed to booster. data BoosterParam = forall a. Show a => Param { paramName :: String , paramValue :: a } -- | Predefined objective functions. -- -- Ref: https://github.com/dmlc/xgboost/blob/master/src/objective/regression_obj.cc data ObjectiveFunction = RegLinear | RegLogistic | BinaryLogistic | BinaryLogitraw | CountPoisson | RegGamma | RegTweedie deriving Eq instance Show ObjectiveFunction where show RegLinear = "reg:linear" show RegLogistic = "reg:logistic" show BinaryLogistic = "binary:logistic" show BinaryLogitraw = "binary:logitraw" show CountPoisson = "count:poisson" show RegGamma = "reg:gamma" show RegTweedie = "reg:tweedie" setBoosterParam :: Booster -> BoosterParam -> IO () setBoosterParam booster Param{..} = setParam booster paramName (show paramValue) newBooster :: (?params :: [BoosterParam]) => [DMatrix] -> IO Booster newBooster dmats = do booster <- xgbBooster dmats setParam booster "seed" "0" forM_ ?params $ \param -> setBoosterParam booster param return booster {------------------------------------------------------------------------------ -- Model train and predict APIs. ------------------------------------------------------------------------------} xgbTrain :: (?params :: [BoosterParam], ?debug :: Bool) => DMatrix -- ^ Data to be trained -> Int32 -- ^ Number of boosting iterations -> IO Booster xgbTrain dtrain rounds = do let nboost = 0 booster <- newBooster [dtrain] version <- loadRabitCheckpoint booster when ?debug $ do wdsize <- rabitGetWordSize assert (wdsize /= 1 || version == 0) $ return () let startIter = version `div` 2 let go (_nboost, _version) i = do _version' <- if _version `mod` 2 == 0 then do updateOneIter booster i dtrain saveRabitCheckpoint booster return (_version + 1) else return _version when ?debug $ do wdsize <- rabitGetWordSize ver <- rabitVersionNumber assert (wdsize == 1 || _version' == ver) $ return () saveRabitCheckpoint booster return (_nboost + 1, _version' + 1) foldM_ go (nboost + startIter, version) [startIter, rounds] return booster xgbPredict :: (?debug :: Bool) => Booster -> DMatrix -> [PredictMask] -> Int32 -- ^ Limit number of trees in the prediction; defaults to 0 (use all trees). -> IO (UArray Float) xgbPredict booster dtest masks nlimit = boosterPredict booster dtest masks nlimit