module B9.B9Exec
  ( cmd,
    cmdInteractive,
    hostCmdEither,
    hostCmd,
    hostCmdStdIn,
    Timeout (..),
    HostCommandStdin (..),
  )
where
import B9.B9Config
import B9.B9Error
import B9.B9Logging
import B9.BuildInfo (BuildInfoReader, isInteractive)
import qualified Conduit as CL
import Control.Concurrent
import Control.Concurrent.Async (Concurrently (..), race)
import Control.Eff
import qualified Control.Exception as ExcIO
import Control.Lens (view)
import Control.Monad.IO.Class
import Control.Monad.Trans.Control (control, embed_, restoreM)
import qualified Data.ByteString as Strict
import Data.Conduit
  ( (.|),
    ConduitT,
    Void,
    runConduit,
  )
import qualified Data.Conduit.Binary as CB
import qualified Data.Conduit.List as CL
import Data.Conduit.Process
import Data.Functor ()
import Data.Maybe
import qualified Data.Text as Text
import GHC.Stack
import System.Exit
cmdInteractive ::
  (HasCallStack, Member ExcB9 e, Member BuildInfoReader e, CommandIO e) =>
  String ->
  Eff e ()
cmdInteractive str = do
  t <- view defaultTimeout <$> getB9Config
  inheritStdIn <- isInteractive
  ok <-
    if inheritStdIn
      then hostCmdEither HostCommandInheritStdin str t
      else hostCmdEither HostCommandNoStdin str t
  case ok of
    Right _ ->
      return ()
    Left e ->
      errorExitL ("SYSTEM COMMAND FAILED: " ++ show e)
cmd ::
  (HasCallStack, Member ExcB9 e, CommandIO e) =>
  String ->
  Eff e ()
cmd str = do
  t <- view defaultTimeout <$> getB9Config
  ok <- hostCmdEither HostCommandNoStdin str t
  case ok of
    Right _ ->
      return ()
    Left e ->
      errorExitL ("SYSTEM COMMAND FAILED: " ++ show e)
hostCmd ::
  (CommandIO e, Member ExcB9 e) =>
  
  String ->
  
  Maybe Timeout ->
  
  Eff e Bool
hostCmd cmdStr timeout = do
  res <- hostCmdEither HostCommandNoStdin cmdStr timeout
  case res of
    Left e ->
      throwB9Error ("Command timed out: " ++ show cmdStr ++ " " ++ show e)
    Right (ExitFailure ec) -> do
      errorL ("Command exited with error code: " ++ show cmdStr ++ " " ++ show ec)
      return False
    Right ExitSuccess ->
      return True
hostCmdStdIn ::
  (CommandIO e, Member ExcB9 e) =>
  
  
  
  
  HostCommandStdin ->
  
  String ->
  
  Maybe Timeout ->
  
  Eff e Bool
hostCmdStdIn hostStdIn cmdStr timeout = do
  res <- hostCmdEither hostStdIn cmdStr timeout
  case res of
    Left e ->
      throwB9Error ("Command timed out: " ++ show cmdStr ++ " " ++ show e)
    Right (ExitFailure ec) -> do
      errorL ("Command exited with error code: " ++ show cmdStr ++ " " ++ show ec)
      return False
    Right ExitSuccess ->
      return True
data HostCommandStdin
  = 
    HostCommandNoStdin
  | 
    HostCommandInheritStdin
  | 
    HostCommandStdInConduit (ConduitT () Strict.ByteString IO ())
hostCmdEither ::
  forall e.
  (CommandIO e) =>
  
  
  
  
  HostCommandStdin ->
  
  String ->
  
  Maybe Timeout ->
  Eff e (Either Timeout ExitCode)
hostCmdEither inputSource cmdStr timeoutArg = do
  let tag = "[" ++ printHash cmdStr ++ "]"
  traceL $ "COMMAND " ++ tag ++ ": " ++ cmdStr
  tf <- fromMaybe 1 . view timeoutFactor <$> getB9Config
  timeout <-
    fmap (TimeoutMicros . \(TimeoutMicros t) -> tf * t)
      <$> maybe
        (view defaultTimeout <$> getB9Config)
        (return . Just)
        timeoutArg
  control $ \runInIO ->
    do
      ExcIO.catch
        (runInIO (go timeout tag))
        ( \(e :: ExcIO.SomeException) -> do
            runInIO (errorL ("COMMAND " ++ tag ++ " interrupted: " ++ show e))
            runInIO (return (Right (ExitFailure 126) :: Either Timeout ExitCode))
        )
      >>= restoreM
  where
    go :: Maybe Timeout -> String -> Eff e (Either Timeout ExitCode)
    go timeout tag = do
      traceLC <- traceMsgProcessLogger tag
      errorLC <- errorMsgProcessLogger tag
      let timer t@(TimeoutMicros micros) = do
            threadDelay micros
            return t
      (cph, runCmd) <- case inputSource of
        HostCommandNoStdin -> do
          (ClosedStream, cpOut, cpErr, cph) <- streamingProcess (shell cmdStr)
          let runCmd =
                runConcurrently
                  ( Concurrently (runConduit (cpOut .| runProcessLogger traceLC))
                      *> Concurrently (runConduit (cpErr .| runProcessLogger errorLC))
                      *> Concurrently (waitForStreamingProcess cph)
                  )
          return (cph, runCmd)
        HostCommandInheritStdin -> do
          (Inherited, Inherited, Inherited, cph) <- streamingProcess (shell cmdStr)
          let runCmd = waitForStreamingProcess cph
          return (cph, runCmd)
        HostCommandStdInConduit inputC -> do
          (stdIn, cpOut, cpErr, cph) <- streamingProcess (shell cmdStr)
          let runCmd =
                runConcurrently
                  ( Concurrently (runConduit (cpOut .| runProcessLogger traceLC))
                      *> Concurrently (runConduit (cpErr .| runProcessLogger errorLC))
                      *> Concurrently (runConduit (inputC .| stdIn))
                      *> Concurrently (waitForStreamingProcess cph)
                  )
          return (cph, runCmd)
      e <- liftIO (maybe (fmap Right) (race . timer) timeout runCmd)
      closeStreamingProcessHandle cph
      case e of
        Left _ ->
          errorL $ "COMMAND TIMED OUT " ++ tag
        Right ExitSuccess ->
          traceL $ "COMMAND FINISHED " ++ tag
        Right (ExitFailure ec) ->
          errorL $ "COMMAND FAILED EXIT CODE: " ++ show ec ++ " " ++ tag
      return e
newtype ProcessLogger
  = MkProcessLogger
      {runProcessLogger :: ConduitT Strict.ByteString Void IO ()}
traceMsgProcessLogger :: (CommandIO e) => String -> Eff e ProcessLogger
traceMsgProcessLogger = mkMsgProcessLogger traceL
errorMsgProcessLogger :: (CommandIO e) => String -> Eff e ProcessLogger
errorMsgProcessLogger = mkMsgProcessLogger errorL
mkMsgProcessLogger :: (CommandIO e) => (String -> Eff e ()) -> String -> Eff e ProcessLogger
mkMsgProcessLogger logFun prefix = do
  logIO <-
    embed_
      ( \logBytes ->
          logFun (prefix ++ ": " ++ Text.unpack logBytes)
      )
  return
    ( MkProcessLogger
        ( CB.lines
            .| CL.decodeUtf8LenientC
            .| CL.mapM_ (liftIO . logIO)
        )
    )