module Network.Kafka where
import Control.Applicative
import Control.Exception (Exception, IOException)
import Control.Exception.Lifted (catch)
import Control.Lens
import Control.Monad.Trans.Control (MonadBaseControl)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Except (ExceptT(..), runExceptT, MonadError(..))
import Control.Monad.Trans.State
import Control.Monad.State.Class (MonadState)
import Data.ByteString.Char8 (ByteString)
import Data.List.NonEmpty (NonEmpty(..))
import qualified Data.List.NonEmpty as NE
import Data.Monoid ((<>))
import qualified Data.Pool as Pool
import GHC.Generics (Generic)
import System.IO
import qualified Data.Map as M
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Network
import Prelude
import Network.Kafka.Protocol
type KafkaAddress = (Host, Port)
data KafkaState = KafkaState { 
                               _stateName :: KafkaString
                               
                             , _stateRequiredAcks :: RequiredAcks
                               
                             , _stateRequestTimeout :: Timeout
                               
                             , _stateWaitSize :: MinBytes
                               
                             , _stateBufferSize :: MaxBytes
                               
                             , _stateWaitTime :: MaxWaitTime
                               
                             , _stateCorrelationId :: CorrelationId
                               
                             , _stateBrokers :: M.Map Leader Broker
                               
                             , _stateConnections :: M.Map KafkaAddress (Pool.Pool Handle)
                               
                             , _stateTopicMetadata :: M.Map TopicName TopicMetadata
                               
                             , _stateAddresses :: NonEmpty KafkaAddress
                             } deriving (Generic, Show)
makeLenses ''KafkaState
type Kafka m = (MonadState KafkaState m, MonadError KafkaClientError m, MonadIO m, MonadBaseControl IO m)
type KafkaClientId = KafkaString
data KafkaClientError = 
                        KafkaNoOffset
                        
                      | KafkaDeserializationError String 
                        
                      | KafkaInvalidBroker Leader
                      | KafkaFailedToFetchMetadata
                      | KafkaIOException IOException
                        deriving (Eq, Generic, Show)
instance Exception KafkaClientError
data KafkaTime = 
                 LatestTime
                 
               | EarliestTime
                 
               | OtherTime Time
               deriving (Eq, Generic)
data PartitionAndLeader = PartitionAndLeader { _palTopic :: TopicName
                                             , _palPartition :: Partition
                                             , _palLeader :: Leader
                                             }
                                             deriving (Show, Generic, Eq, Ord)
makeLenses ''PartitionAndLeader
data TopicAndPartition = TopicAndPartition { _tapTopic :: TopicName
                                           , _tapPartition :: Partition
                                           }
                         deriving (Eq, Generic, Ord, Show)
data TopicAndMessage = TopicAndMessage { _tamTopic :: TopicName
                                       , _tamMessage :: Message
                                       }
                       deriving (Eq, Generic, Show)
makeLenses ''TopicAndMessage
tamPayload :: TopicAndMessage -> ByteString
tamPayload = foldOf (tamMessage . payload)
defaultCorrelationId :: CorrelationId
defaultCorrelationId = 0
defaultRequiredAcks :: RequiredAcks
defaultRequiredAcks = 1
defaultRequestTimeout :: Timeout
defaultRequestTimeout = 10000
defaultMinBytes :: MinBytes
defaultMinBytes = MinBytes 0
defaultMaxBytes :: MaxBytes
defaultMaxBytes = 1024 * 1024
defaultMaxWaitTime :: MaxWaitTime
defaultMaxWaitTime = 0
mkKafkaState :: KafkaClientId -> KafkaAddress -> KafkaState
mkKafkaState cid addy =
    KafkaState cid
               defaultRequiredAcks
               defaultRequestTimeout
               defaultMinBytes
               defaultMaxBytes
               defaultMaxWaitTime
               defaultCorrelationId
               M.empty
               M.empty
               M.empty
               (addy :| [])
addKafkaAddress :: KafkaAddress -> KafkaState -> KafkaState
addKafkaAddress = over stateAddresses . NE.nub .: NE.cons
  where infixr 9 .:
        (.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d
        (.:) = (.).(.)
runKafka :: KafkaState -> StateT KafkaState (ExceptT KafkaClientError IO) a -> IO (Either KafkaClientError a)
runKafka s k = runExceptT $ evalStateT k s
tryKafka :: Kafka m => m a -> m a
tryKafka = (`catch` \e -> throwError $ KafkaIOException (e :: IOException))
makeRequest :: Kafka m => Handle -> ReqResp (m a) -> m a
makeRequest h reqresp = do
  (clientId, correlationId) <- makeIds
  eitherResp <- tryKafka $ doRequest clientId correlationId h reqresp
  case eitherResp of
    Left s -> throwError $ KafkaDeserializationError s
    Right r -> return r
  where
    makeIds :: MonadState KafkaState m => m (ClientId, CorrelationId)
    makeIds = do
      corid <- use stateCorrelationId
      stateCorrelationId += 1
      conid <- use stateName
      return (ClientId conid, corid)
metadata :: Kafka m => MetadataRequest -> m MetadataResponse
metadata request = withAnyHandle $ flip metadata' request
metadata' :: Kafka m => Handle -> MetadataRequest -> m MetadataResponse
metadata' h request = makeRequest h $ MetadataRR request
getTopicPartitionLeader :: Kafka m => TopicName -> Partition -> m Broker
getTopicPartitionLeader t p = do
  let s = stateTopicMetadata . at t
  tmd <- findMetadataOrElse [t] s KafkaFailedToFetchMetadata
  leader <- expect KafkaFailedToFetchMetadata (firstOf $ findPartitionMetadata t . (folded . findPartition p) . partitionMetadataLeader) tmd
  use stateBrokers >>= expect (KafkaInvalidBroker leader) (view $ at leader)
expect :: Kafka m => KafkaClientError -> (a -> Maybe b) -> a -> m b
expect e f = maybe (throwError e) return . f
brokerPartitionInfo :: Kafka m => TopicName -> m (Set PartitionAndLeader)
brokerPartitionInfo t = do
  let s = stateTopicMetadata . at t
  tmd <- findMetadataOrElse [t] s KafkaFailedToFetchMetadata
  return $ Set.fromList $ pal <$> tmd ^. partitionsMetadata
    where pal d = PartitionAndLeader t (d ^. partitionId) (d ^. partitionMetadataLeader)
findMetadataOrElse :: Kafka m => [TopicName] -> Getting (Maybe a) KafkaState (Maybe a) -> KafkaClientError -> m a
findMetadataOrElse ts s err = do
  maybeFound <- use s
  case maybeFound of
    Just x -> return x
    Nothing -> do
      updateMetadatas ts
      maybeFound' <- use s
      case maybeFound' of
        Just x -> return x
        Nothing -> throwError err
protocolTime :: KafkaTime -> Time
protocolTime LatestTime = Time (-1)
protocolTime EarliestTime = Time (-2)
protocolTime (OtherTime o) = o
updateMetadatas :: Kafka m => [TopicName] -> m ()
updateMetadatas ts = do
  md <- metadata $ MetadataReq ts
  let (brokers, tmds) = (md ^.. metadataResponseBrokers . folded, md ^.. topicsMetadata . folded)
      addresses = map broker2address brokers
  stateAddresses %= NE.nub . NE.fromList . (++ addresses) . NE.toList
  stateBrokers %= \m -> foldr addBroker m brokers
  stateTopicMetadata %= \m -> foldr addTopicMetadata m tmds
  return ()
    where addBroker :: Broker -> M.Map Leader Broker -> M.Map Leader Broker
          addBroker b = M.insert (Leader . Just $ b ^. brokerNode . nodeId) b
          addTopicMetadata :: TopicMetadata -> M.Map TopicName TopicMetadata -> M.Map TopicName TopicMetadata
          addTopicMetadata tm = M.insert (tm ^. topicMetadataName) tm
updateMetadata :: Kafka m => TopicName -> m ()
updateMetadata t = updateMetadatas [t]
updateAllMetadata :: Kafka m => m ()
updateAllMetadata = updateMetadatas []
withBrokerHandle :: Kafka m => Broker -> (Handle -> m a) -> m a
withBrokerHandle broker = withAddressHandle (broker2address broker)
withAddressHandle :: Kafka m => KafkaAddress -> (Handle -> m a) -> m a
withAddressHandle address kafkaAction = do
  conns <- use stateConnections
  let foundPool = conns ^. at address
  pool <- case foundPool of
    Nothing -> do
      newPool <- tryKafka $ liftIO $ mkPool address
      stateConnections .= (at address ?~ newPool $ conns)
      return newPool
    Just p -> return p
  tryKafka $ Pool.withResource pool kafkaAction
    where
      mkPool :: KafkaAddress -> IO (Pool.Pool Handle)
      mkPool a = Pool.createPool (createHandle a) hClose 1 10 1
        where createHandle (h, p) = Network.connectTo (h ^. hostString) (p ^. portId)
broker2address :: Broker -> KafkaAddress
broker2address broker = (,) (broker ^. brokerHost) (broker ^. brokerPort)
withAnyHandle :: Kafka m => (Handle -> m a) -> m a
withAnyHandle f = do
  (addy :| _) <- use stateAddresses
  x <- withAddressHandle addy f
  stateAddresses %= rotate
  return x
    where rotate :: NonEmpty a -> NonEmpty a
          rotate = NE.fromList . rotate' 1 . NE.toList
          rotate' n xs = zipWith const (drop n (cycle xs)) xs
data PartitionOffsetRequestInfo =
    PartitionOffsetRequestInfo { 
                                 _kafkaTime :: KafkaTime
                                 
                               , _maxNumOffsets :: MaxNumberOfOffsets
                               }
getLastOffset :: Kafka m => KafkaTime -> Partition -> TopicName -> m Offset
getLastOffset m p t = do
  broker <- getTopicPartitionLeader t p
  withBrokerHandle broker (\h -> getLastOffset' h m p t)
getLastOffset' :: Kafka m => Handle -> KafkaTime -> Partition -> TopicName -> m Offset
getLastOffset' h m p t = do
  let offsetRR = OffsetRR $ offsetRequest [(TopicAndPartition t p, PartitionOffsetRequestInfo m 1)]
  offsetResponse <- makeRequest h offsetRR
  let maybeResp = firstOf (offsetResponseOffset p) offsetResponse
  maybe (throwError KafkaNoOffset) return maybeResp
offsetRequest :: [(TopicAndPartition, PartitionOffsetRequestInfo)] -> OffsetRequest
offsetRequest ts =
    OffsetReq (ReplicaId (-1), M.toList . M.unionsWith (<>) $ fmap f ts)
        where f (TopicAndPartition t p, i) = M.singleton t [g p i]
              g p (PartitionOffsetRequestInfo kt mno) = (p, protocolTime kt, mno)