{-# LANGUAGE GADTs, DataKinds, TypeApplications, RankNTypes, ScopedTypeVariables, ConstraintKinds #-} {-# LANGUAGE DeriveAnyClass, DeriveGeneric #-} {-# LANGUAGE FlexibleInstances, FlexibleContexts #-} {-# LANGUAGE QuasiQuotes, RecordWildCards, TupleSections, LambdaCase, OverloadedStrings #-} module Language.Coformat.Optimization where import qualified Data.HashMap.Strict as HM import qualified Data.Text as T import Control.Concurrent.Async.Pool import Control.Lens import Control.Monad.Except.CoHas import Control.Monad.Extra import Control.Monad.Logger import Control.Monad.Reader.Has hiding(update) import Control.Monad.State.Strict import Data.Foldable import Data.List import Data.Ord import Data.Proxy import Data.String.Interpolate.IsString import Numeric.Natural import Language.Coformat.Descr import Language.Coformat.Formatter import Language.Coformat.Formatter.Failure import Language.Coformat.Score import Language.Coformat.Util import Language.Coformat.Variables data FmtEnv = FmtEnv { baseStyle :: T.Text , preparedFiles :: [PreparedFile] , constantOpts :: [ConfigItemT 'Value] , formatterInfo :: FormatterInfo } data OptEnv = OptEnv { categoricalVariables :: [IxedCategoricalVariable] , integralVariables :: [IxedIntegralVariable] , maxSubsetSize :: Natural } runFormat :: (MonadError err m, CoHas UnexpectedFailure err, CoHas ExpectedFailure err, MonadIO m, MonadLogger m) => FormatterInfo -> PreparedFile -> String -> T.Text -> [ConfigItemT 'Value] -> m Score runFormat FormatterInfo { .. } prepared logStr baseSty opts = do stdout <- runCommand execName $ formatFile baseSty opts $ filename prepared let dist = calcScore prepared stdout logDebugN [i|#{logStr}: #{dist}|] pure dist type OptMonad err r m = (MonadLoggerIO m, MonadError err m, CoHas UnexpectedFailure err, MonadReader r m, Has FmtEnv r) runFormatFiles :: (OptMonad err r m, CoHas ExpectedFailure err) => [ConfigItemT 'Value] -> String -> m Score runFormatFiles varOpts logStr = do FmtEnv { .. } <- ask fmap mconcat $ forM preparedFiles $ \prepared -> runFormat formatterInfo prepared [i|#{logStr} at #{filename prepared}|] baseStyle $ constantOpts <> varOpts chooseBaseStyle :: (MonadError String m, MonadLoggerIO m) => FormatterInfo -> [T.Text] -> [ConfigItemT 'Value] -> [PreparedFile] -> m (T.Text, Score) chooseBaseStyle formatter baseStyles predefinedOpts files = do sty2dists <- forConcurrently' ((,) <$> baseStyles <*> files) $ \(sty, file) -> convert (show @(Either ExpectedFailure UnexpectedFailure)) $ (sty,) <$> runFormat formatter file [i|Initial guess for #{sty} at #{filename file}|] sty predefinedOpts let accumulated = HM.toList $ HM.fromListWith (<>) sty2dists forM_ accumulated $ \(sty, acc) -> logInfoN [i|Initial accumulated guess for #{sty}: #{acc}|] pure $ minimumBy (comparing snd) accumulated variateAt :: forall a. (Variate a, Foldable (VariateResult a)) => Proxy a -> Int -> [ConfigItemT 'Value] -> [[ConfigItemT 'Value]] variateAt _ idx opts = [ update idx (updater v') opts | v' <- toList variated ] where thePrism :: Prism' (ConfigTypeT 'Value) a thePrism = varPrism variated = variate $ value (opts !! idx) ^?! thePrism updater v cfg = cfg { value = value cfg & thePrism .~ v } data OptState = OptState { currentOpts :: [ConfigItemT 'Value] , currentScore :: Score } deriving (Show) initOptState :: [ConfigItemT 'Value] -> Score -> OptState initOptState currentOpts currentScore = OptState { .. } dropExpectedFailures :: OptMonad err r m => (forall err' r' m'. (OptMonad err' r' m', CoHas ExpectedFailure err') => m' Score) -> m Score dropExpectedFailures act = do res <- runExceptT act case res of Left (UnexpectedFailure failure) -> throwError failure Left (ExpectedFailure failure) -> do logErrorN [i|Unable to run the formatter: #{show failure}|] pure maxBound Right sc -> pure sc variateSubset :: [SomeIxedVariable] -> [ConfigItemT 'Value] -> [[ConfigItemT 'Value]] variateSubset [] opts = [opts] variateSubset (SomeIxedVariable (IxedVariable (MkDV (_ :: a)) idx) : rest) opts = concatMap (variateSubset rest) $ variateAt @a Proxy idx opts showVariated :: [SomeIxedVariable] -> [ConfigItemT 'Value] -> String showVariated vars opts = intercalate ", " [showVar var | var <- vars] where showVar (SomeIxedVariable (IxedVariable _ idx)) = [i|#{name $ opts !! idx} -> #{value $ opts !! idx}|] chooseBestSubset :: (OptMonad err r m, Has OptState r, Has TaskGroup r) => Natural -> [SomeIxedVariable] -> m (Maybe ([ConfigItemT 'Value], Score)) chooseBestSubset subsetSize ixedVariables = do OptState { .. } <- ask FmtEnv { .. } <- ask partialResults <- forConcurrentlyPooled (subsetsN subsetSize ixedVariables) $ \someVarsSubset -> do opt2scores <- forM (variateSubset someVarsSubset currentOpts) $ \opts' -> fmap (opts',) $ dropExpectedFailures $ runFormatFiles opts' $ showVariated someVarsSubset opts' let (bestOpts, bestScore) = minimumBy (comparing snd) opt2scores when (bestScore < currentScore) $ logInfoN [i|Total dist for #{showVariated someVarsSubset bestOpts}: #{currentScore} -> #{bestScore}|] pure (bestOpts, bestScore, someVarsSubset) let (bestOpts, bestScore, bestVarsSubset) = minimumBy (comparing (^. _2)) partialResults if bestScore < currentScore then do logInfoN [i|Choosing #{showVariated bestVarsSubset bestOpts}|] pure $ Just (bestOpts, bestScore) else pure Nothing stepGDGeneric' :: (OptMonad err r m, Has TaskGroup r, Has OptEnv r, MonadState OptState m) => Natural -> [OptEnv -> [SomeIxedVariable]] -> m () stepGDGeneric' subsetSize varGetters = do current <- get fmtEnv@FmtEnv { .. } <- ask optEnv <- ask tg :: TaskGroup <- ask runReaderT (chooseBestSubset subsetSize $ concatMap ($ optEnv) varGetters) (current, fmtEnv, tg) >>= \case Nothing -> pure () Just (opts', score') -> do logInfoN [i|Total score after optimization on all files: #{score'}|] put OptState { currentOpts = opts', currentScore = score' } stepGDGeneric :: (OptMonad err r m, Has TaskGroup r, Has OptEnv r, MonadState OptState m) => Natural -> [OptEnv -> [SomeIxedVariable]] -> m () stepGDGeneric subsetSize varGetters = whenM ((> mempty) <$> gets currentScore) $ stepGDGeneric' subsetSize varGetters fixGD :: (OptMonad err r m, Has TaskGroup r, Has OptEnv r, MonadState OptState m, err ~ UnexpectedFailure) => Maybe Int -> Natural -> m () fixGD (Just 0) _ = pure () fixGD counter curSubsetSize = do maxSubsetSize' <- asks maxSubsetSize if curSubsetSize > maxSubsetSize' then logInfoN [i|Done optimizing|] else do startScore <- gets currentScore stepGDGeneric curSubsetSize [asSome . categoricalVariables, asSome . integralVariables] endScore <- gets currentScore logInfoN [i|Full optimization step done, went from #{startScore} to #{endScore}|] if startScore /= endScore then fixGD (subtract 1 <$> counter) 1 else do logInfoN [i|Done optimizing with subset size #{curSubsetSize}, stopped at score #{endScore}|] fixGD (subtract 1 <$> counter) (curSubsetSize + 1)