-- Copyright (C) 2014  Fraser Tweedale
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

{-|

JOSE error types.

-}
module Crypto.JOSE.Error
  (
    Error(..)
  , AsError(..)

  -- * JOSE compact serialisation errors
  , InvalidNumberOfParts(..), expectedParts, actualParts
  , CompactTextError(..)
  , CompactDecodeError(..)
  , _CompactInvalidNumberOfParts
  , _CompactInvalidText
  ) where

import Data.Semigroup ((<>))
import Numeric.Natural

import Control.Monad.Trans (MonadTrans(..))
import qualified Crypto.PubKey.RSA as RSA
import Crypto.Error (CryptoError)
import Crypto.Random (MonadRandom(..))
import Control.Lens (Getter, to)
import Control.Lens.TH (makeClassyPrisms, makePrisms)
import qualified Data.Text as T
import qualified Data.Text.Encoding.Error as T


-- | The wrong number of parts were found when decoding a
-- compact JOSE object.
--
data InvalidNumberOfParts =
  InvalidNumberOfParts Natural Natural -- ^ expected vs actual parts
  deriving (InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
(InvalidNumberOfParts -> InvalidNumberOfParts -> Bool)
-> (InvalidNumberOfParts -> InvalidNumberOfParts -> Bool)
-> Eq InvalidNumberOfParts
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c/= :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
$c== :: InvalidNumberOfParts -> InvalidNumberOfParts -> Bool
Eq)

instance Show InvalidNumberOfParts where
  show :: InvalidNumberOfParts -> String
show (InvalidNumberOfParts Natural
n Natural
m) =
    String
"Expected " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" parts; got " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
m

-- | Get the expected or actual number of parts.
expectedParts, actualParts :: Getter InvalidNumberOfParts Natural
expectedParts :: (Natural -> f Natural)
-> InvalidNumberOfParts -> f InvalidNumberOfParts
expectedParts = (InvalidNumberOfParts -> Natural)
-> (Natural -> f Natural)
-> InvalidNumberOfParts
-> f InvalidNumberOfParts
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to ((InvalidNumberOfParts -> Natural)
 -> (Natural -> f Natural)
 -> InvalidNumberOfParts
 -> f InvalidNumberOfParts)
-> (InvalidNumberOfParts -> Natural)
-> (Natural -> f Natural)
-> InvalidNumberOfParts
-> f InvalidNumberOfParts
forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
n Natural
_) -> Natural
n
actualParts :: (Natural -> f Natural)
-> InvalidNumberOfParts -> f InvalidNumberOfParts
actualParts   = (InvalidNumberOfParts -> Natural)
-> (Natural -> f Natural)
-> InvalidNumberOfParts
-> f InvalidNumberOfParts
forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to ((InvalidNumberOfParts -> Natural)
 -> (Natural -> f Natural)
 -> InvalidNumberOfParts
 -> f InvalidNumberOfParts)
-> (InvalidNumberOfParts -> Natural)
-> (Natural -> f Natural)
-> InvalidNumberOfParts
-> f InvalidNumberOfParts
forall a b. (a -> b) -> a -> b
$ \(InvalidNumberOfParts Natural
_ Natural
n) -> Natural
n


-- | Bad UTF-8 data in a compact object, at the specified index
data CompactTextError = CompactTextError
  Natural
  T.UnicodeException
  deriving (CompactTextError -> CompactTextError -> Bool
(CompactTextError -> CompactTextError -> Bool)
-> (CompactTextError -> CompactTextError -> Bool)
-> Eq CompactTextError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompactTextError -> CompactTextError -> Bool
$c/= :: CompactTextError -> CompactTextError -> Bool
== :: CompactTextError -> CompactTextError -> Bool
$c== :: CompactTextError -> CompactTextError -> Bool
Eq)

instance Show CompactTextError where
  show :: CompactTextError -> String
show (CompactTextError Natural
n UnicodeException
s) =
    String
"Invalid text at part " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Natural -> String
forall a. Show a => a -> String
show Natural
n String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
": " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> UnicodeException -> String
forall a. Show a => a -> String
show UnicodeException
s


-- | An error when decoding a JOSE compact object.
-- JSON decoding errors that occur during compact object processing
-- throw 'JSONDecodeError'.
--
data CompactDecodeError
  = CompactInvalidNumberOfParts InvalidNumberOfParts
  | CompactInvalidText CompactTextError
  deriving (CompactDecodeError -> CompactDecodeError -> Bool
(CompactDecodeError -> CompactDecodeError -> Bool)
-> (CompactDecodeError -> CompactDecodeError -> Bool)
-> Eq CompactDecodeError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CompactDecodeError -> CompactDecodeError -> Bool
$c/= :: CompactDecodeError -> CompactDecodeError -> Bool
== :: CompactDecodeError -> CompactDecodeError -> Bool
$c== :: CompactDecodeError -> CompactDecodeError -> Bool
Eq)
makePrisms ''CompactDecodeError

instance Show CompactDecodeError where
  show :: CompactDecodeError -> String
show (CompactInvalidNumberOfParts InvalidNumberOfParts
e) = String
"Invalid number of parts: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> InvalidNumberOfParts -> String
forall a. Show a => a -> String
show InvalidNumberOfParts
e
  show (CompactInvalidText CompactTextError
e) = String
"Invalid text: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CompactTextError -> String
forall a. Show a => a -> String
show CompactTextError
e



-- | All the errors that can occur.
--
data Error
  = AlgorithmNotImplemented   -- ^ A requested algorithm is not implemented
  | AlgorithmMismatch String  -- ^ A requested algorithm cannot be used
  | KeyMismatch T.Text        -- ^ Wrong type of key was given
  | KeySizeTooSmall           -- ^ Key size is too small
  | OtherPrimesNotSupported   -- ^ RSA private key with >2 primes not supported
  | RSAError RSA.Error        -- ^ RSA encryption, decryption or signing error
  | CryptoError CryptoError   -- ^ Various cryptonite library error cases
  | CompactDecodeError CompactDecodeError
  -- ^ Wrong number of parts in compact serialisation
  | JSONDecodeError String    -- ^ JSON (Aeson) decoding error
  | NoUsableKeys              -- ^ No usable keys were found in the key store
  | JWSCritUnprotected
  | JWSNoValidSignatures
  -- ^ 'AnyValidated' policy active, and no valid signature encountered
  | JWSInvalidSignature
  -- ^ 'AllValidated' policy active, and invalid signature encountered
  | JWSNoSignatures
  -- ^ 'AllValidated' policy active, and there were no signatures on object
  --   that matched the allowed algorithms
  deriving (Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq, Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show)
makeClassyPrisms ''Error


instance (
    MonadRandom m
  , MonadTrans t
  , Functor (t m)
  , Monad (t m)
  ) => MonadRandom (t m) where
    getRandomBytes :: Int -> t m byteArray
getRandomBytes = m byteArray -> t m byteArray
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m byteArray -> t m byteArray)
-> (Int -> m byteArray) -> Int -> t m byteArray
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes