{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DuplicateRecordFields #-}
module DataRobot.PredictResponse
  ( PredictError(..)
  , PredictResult(..)
  , PredictionValue(..)
  , parseResponse
  , predictionValue
  ) where
import Control.Applicative ((<|>))
import Control.Monad.Catch (Exception)
import Data.Aeson (FromJSON(..), ToJSON, Value(..), decode, defaultOptions, genericParseJSON, withObject, (.:), eitherDecode)
import Data.Aeson.Types (Options(..), typeMismatch)
import Data.List (find)
import Data.Maybe (fromMaybe, maybe)
import Data.Text (Text)
import Data.String.Conversions (cs)
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import Safe (headMay)

import Lens.Micro ((^.))
import Network.Wreq (Response, responseBody, responseHeader)
import Data.ByteString.Lazy (ByteString)


underscorePrefixOptions :: Options
underscorePrefixOptions =
  defaultOptions { fieldLabelModifier = dropWhile (== '_') }

type Code = Int

data PredictError
    = APIError Code Text
    | MissingPrediction
    deriving (Typeable, Show, Generic)

instance Exception PredictError
instance ToJSON PredictError


-- Datarobot successful response
data ResponseSuccess = ResponseSuccess
  { _data :: [Prediction]
  } deriving (Eq, Show, Generic)

instance FromJSON ResponseSuccess where
  parseJSON = genericParseJSON underscorePrefixOptions


-- Datarobot failure response
data ResponseFailure = ResponseFailure
  { _message :: Text
  } deriving (Show, Eq, Generic)

instance FromJSON ResponseFailure where
  parseJSON = genericParseJSON underscorePrefixOptions

-- Datarobot response
data ResponseData
  = ResponseData (Either ResponseFailure ResponseSuccess)
  deriving (Eq, Show)

instance FromJSON ResponseData where
  parseJSON v = ResponseData <$>
    ((Right <$> parseJSON v) <|> (Left <$> parseJSON v))

-- A single prediction value
data PredictionValue = PredictionValue
  { label :: Text
  , value :: Float
  } deriving (Eq, Show, Generic)

instance ToJSON PredictionValue
instance FromJSON PredictionValue where
  parseJSON = withObject "prediction_value" $ \o -> do
      value' <- o .: "value"
      label' <- labelText =<< o .: "label"
      return $ PredictionValue label' value'
    where
      -- Always treat the label as text even though the JSON also allows numbers
      -- This makes key-based lookup easier on the API consumer
      labelText (Number n) = pure $ (cs .show) n
      labelText (String s) = pure s
      labelText invalid    = typeMismatch "label" invalid


-- Combination of prediction values and prediction label
data Prediction = Prediction
  { _prediction       :: Value  -- label or float
  , _predictionValues :: Maybe [PredictionValue]
  } deriving (Eq, Show, Generic)

instance ToJSON Prediction
instance FromJSON Prediction where
  parseJSON = genericParseJSON underscorePrefixOptions



-- | Result from the prediction
data PredictResult = PredictResult
  { prediction       :: Value
  , predictionTimeMs :: Float
  , predictionValues :: Maybe [PredictionValue]
  } deriving (Show, Eq, Generic)

instance ToJSON PredictResult


-- Create a result for a data robot response
handleResponse :: Float -> ResponseData -> Either PredictError PredictResult
handleResponse et (ResponseData (Right rs)) =
    maybe (Left MissingPrediction) Right $ do
        p <- headMay (_data rs)
        pure PredictResult
            { prediction       = _prediction p
            , predictionValues = _predictionValues p
            , predictionTimeMs = et
            }
handleResponse _ (ResponseData (Left err)) =
    responseFailure $ _message err

-- Create a result for a failed response
responseFailure :: Text -> Either PredictError PredictResult
responseFailure e = Left $ APIError 422 e

-- Parse the entire prediction response
-- This is needed because some of the data is delivered in the body and some is delivered via headers
parseResponse :: Response ByteString -> Either PredictError PredictResult
parseResponse r = do
    either (responseFailure . cs) (handleResponse tm) $ eitherDecode b
    where
      b  = r ^. responseBody
      et = r ^. responseHeader "X-DataRobot-Execution-Time"
      tm = fromMaybe 0.0 $ decode (cs et)

-- Find a prediction value probability from a given label
predictionValue :: Text -> PredictResult -> Maybe Float
predictionValue c r = do
    ps <- predictionValues r
    pd <- find ((== c) . label) ps
    pure $ value pd