{-# LANGUAGE CPP, DeriveDataTypeable, DeriveFunctor, DeriveGeneric,
    FlexibleContexts, OverloadedStrings, RecordWildCards #-}

module Network.Wreq.Cache
    (
      shouldCache
    , validateEntry
    , cacheStore
    ) where

import Control.Applicative
import Control.Lens ((^?), (^.), (^..), folded, non, pre, to)
import Control.Monad (guard)
import Data.Attoparsec.ByteString.Char8 as A
import Data.CaseInsensitive (mk)
import Data.Foldable (forM_)
import Data.HashSet (HashSet)
import Data.Hashable (Hashable)
import Data.IntSet (IntSet)
import Data.IORef (newIORef)
import Data.List (sort)
import Data.Maybe (listToMaybe)
import Data.Monoid (First(..), mconcat)
import Data.Time.Clock (UTCTime, addUTCTime, getCurrentTime)
import Data.Time.Format (parseTimeM)
import Data.Time.Locale.Compat (defaultTimeLocale)
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import Network.HTTP.Types (HeaderName, Method)
import Network.Wreq.Internal.Lens
import Network.Wreq.Internal.Types
import Network.Wreq.Lens
import qualified Data.ByteString.Char8 as B
import qualified Data.HashSet as HashSet
import qualified Data.IntSet as IntSet
import qualified Network.Wreq.Cache.Store as Store

#if MIN_VERSION_base(4,6,0)
import Data.IORef (atomicModifyIORef')
#else
import Data.IORef (IORef, atomicModifyIORef)

atomicModifyIORef' :: IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' = atomicModifyIORef
#endif

cacheStore :: Int -> IO (Run body -> Run body)
cacheStore :: forall body. Seconds -> IO (Run body -> Run body)
cacheStore Seconds
capacity = do
  IORef (Store Method (CacheEntry body))
cache <- forall a. a -> IO (IORef a)
newIORef (forall k v. Ord k => Seconds -> Store k v
Store.empty Seconds
capacity)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ \Run body
run Req
req -> do
    let url :: Method
url = Req -> Method
reqURL Req
req
    UTCTime
before <- IO UTCTime
getCurrentTime
    Maybe (Response body)
mresp <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (Store Method (CacheEntry body))
cache forall a b. (a -> b) -> a -> b
$ \Store Method (CacheEntry body)
s ->
      case forall k v.
(Ord k, Hashable k) =>
k -> Store k v -> Maybe (v, Store k v)
Store.lookup Method
url Store Method (CacheEntry body)
s of
        Maybe (CacheEntry body, Store Method (CacheEntry body))
Nothing       -> (Store Method (CacheEntry body)
s, forall a. Maybe a
Nothing)
        Just (CacheEntry body
ce, Store Method (CacheEntry body)
s') ->
          case forall body. UTCTime -> CacheEntry body -> Maybe (Response body)
validateEntry UTCTime
before CacheEntry body
ce of
            n :: Maybe (Response body)
n@Maybe (Response body)
Nothing -> (forall k v. (Ord k, Hashable k) => k -> Store k v -> Store k v
Store.delete Method
url Store Method (CacheEntry body)
s, Maybe (Response body)
n)
            Maybe (Response body)
resp      -> (Store Method (CacheEntry body)
s', Maybe (Response body)
resp)
    case Maybe (Response body)
mresp of
      Just Response body
resp -> forall (m :: * -> *) a. Monad m => a -> m a
return Response body
resp
      Maybe (Response body)
Nothing -> do
        Response body
resp <- Run body
run Req
req
        UTCTime
after <- IO UTCTime
getCurrentTime
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall body.
UTCTime -> Req -> Response body -> Maybe (CacheEntry body)
shouldCache UTCTime
after Req
req Response body
resp) forall a b. (a -> b) -> a -> b
$ \CacheEntry body
ce ->
          forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef (Store Method (CacheEntry body))
cache forall a b. (a -> b) -> a -> b
$ \Store Method (CacheEntry body)
s -> (forall k v. (Ord k, Hashable k) => k -> v -> Store k v -> Store k v
Store.insert Method
url CacheEntry body
ce Store Method (CacheEntry body)
s, ())
        forall (m :: * -> *) a. Monad m => a -> m a
return Response body
resp

cacheableStatuses :: IntSet
cacheableStatuses :: IntSet
cacheableStatuses = [Seconds] -> IntSet
IntSet.fromList [Seconds
200, Seconds
203, Seconds
300, Seconds
301, Seconds
410]

cacheableMethods :: HashSet Method
cacheableMethods :: HashSet Method
cacheableMethods = forall a. (Eq a, Hashable a) => [a] -> HashSet a
HashSet.fromList [Method
"GET", Method
"HEAD", Method
"OPTIONS"]

possiblyCacheable :: Request -> Response body -> Bool
possiblyCacheable :: forall body. Request -> Response body -> Bool
possiblyCacheable Request
req Response body
resp =
    (Request
req forall s a. s -> Getting a s a -> a
^. Lens' Request Method
method) forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
`HashSet.member` HashSet Method
cacheableMethods Bool -> Bool -> Bool
&&
    (Response body
resp forall s a. s -> Getting a s a -> a
^. forall body. Lens' (Response body) Status
responseStatus forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lens' Status Seconds
statusCode) Seconds -> IntSet -> Bool
`IntSet.member` IntSet
cacheableStatuses

computeExpiration :: UTCTime -> [CacheResponse Seconds] -> Maybe UTCTime
computeExpiration :: UTCTime -> [CacheResponse Seconds] -> Maybe UTCTime
computeExpiration UTCTime
now [CacheResponse Seconds]
crs = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *). Foldable t => t Bool -> Bool
and [forall age. [HeaderName] -> CacheResponse age
NoCache [] forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [CacheResponse Seconds]
crs, forall age. CacheResponse age
NoStore forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [CacheResponse Seconds]
crs]
  Seconds
age <- forall a. [a] -> Maybe a
listToMaybe forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> [a]
sort [Seconds
age | MaxAge Seconds
age <- [CacheResponse Seconds]
crs]
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall a b. (Integral a, Num b) => a -> b
fromIntegral Seconds
age NominalDiffTime -> UTCTime -> UTCTime
`addUTCTime` UTCTime
now

validateEntry :: UTCTime -> CacheEntry body -> Maybe (Response body)
validateEntry :: forall body. UTCTime -> CacheEntry body -> Maybe (Response body)
validateEntry UTCTime
now CacheEntry{Maybe UTCTime
UTCTime
Response body
entryResponse :: forall body. CacheEntry body -> Response body
entryExpires :: forall body. CacheEntry body -> Maybe UTCTime
entryCreated :: forall body. CacheEntry body -> UTCTime
entryResponse :: Response body
entryExpires :: Maybe UTCTime
entryCreated :: UTCTime
..} =
  case Maybe UTCTime
entryExpires of
    Maybe UTCTime
Nothing          -> forall a. a -> Maybe a
Just Response body
entryResponse
    Just UTCTime
e | UTCTime
e forall a. Ord a => a -> a -> Bool
> UTCTime
now -> forall a. a -> Maybe a
Just Response body
entryResponse
    Maybe UTCTime
_                -> forall a. Maybe a
Nothing

shouldCache :: UTCTime -> Req -> Response body -> Maybe (CacheEntry body)
shouldCache :: forall body.
UTCTime -> Req -> Response body -> Maybe (CacheEntry body)
shouldCache UTCTime
now (Req Mgr
_ Request
req) Response body
resp = do
  forall (f :: * -> *). Alternative f => Bool -> f ()
guard (forall body. Request -> Response body -> Bool
possiblyCacheable Request
req Response body
resp)
  let crs :: [CacheResponse Seconds]
crs = Response body
resp forall s a. s -> Getting (Endo [a]) s a -> [a]
^.. forall body. HeaderName -> Traversal' (Response body) Method
responseHeader HeaderName
"Cache-Control" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Parser a -> Fold Method a
atto_ Parser [CacheResponse Seconds]
parseCacheResponse forall b c a. (b -> c) -> (a -> b) -> a -> c
.
                     forall (f :: * -> *) a. Foldable f => IndexedFold Seconds (f a) a
folded forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to forall age. CacheResponse age -> CacheResponse age
simplifyCacheResponse
      dateHeader :: HeaderName
-> p UTCTime (f UTCTime) -> Response body -> f (Response body)
dateHeader HeaderName
name = forall body. HeaderName -> Traversal' (Response body) Method
responseHeader HeaderName
name forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (p :: * -> * -> *) (f :: * -> *) s a.
(Profunctor p, Contravariant f) =>
(s -> a) -> Optic' p f s a
to Method -> Maybe UTCTime
parseDate forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Foldable f => IndexedFold Seconds (f a) a
folded
      mexpires :: Maybe UTCTime
mexpires = case [CacheResponse Seconds]
crs of
                   [] -> Response body
resp forall s a. s -> Getting (First a) s a -> Maybe a
^? forall {f :: * -> *} {p :: * -> * -> *} {body}.
(Contravariant f, Indexable Seconds p, Applicative f) =>
HeaderName
-> p UTCTime (f UTCTime) -> Response body -> f (Response body)
dateHeader HeaderName
"Expires"
                   [CacheResponse Seconds]
_  -> UTCTime -> [CacheResponse Seconds] -> Maybe UTCTime
computeExpiration UTCTime
now [CacheResponse Seconds]
crs
      created :: UTCTime
created = Response body
resp forall s a. s -> Getting a s a -> a
^. forall a s.
Getting (First a) s a -> IndexPreservingGetter s (Maybe a)
pre (forall {f :: * -> *} {p :: * -> * -> *} {body}.
(Contravariant f, Indexable Seconds p, Applicative f) =>
HeaderName
-> p UTCTime (f UTCTime) -> Response body -> f (Response body)
dateHeader HeaderName
"Date") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => a -> Iso' (Maybe a) a
non UTCTime
now
  case Maybe UTCTime
mexpires of
    Just UTCTime
expires | UTCTime
expires forall a. Ord a => a -> a -> Bool
<= UTCTime
created                -> forall (f :: * -> *) a. Alternative f => f a
empty
    Maybe UTCTime
Nothing      | Request
req forall s a. s -> Getting a s a -> a
^. Lens' Request Method
method forall a. Eq a => a -> a -> Bool
== Method
"GET" Bool -> Bool -> Bool
&&
                   Bool -> Bool
not (Method -> Bool
B.null (Request
req forall s a. s -> Getting a s a -> a
^. Lens' Request Method
queryString)) -> forall (f :: * -> *) a. Alternative f => f a
empty
    Maybe UTCTime
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall body.
UTCTime -> Maybe UTCTime -> Response body -> CacheEntry body
CacheEntry UTCTime
created Maybe UTCTime
mexpires Response body
resp

type Seconds = Int

data CacheResponse age = Public
                       | Private [HeaderName]
                       | NoCache [HeaderName]
                       | NoStore
                       | NoTransform
                       | MustRevalidate
                       | ProxyRevalidate
                       | MaxAge age
                       | SMaxAge age
                       | Extension
                       deriving (CacheResponse age -> CacheResponse age -> Bool
forall age.
Eq age =>
CacheResponse age -> CacheResponse age -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CacheResponse age -> CacheResponse age -> Bool
$c/= :: forall age.
Eq age =>
CacheResponse age -> CacheResponse age -> Bool
== :: CacheResponse age -> CacheResponse age -> Bool
$c== :: forall age.
Eq age =>
CacheResponse age -> CacheResponse age -> Bool
Eq, Seconds -> CacheResponse age -> ShowS
forall age. Show age => Seconds -> CacheResponse age -> ShowS
forall age. Show age => [CacheResponse age] -> ShowS
forall age. Show age => CacheResponse age -> String
forall a.
(Seconds -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CacheResponse age] -> ShowS
$cshowList :: forall age. Show age => [CacheResponse age] -> ShowS
show :: CacheResponse age -> String
$cshow :: forall age. Show age => CacheResponse age -> String
showsPrec :: Seconds -> CacheResponse age -> ShowS
$cshowsPrec :: forall age. Show age => Seconds -> CacheResponse age -> ShowS
Show, forall a b. a -> CacheResponse b -> CacheResponse a
forall a b. (a -> b) -> CacheResponse a -> CacheResponse b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> CacheResponse b -> CacheResponse a
$c<$ :: forall a b. a -> CacheResponse b -> CacheResponse a
fmap :: forall a b. (a -> b) -> CacheResponse a -> CacheResponse b
$cfmap :: forall a b. (a -> b) -> CacheResponse a -> CacheResponse b
Functor, Typeable, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall age x. Rep (CacheResponse age) x -> CacheResponse age
forall age x. CacheResponse age -> Rep (CacheResponse age) x
$cto :: forall age x. Rep (CacheResponse age) x -> CacheResponse age
$cfrom :: forall age x. CacheResponse age -> Rep (CacheResponse age) x
Generic)

instance Hashable age => Hashable (CacheResponse age)

simplifyCacheResponse :: CacheResponse age -> CacheResponse age
simplifyCacheResponse :: forall age. CacheResponse age -> CacheResponse age
simplifyCacheResponse (Private [HeaderName]
_) = forall age. [HeaderName] -> CacheResponse age
Private []
simplifyCacheResponse (NoCache [HeaderName]
_) = forall age. [HeaderName] -> CacheResponse age
NoCache []
simplifyCacheResponse CacheResponse age
cr          = CacheResponse age
cr

parseCacheResponse :: A.Parser [CacheResponse Seconds]
parseCacheResponse :: Parser [CacheResponse Seconds]
parseCacheResponse = forall {a}. Parser Method a -> Parser Method [a]
commaSep1 Parser Method (CacheResponse Seconds)
body
  where
    body :: Parser Method (CacheResponse Seconds)
body = Parser Method Method
"public" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall age. CacheResponse age
Public
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"private" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (forall age. [HeaderName] -> CacheResponse age
Private forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {b}. Parser Method b -> Parser Method b
eq Parser Method [HeaderName]
headerNames forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure []))
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"no-cache" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (forall age. [HeaderName] -> CacheResponse age
NoCache forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {b}. Parser Method b -> Parser Method b
eq Parser Method [HeaderName]
headerNames forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> forall (f :: * -> *) a. Applicative f => a -> f a
pure []))
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"no-store" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall age. CacheResponse age
NoStore
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"no-transform" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall age. CacheResponse age
NoTransform
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"must-revalidate" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall age. CacheResponse age
MustRevalidate
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"proxy-revalidate" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall age. CacheResponse age
ProxyRevalidate
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"max-age" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall {b}. Parser Method b -> Parser Method b
eq (forall age. age -> CacheResponse age
MaxAge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Integral a => Parser a
decimal)
       forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Parser Method Method
