{-# LANGUAGE CPP               #-}
{-# LANGUAGE ConstraintKinds   #-}
{-# LANGUAGE OverloadedStrings #-}
module Language.Fixpoint.Graph.Partition (
  
    CPart (..)
  , partition, partition', partitionN
  
  , MCInfo (..)
  , mcInfo
  
  , dumpPartitions
  ) where
import           GHC.Conc                  (getNumProcessors)
import           Control.Monad             (forM_)
import           Language.Fixpoint.Misc         
import           Language.Fixpoint.Utils.Files
import           Language.Fixpoint.Types.Config
import qualified Language.Fixpoint.Types              as F
import           Language.Fixpoint.Graph.Types
import           Language.Fixpoint.Graph.Deps
import qualified Data.HashMap.Strict                  as M
import           Data.Maybe                     (fromMaybe)
import           Data.Hashable
#if !MIN_VERSION_base(4,14,0)
import           Data.Semigroup                 (Semigroup (..))
#endif
import           Text.PrettyPrint.HughesPJ.Compat
import           Data.List (sortBy)
data CPart c a = CPart { pws :: !(M.HashMap F.KVar (F.WfC a))
                       , pcm :: !(M.HashMap Integer (c a))
                       }
instance Semigroup (CPart c a) where
  l <> r = CPart { pws = pws l <> pws r
                 , pcm = pcm l <> pcm r
                 }
instance Monoid (CPart c a) where
   mempty      = CPart mempty mempty
   mappend     = (<>)
data MCInfo = MCInfo { mcCores       :: !Int
                     , mcMinPartSize :: !Int
                     , mcMaxPartSize :: !Int
                     } deriving (Show)
mcInfo :: Config -> IO MCInfo
mcInfo c = do
   np <- getNumProcessors
   let nc = fromMaybe np (cores c)
   return MCInfo { mcCores = nc
                 , mcMinPartSize = minPartSize c
                 , mcMaxPartSize = maxPartSize c
                 }
partition :: (F.Fixpoint a, F.Fixpoint (c a), F.TaggedC c a) => Config -> F.GInfo c a -> IO (F.Result (Integer, a))
partition cfg fi
  = do dumpPartitions cfg fis
       
       return mempty
    where
      
      fis = partition' Nothing fi
partition' :: (F.TaggedC c a)
           => Maybe MCInfo -> F.GInfo c a -> [F.GInfo c a]
partition' mn fi  = case mn of
   Nothing -> fis mkPartition id
   Just mi -> partitionN mi fi $ fis mkPartition' finfoToCpart
  where
    css            = decompose fi
    fis partF ctor = applyNonNull [ctor fi] (pbc partF) css
    pbc partF      = partitionByConstraints partF fi
partitionN :: MCInfo        
           -> F.GInfo c a   
           -> [CPart c a]   
           -> [F.GInfo c a] 
partitionN mi fi cp
   | cpartSize (finfoToCpart fi) <= minThresh = [fi]
   | otherwise = map (cpartToFinfo fi) $ toNParts sortedParts
   where
      toNParts p
         | isDone p = p
         | otherwise = toNParts $ insertSorted firstTwo rest
            where (firstTwo, rest) = unionFirstTwo p
      isDone [] = True
      isDone [_] = True
      isDone fi'@(a:b:_) = length fi' <= prts
                            && (cpartSize a >= minThresh
                                || cpartSize a + cpartSize b >= maxThresh)
      sortedParts = sortBy sortPredicate cp
      unionFirstTwo (a:b:xs) = (a `mappend` b, xs)
      unionFirstTwo _        = errorstar "Partition.partitionN.unionFirstTwo called on bad arguments"
      sortPredicate lhs rhs
         | cpartSize lhs < cpartSize rhs = GT
         | cpartSize lhs > cpartSize rhs = LT
         | otherwise = EQ
      insertSorted a []     = [a]
      insertSorted a (x:xs) = if sortPredicate a x == LT
                              then x : insertSorted a xs
                              else a : x : xs
      prts      = mcCores mi
      minThresh = mcMinPartSize mi
      maxThresh = mcMaxPartSize mi
cpartSize :: CPart c a -> Int
cpartSize c = (M.size . pcm) c + (length . pws) c
cpartToFinfo :: F.GInfo c a -> CPart c a -> F.GInfo c a
cpartToFinfo fi p = fi {F.ws = pws p, F.cm = pcm p}
finfoToCpart :: F.GInfo c a -> CPart c a
finfoToCpart fi = CPart { pcm = F.cm fi
                        , pws = F.ws fi
                        }
dumpPartitions :: (F.Fixpoint (c a), F.Fixpoint a) => Config -> [F.GInfo c a] -> IO ()
dumpPartitions cfg fis =
  forM_ (zip [0..] fis) $ \(i, fi) ->
    writeFile (queryFile (Part i) cfg) (render $ F.toFixpoint cfg fi)
type PartitionCtor c a b = F.GInfo c a
                       -> M.HashMap Int [(Integer, c a)]
                       -> M.HashMap Int [(F.KVar, F.WfC a)]
                       -> Int
                       -> b 
partitionByConstraints :: PartitionCtor c a b 
                       -> F.GInfo c a
                       -> KVComps
                       -> ListNE b          
partitionByConstraints f fi kvss = f fi icM iwM <$> js
  where
    js   = fst <$> jkvs                                
    gc   = groupFun cM                                 
    gk   = groupFun kM                                 
    iwM  = groupMap (gk . fst) (M.toList (F.ws fi))    
    icM  = groupMap (gc . fst) (M.toList (F.cm fi))    
    jkvs = zip [1..] kvss
    kvI  = [ (x, j) | (j, kvs) <- jkvs, x <- kvs ]
    kM   = M.fromList [ (k, i) | (KVar k, i) <- kvI ]
    cM   = M.fromList [ (c, i) | (Cstr c, i) <- kvI ]
mkPartition :: F.GInfo c a
            -> M.HashMap Int [(Integer, c a)]
            -> M.HashMap Int [(F.KVar, F.WfC a)]
            -> Int
            -> F.GInfo c a
mkPartition fi icM iwM j
  = fi{ F.cm       = M.fromList $ M.lookupDefault [] j icM
      , F.ws       = M.fromList $ M.lookupDefault [] j iwM
      }
mkPartition' :: F.GInfo c a
             -> M.HashMap Int [(Integer, c a)]
             -> M.HashMap Int [(F.KVar, F.WfC a)]
             -> Int
             -> CPart c a
mkPartition' _ icM iwM j
  = CPart { pcm       = M.fromList $ M.lookupDefault [] j icM
          , pws       = M.fromList $ M.lookupDefault [] j iwM
          }
groupFun :: (Show k, Eq k, Hashable k) => M.HashMap k Int -> k -> Int
groupFun m k = safeLookup ("groupFun: " ++ show k) k m