-- Copyright (c) 2014-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is distributed under the terms of a BSD license,
-- found in the LICENSE file.

{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- A generic Haxl datasource for performing arbitrary IO concurrently.
-- Every IO operation will be performed in a separate thread.
-- You can use this with any kind of IO, but each different operation
-- requires an instance of the 'ConcurrentIO' class.
--
-- For example, to make a concurrent sleep operation:
--
-- > sleep :: Int -> GenHaxl u w Int
-- > sleep n = dataFetch (Sleep n)
-- >
-- > data Sleep
-- > instance ConcurrentIO Sleep where
-- >   data ConcurrentIOReq Sleep a where
-- >     Sleep :: Int -> ConcurrentIOReq Sleep Int
-- >
-- >   performIO (Sleep n) = threadDelay (n*1000) >> return n
-- >
-- > deriving instance Eq (ConcurrentIOReq Sleep a)
-- > deriving instance Show (ConcurrentIOReq Sleep a)
-- >
-- > instance ShowP (ConcurrentIOReq Sleep) where showp = show
-- >
-- > instance Hashable (ConcurrentIOReq Sleep a) where
-- >   hashWithSalt s (Sleep n) = hashWithSalt s n
--
-- Note that you can have any number of constructors in your
-- ConcurrentIOReq GADT, so most of the boilerplate only needs to be
-- written once.

module Haxl.DataSource.ConcurrentIO
  ( mkConcurrentIOState
  , ConcurrentIO(..)
  ) where

import Control.Concurrent
import Control.Monad
import qualified Data.Text as Text
import Data.Typeable

import Haxl.Core

class ConcurrentIO tag where
  data ConcurrentIOReq tag a
  performIO :: ConcurrentIOReq tag a -> IO a

deriving instance Typeable ConcurrentIOReq -- not needed by GHC 7.10 and later

instance (Typeable tag) => StateKey (ConcurrentIOReq tag) where
  data State (ConcurrentIOReq tag) = ConcurrentIOState
  getStateType :: Proxy (ConcurrentIOReq tag) -> TypeRep
getStateType Proxy (ConcurrentIOReq tag)
_ = Proxy ConcurrentIOReq -> TypeRep
forall k (proxy :: k -> *) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep (Proxy ConcurrentIOReq
forall k (t :: k). Proxy t
Proxy :: Proxy ConcurrentIOReq)

mkConcurrentIOState :: IO (State (ConcurrentIOReq ()))
mkConcurrentIOState :: IO (State (ConcurrentIOReq ()))
mkConcurrentIOState = State (ConcurrentIOReq ()) -> IO (State (ConcurrentIOReq ()))
forall (m :: * -> *) a. Monad m => a -> m a
return State (ConcurrentIOReq ())
forall tag. State (ConcurrentIOReq tag)
ConcurrentIOState

instance Typeable tag => DataSourceName (ConcurrentIOReq tag) where
  dataSourceName :: Proxy (ConcurrentIOReq tag) -> Text
dataSourceName Proxy (ConcurrentIOReq tag)
_ =
    String -> Text
Text.pack (TyCon -> String
forall a. Show a => a -> String
show (TypeRep -> TyCon
typeRepTyCon (Proxy tag -> TypeRep
forall k (proxy :: k -> *) (a :: k).
Typeable a =>
proxy a -> TypeRep
typeRep (Proxy tag
forall k (t :: k). Proxy t
Proxy :: Proxy tag))))

instance
  (Typeable tag, ShowP (ConcurrentIOReq tag), ConcurrentIO tag)
  => DataSource u (ConcurrentIOReq tag)
 where
  fetch :: State (ConcurrentIOReq tag)
-> Flags -> u -> PerformFetch (ConcurrentIOReq tag)
fetch State (ConcurrentIOReq tag)
_state Flags
_flags u
_u = ([BlockedFetch (ConcurrentIOReq tag)] -> IO ())
-> PerformFetch (ConcurrentIOReq tag)
forall (req :: * -> *).
([BlockedFetch req] -> IO ()) -> PerformFetch req
BackgroundFetch (([BlockedFetch (ConcurrentIOReq tag)] -> IO ())
 -> PerformFetch (ConcurrentIOReq tag))
-> ([BlockedFetch (ConcurrentIOReq tag)] -> IO ())
-> PerformFetch (ConcurrentIOReq tag)
forall a b. (a -> b) -> a -> b
$ \[BlockedFetch (ConcurrentIOReq tag)]
bfs -> do
    [BlockedFetch (ConcurrentIOReq tag)]
-> (BlockedFetch (ConcurrentIOReq tag) -> IO ThreadId) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [BlockedFetch (ConcurrentIOReq tag)]
bfs ((BlockedFetch (ConcurrentIOReq tag) -> IO ThreadId) -> IO ())
-> (BlockedFetch (ConcurrentIOReq tag) -> IO ThreadId) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(BlockedFetch ConcurrentIOReq tag a
req ResultVar a
rv) ->
      IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (ConcurrentIOReq tag a -> IO a
forall tag a. ConcurrentIO tag => ConcurrentIOReq tag a -> IO a
performIO ConcurrentIOReq tag a
req) (ResultVar a -> Either SomeException a -> IO ()
forall a. ResultVar a -> Either SomeException a -> IO ()
putResultFromChildThread ResultVar a
rv)