"s-maxage" forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall {b}. Parser Method b -> Parser Method b
eq (forall age. age -> CacheResponse age
SMaxAge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Integral a => Parser a
decimal)
    headerNames :: Parser Method [HeaderName]
headerNames = Char -> Parser Char
A.char Char
'"' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall {a}. Parser Method a -> Parser Method [a]
commaSep1 Parser Method HeaderName
hdr forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Char -> Parser Char
A.char Char
'"'
    hdr :: Parser Method HeaderName
hdr = forall s. FoldCase s => s -> CI s
mk forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Char -> Bool) -> Parser Method Method
A.takeWhile1 (String -> Char -> Bool
inClass String
"a-zA-Z0-9_-")
    commaSep1 :: Parser Method a -> Parser Method [a]
commaSep1 Parser Method a
p = (Parser Method a
p forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser ()
skipSpace) forall (f :: * -> *) a s. Alternative f => f a -> f s -> f [a]
`sepBy1` (Char -> Parser Char
A.char Char
',' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser ()
skipSpace)
    eq :: Parser Method b -> Parser Method b
eq Parser Method b
p = Parser ()
skipSpace forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Char -> Parser Char
A.char Char
'=' forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser ()
skipSpace forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> Parser Method b
p

parseDate :: B.ByteString -> Maybe UTCTime
parseDate :: Method -> Maybe UTCTime
parseDate Method
s = forall a. First a -> Maybe a
getFirst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {a}. ParseTime a => String -> First a
tryout forall a b. (a -> b) -> a -> b
$ [
    String
"%a, %d %b %Y %H:%M:%S %Z"
  , String
"%A, %d-%b-%y %H:%M:%S %Z"
  , String
"%a %b %e %H:%M:%S %Y"
  ]
  where tryout :: String -> First a
tryout String
fmt = forall a. Maybe a -> First a
First forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) t.
(MonadFail m, ParseTime t) =>
Bool -> TimeLocale -> String -> String -> m t
parseTimeM Bool
True TimeLocale
defaultTimeLocale String
fmt (Method -> String
B.unpack Method
s)