{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DuplicateRecordFields #-} module DataRobot.PredictResponse ( PredictError(..) , PredictResult(..) , responseResult , classProbability ) where import Control.Applicative ((<|>)) import Control.Monad.Catch (Exception) import Data.Aeson (FromJSON(..), ToJSON, Value) import Data.HashMap.Strict (HashMap) import qualified Data.HashMap.Strict as HM import Data.Maybe (fromMaybe) import Data.Text (Text) import Data.Typeable (Typeable) import GHC.Generics (Generic) import Safe (headMay) data PredictError = APIError Code Text | MissingPrediction deriving (Typeable, Show) instance Exception PredictError newtype PredictResponse = PredictResponse (Either PredictFailure PredictSuccess ) deriving (Show, Eq) instance FromJSON PredictResponse where parseJSON v = PredictResponse <$> ((Right <$> parseJSON v) <|> (Left <$> parseJSON v)) type Code = Int data PredictFailure = PredictFailure { code :: Code , status :: Text } deriving (Show, Eq, Generic) instance FromJSON PredictFailure data PredictSuccess = PredictSuccess { predictions :: [Prediction] , execution_time :: Float , model_id :: Text -- , task :: Text } deriving (Show, Eq, Generic) instance FromJSON PredictSuccess data Prediction = Prediction { prediction :: Value , class_probabilities :: Maybe (HashMap Text Float) } deriving (Show, Eq, Generic) instance FromJSON Prediction -- | Result from the prediction data PredictResult = PredictResult { prediction :: Value , predictionTimeMs :: Float , modelId :: Text , classProbabilities :: Maybe (HashMap Text Float) } deriving (Show, Eq, Generic) instance ToJSON PredictResult responseResult :: PredictResponse -> Either PredictError PredictResult responseResult (PredictResponse (Right ps)) = fromMaybe (Left MissingPrediction) $ do p <- headMay (predictions ps) return $ Right $ PredictResult { prediction = prediction (p :: Prediction) , predictionTimeMs = execution_time ps , modelId = model_id ps , classProbabilities = class_probabilities p } responseResult (PredictResponse (Left pf)) = Left $ APIError (code pf) (status pf) classProbability :: Text -> PredictResult -> Maybe Float classProbability c r = do cps <- classProbabilities r HM.lookup c cps