{-# LANGUAGE CPP               #-}
{-# LANGUAGE ConstraintKinds   #-}
{-# LANGUAGE OverloadedStrings #-}

-- | This module implements functions to build constraint / kvar
--   dependency graphs, partition them and print statistics about
--   their structure.

module Language.Fixpoint.Graph.Partition (

  -- * Split constraints
    CPart (..)
  , partition, partition', partitionN

  -- * Information about cores
  , MCInfo (..)
  , mcInfo

  -- * Debug
  , dumpPartitions

  ) where

import           GHC.Conc                  (getNumProcessors)
import           Control.Monad             (forM_)
-- import           GHC.Generics              (Generic)
import           Language.Fixpoint.Misc         -- hiding (group)
import           Language.Fixpoint.Utils.Files
import           Language.Fixpoint.Types.Config
-- import           Language.Fixpoint.Types.PrettyPrint
-- import qualified Language.Fixpoint.Types.Visitor      as V
import qualified Language.Fixpoint.Types              as F
import           Language.Fixpoint.Graph.Types
import           Language.Fixpoint.Graph.Deps

import qualified Data.HashMap.Strict                  as M
-- import qualified Data.Graph                           as G
-- import qualified Data.Tree                            as T
-- import           Data.Function (on)
import           Data.Maybe                     (fromMaybe)
import           Data.Hashable
import           Text.PrettyPrint.HughesPJ.Compat
import           Data.List (sortBy)
import           Data.Function (on)
-- import qualified Data.HashSet              as S

-- import qualified Language.Fixpoint.Solver.Solution    as So
-- import Data.Graph.Inductive



--------------------------------------------------------------------------------
-- | Constraint Partition Container --------------------------------------------
--------------------------------------------------------------------------------

data CPart c a = CPart { forall (c :: * -> *) a. CPart c a -> HashMap KVar (WfC a)
pws :: !(M.HashMap F.KVar (F.WfC a))
                       , forall (c :: * -> *) a. CPart c a -> HashMap Integer (c a)
pcm :: !(M.HashMap Integer (c a))
                       }

instance Semigroup (CPart c a) where
  CPart c a
l <> :: CPart c a -> CPart c a -> CPart c a
<> CPart c a
r = CPart { pws :: HashMap KVar (WfC a)
pws = forall (c :: * -> *) a. CPart c a -> HashMap KVar (WfC a)
pws CPart c a
l forall a. Semigroup a => a -> a -> a
<> forall (c :: * -> *) a. CPart c a -> HashMap KVar (WfC a)
pws CPart c a
r
                 , pcm :: HashMap Integer (c a)
pcm = forall (c :: * -> *) a. CPart c a -> HashMap Integer (c a)
pcm CPart c a
l forall a. Semigroup a => a -> a -> a
<> forall (c :: * -> *) a. CPart c a -> HashMap Integer (c a)
pcm CPart c a
r
                 }

instance Monoid (CPart c a) where
   mempty :: CPart c a
mempty      = forall (c :: * -> *) a.
HashMap KVar (WfC a) -> HashMap Integer (c a) -> CPart c a
CPart forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty
   mappend :: CPart c a -> CPart c a -> CPart c a
mappend     = forall a. Semigroup a => a -> a -> a
(<>)

--------------------------------------------------------------------------------
-- | Multicore info ------------------------------------------------------------
--------------------------------------------------------------------------------

data MCInfo = MCInfo { MCInfo -> Int
mcCores       :: !Int
                     , MCInfo -> Int
mcMinPartSize :: !Int
                     , MCInfo -> Int
mcMaxPartSize :: !Int
                     } deriving (Int -> MCInfo -> ShowS
[MCInfo] -> ShowS
MCInfo -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MCInfo] -> ShowS
$cshowList :: [MCInfo] -> ShowS
show :: MCInfo -> String
$cshow :: MCInfo -> String
showsPrec :: Int -> MCInfo -> ShowS
$cshowsPrec :: Int -> MCInfo -> ShowS
Show)

mcInfo :: Config -> IO MCInfo
mcInfo :: Config -> IO MCInfo
mcInfo Config
c = do
   Int
np <- IO Int
getNumProcessors
   let nc :: Int
nc = forall a. a -> Maybe a -> a
fromMaybe Int
np (Config -> Maybe Int
cores Config
c)
   forall (m :: * -> *) a. Monad m => a -> m a
return MCInfo { mcCores :: Int
mcCores = Int
nc
                 , mcMinPartSize :: Int
mcMinPartSize = Config -> Int
minPartSize Config
c
                 , mcMaxPartSize :: Int
mcMaxPartSize = Config -> Int
maxPartSize Config
c
                 }

partition :: (F.Fixpoint a, F.Fixpoint (c a), F.TaggedC c a) => Config -> F.GInfo c a -> IO (F.Result (Integer, a))
partition :: forall a (c :: * -> *).
(Fixpoint a, Fixpoint (c a), TaggedC c a) =>
Config -> GInfo c a -> IO (Result (Integer, a))
partition Config
cfg GInfo c a
fi
  = do forall (c :: * -> *) a.
(Fixpoint (c a), Fixpoint a) =>
Config -> [GInfo c a] -> IO ()
dumpPartitions Config
cfg [GInfo c a]
fis
       -- writeGraph      f   g
       forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
    where
      --f   = queryFile Dot cfg
      fis :: [GInfo c a]
fis = forall (c :: * -> *) a.
TaggedC c a =>
Maybe MCInfo -> GInfo c a -> [GInfo c a]
partition' forall a. Maybe a
Nothing GInfo c a
fi

------------------------------------------------------------------------------
-- | Partition an FInfo into multiple disjoint FInfos. Info is Nothing to
--   produce the maximum possible number of partitions. Or a MultiCore Info
--   to control the partitioning
------------------------------------------------------------------------------
partition' :: (F.TaggedC c a)
           => Maybe MCInfo -> F.GInfo c a -> [F.GInfo c a]
------------------------------------------------------------------------------
partition' :: forall (c :: * -> *) a.
TaggedC c a =>
Maybe MCInfo -> GInfo c a -> [GInfo c a]
partition' Maybe MCInfo
mn GInfo c a
fi  = case Maybe MCInfo
mn of
   Maybe MCInfo
Nothing -> forall {a}. PartitionCtor c a a -> (GInfo c a -> a) -> [a]
fis forall (c :: * -> *) a.
GInfo c a
-> HashMap Int [(Integer, c a)]
-> HashMap Int [(KVar, WfC a)]
-> Int
-> GInfo c a
mkPartition forall a. a -> a
id
   Just MCInfo
mi -> forall (c :: * -> *) a.
MCInfo -> GInfo c a -> [CPart c a] -> [GInfo c a]
partitionN MCInfo
mi GInfo c a
fi forall a b. (a -> b) -> a -> b
$ forall {a}. PartitionCtor c a a -> (GInfo c a -> a) -> [a]
fis forall (c :: * -> *) a.
GInfo c a
-> HashMap Int [(Integer, c a)]
-> HashMap Int [(KVar, WfC a)]
-> Int
-> CPart c a
mkPartition' forall (c :: * -> *) a. GInfo c a -> CPart c a
finfoToCpart
  where
    css :: KVComps
css            = forall (c :: * -> *) a. TaggedC c a => GInfo c a -> KVComps
decompose GInfo c a
fi
    fis :: PartitionCtor c a a -> (GInfo c a -> a) -> [a]
fis PartitionCtor c a a
partF GInfo c a -> a
ctor = forall b a. b -> ([a] -> b) -> [a] -> b
applyNonNull [GInfo c a -> a
ctor GInfo c a
fi] (forall {b}. PartitionCtor c a b -> KVComps -> ListNE b
pbc PartitionCtor c a a
partF) KVComps
css
    pbc :: PartitionCtor c a b -> KVComps -> ListNE b
pbc PartitionCtor c a b
partF      = forall (c :: * -> *) a b.
PartitionCtor c a b -> GInfo c a -> KVComps -> ListNE b
partitionByConstraints PartitionCtor c a b
partF GInfo c a
fi

-- | Partition an FInfo into a specific number of partitions of roughly equal
-- amounts of work.
partitionN :: MCInfo        -- ^ Describes thresholds and partition amounts
           -> F.GInfo c a   -- ^ The originial FInfo
           -> [CPart c a]   -- ^ A list of the smallest possible CParts
           -> [F.GInfo c a] -- ^ At most N partitions of at least thresh work
partitionN :: forall (c :: * -> *) a.
MCInfo -> GInfo c a -> [CPart c a] -> [GInfo c a]
partitionN MCInfo
mi GInfo c a
fi [CPart c a]
cp
   | forall (c :: * -> *) a. CPart c a -> Int
cpartSize (forall (c :: * -> *) a. GInfo c a -> CPart c a
finfoToCpart GInfo c a
fi) forall a. Ord a => a -> a -> Bool
<= Int
minThresh = [GInfo c a
fi]
   | Bool
otherwise = forall a b. (a -> b) -> [a] -> [b]
map (forall (c :: * -> *) a. GInfo c a -> CPart c a -> GInfo c a
cpartToFinfo GInfo c a
fi) forall a b. (a -> b) -> a -> b
$ forall {c :: * -> *} {a}. [CPart c a] -> [CPart c a]
toNParts [CPart c a]
sortedParts
   where
      toNParts :: [CPart c a] -> [CPart c a]
toNParts [CPart c a]
p
         | forall {c :: * -> *} {a}. [CPart c a] -> Bool
isDone [CPart c a]
p = [CPart c a]
p
         | Bool
otherwise = [CPart c a] -> [CPart c a]
toNParts forall a b. (a -> b) -> a -> b
$ forall {c :: * -> *} {a}. CPart c a -> [CPart c a] -> [CPart c a]
insertSorted CPart c a
firstTwo [CPart c a]
rest
            where (CPart c a
firstTwo, [CPart c a]
rest) = forall {a}. Monoid a => [a] -> (a, [a])
unionFirstTwo [CPart c a]
p
      isDone :: [CPart c a] -> Bool
isDone [] = Bool
True
      isDone [CPart c a
_] = Bool
True
      isDone fi' :: [CPart c a]
fi'@(CPart c a
a:CPart c a
b:[CPart c a]
_) = forall (t :: * -> *) a. Foldable t => t a -> Int
length [CPart c a]
fi' forall a. Ord a => a -> a -> Bool
<= Int
prts
                            Bool -> Bool -> Bool
&& (forall (c :: * -> *) a. CPart c a -> Int
cpartSize CPart c a
a forall a. Ord a => a -> a -> Bool
>= Int
minThresh
                                Bool -> Bool -> Bool
|| forall (c :: * -> *) a. CPart c a -> Int
cpartSize CPart c a
a forall a. Num a => a -> a -> a
+ forall (c :: * -> *) a. CPart c a -> Int
cpartSize CPart c a
b forall a. Ord a => a -> a -> Bool
>= Int
maxThresh)
      sortedParts :: [CPart c a]
sortedParts = forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy forall {c :: * -> *} {a}. CPart c a -> CPart c a -> Ordering
sortPredicate [CPart c a]
cp
      unionFirstTwo :: [a] -> (a, [a])
unionFirstTwo (a
a:a
b:[a]
xs) = (a
a forall a. Monoid a => a -> a -> a
`mappend` a
b, [a]
xs)
      unionFirstTwo [a]
_        = forall a. (?callStack::CallStack) => String -> a
errorstar String
"Partition.partitionN.unionFirstTwo called on bad arguments"
      sortPredicate :: CPart c a -> CPart c a -> Ordering
sortPredicate = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Ord a => a -> a -> Ordering
compare forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` forall (c :: * -> *) a. CPart c a -> Int
cpartSize
      insertSorted :: CPart c a -> [CPart c a] -> [CPart c a]
insertSorted CPart c a
a []     = [CPart c a
a]
      insertSorted CPart c a
a (CPart c a
x:[CPart c a]
xs) = if forall {c :: * -> *} {a}. CPart c a -> CPart c a -> Ordering
sortPredicate CPart c a
a CPart c a
x forall a. Eq a => a -> a -> Bool
== Ordering
LT
                              then CPart c a
x forall a. a -> [a] -> [a]
: CPart c a -> [CPart c a] -> [CPart c a]
insertSorted CPart c a
a [CPart c a]
xs
                              else CPart c a
a forall a. a -> [a] -> [a]
: CPart c a
x forall a. a -> [a] -> [a]
: [CPart c a]
xs
      prts :: Int
prts      = MCInfo -> Int
mcCores MCInfo
mi
      minThresh :: Int
minThresh = MCInfo -> Int
mcMinPartSize MCInfo
mi
      maxThresh :: Int
maxThresh = MCInfo -> Int
mcMaxPartSize MCInfo
mi


-- | Return the "size" of a CPart. Used to determine if it's
-- substantial enough to be worth parallelizing.
cpartSize :: CPart c a -> Int
cpartSize :: forall (c :: * -> *) a. CPart c a -> Int
cpartSize CPart c a
c = (forall k v. HashMap k v -> Int
M.size forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (c :: * -> *) a. CPart c a -> HashMap Integer (c a)
pcm) CPart c a
c forall a. Num a => a -> a -> a
+ (forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (c :: * -> *) a. CPart c a -> HashMap KVar (WfC a)
pws) CPart c a
c

-- | Convert a CPart to an FInfo
cpartToFinfo :: F.GInfo c a -> CPart c a -> F.GInfo c a
cpartToFinfo :: forall (c :: * -> *) a. GInfo c a -> CPart c a -> GInfo c a
cpartToFinfo GInfo c a
fi CPart c a
p = GInfo c a
fi {ws :: HashMap KVar (WfC a)
F.ws = forall (c :: * -> *) a. CPart c a -> HashMap KVar (WfC a)
pws CPart c a
p, cm :: HashMap Integer (c a)
F.cm = forall (c :: * -> *) a. CPart c a -> HashMap Integer (c a)
pcm CPart c a
p}

-- | Convert an FInfo to a CPart
finfoToCpart :: F.GInfo c a -> CPart c a
finfoToCpart :: forall (c :: * -> *) a. GInfo c a -> CPart c a
finfoToCpart GInfo c a
fi = CPart { pcm :: HashMap Integer (c a)
pcm = forall (c :: * -> *) a. GInfo c a -> HashMap Integer (c a)
F.cm GInfo c a
fi
                        , pws :: HashMap KVar (WfC a)
pws = forall (c :: * -> *) a. GInfo c a -> HashMap KVar (WfC a)
F.ws GInfo c a
fi
                        }

-------------------------------------------------------------------------------------
dumpPartitions :: (F.Fixpoint (c a), F.Fixpoint a) => Config -> [F.GInfo c a] -> IO ()
-------------------------------------------------------------------------------------
dumpPartitions :: forall (c :: * -> *) a.
(Fixpoint (c a), Fixpoint a) =>
Config -> [GInfo c a] -> IO ()
dumpPartitions Config
cfg [GInfo c a]
fis =
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] [GInfo c a]
fis) forall a b. (a -> b) -> a -> b
$ \(Int
i, GInfo c a
fi) ->
    String -> String -> IO ()
writeFile (Ext -> Config -> String
queryFile (Int -> Ext
Part Int
i) Config
cfg) (Doc -> String
render forall a b. (a -> b) -> a -> b
$ forall a (c :: * -> *).
(Fixpoint a, Fixpoint (c a)) =>
Config -> GInfo c a -> Doc
F.toFixpoint Config
cfg GInfo c a
fi)


-- | Type alias for a function to construct a partition. mkPartition and
--   mkPartition' are the two primary functions that conform to this interface
type PartitionCtor c a b = F.GInfo c a
                       -> M.HashMap Int [(Integer, c a)]
                       -> M.HashMap Int [(F.KVar, F.WfC a)]
                       -> Int
                       -> b -- ^ typically a F.FInfo a or F.CPart a

partitionByConstraints :: PartitionCtor c a b -- ^ mkPartition or mkPartition'
                       -> F.GInfo c a
                       -> KVComps
                       -> ListNE b          -- ^ [F.FInfo a] or [F.CPart a]
partitionByConstraints :: forall (c :: * -> *) a b.
PartitionCtor c a b -> GInfo c a -> KVComps -> ListNE b
partitionByConstraints PartitionCtor c a b
f GInfo c a
fi KVComps
kvss = PartitionCtor c a b
f GInfo c a
fi HashMap Int [(Integer, c a)]
icM HashMap Int [(KVar, WfC a)]
iwM forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int]
js
  where
    js :: [Int]
js   = forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Int, [CVertex])]
jkvs                                -- groups
    gc :: Integer -> Int
gc   = forall k. (Show k, Eq k, Hashable k) => HashMap k Int -> k -> Int
groupFun HashMap Integer Int
cM                                 -- (i, ci) |-> j
    gk :: KVar -> Int
gk   = forall k. (Show k, Eq k, Hashable k) => HashMap k Int -> k -> Int
groupFun HashMap KVar Int
kM                                 -- k       |-> j

    iwM :: HashMap Int [(KVar, WfC a)]
iwM  = forall k a. (Eq k, Hashable k) => (a -> k) -> [a] -> HashMap k [a]
groupMap (KVar -> Int
gk forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (forall k v. HashMap k v -> [(k, v)]
M.toList (forall (c :: * -> *) a. GInfo c a -> HashMap KVar (WfC a)
F.ws GInfo c a
fi))    -- j |-> [w]
    icM :: HashMap Int [(Integer, c a)]
icM  = forall k a. (Eq k, Hashable k) => (a -> k) -> [a] -> HashMap k [a]
groupMap (Integer -> Int
gc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) (forall k v. HashMap k v -> [(k, v)]
M.toList (forall (c :: * -> *) a. GInfo c a -> HashMap Integer (c a)
F.cm GInfo c a
fi))    -- j |-> [(i, ci)]

    jkvs :: [(Int, [CVertex])]
jkvs = forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..] KVComps
kvss
    kvI :: [(CVertex, Int)]
kvI  = [ (CVertex
x, Int
j) | (Int
j, [CVertex]
kvs) <- [(Int, [CVertex])]
jkvs, CVertex
x <- [CVertex]
kvs ]
    kM :: HashMap KVar Int
kM   = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList [ (KVar
k, Int
i) | (KVar KVar
k, Int
i) <- [(CVertex, Int)]
kvI ]
    cM :: HashMap Integer Int
cM   = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList [ (Integer
c, Int
i) | (Cstr Integer
c, Int
i) <- [(CVertex, Int)]
kvI ]

mkPartition :: F.GInfo c a
            -> M.HashMap Int [(Integer, c a)]
            -> M.HashMap Int [(F.KVar, F.WfC a)]
            -> Int
            -> F.GInfo c a
mkPartition :: forall (c :: * -> *) a.
GInfo c a
-> HashMap Int [(Integer, c a)]
-> HashMap Int [(KVar, WfC a)]
-> Int
-> GInfo c a
mkPartition GInfo c a
fi HashMap Int [(Integer, c a)]
icM HashMap Int [(KVar, WfC a)]
iwM Int
j
  = GInfo c a
fi{ cm :: HashMap Integer (c a)
F.cm       = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList forall a b. (a -> b) -> a -> b
$ forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault [] Int
j HashMap Int [(Integer, c a)]
icM
      , ws :: HashMap KVar (WfC a)
F.ws       = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList forall a b. (a -> b) -> a -> b
$ forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault [] Int
j HashMap Int [(KVar, WfC a)]
iwM
      }

mkPartition' :: F.GInfo c a
             -> M.HashMap Int [(Integer, c a)]
             -> M.HashMap Int [(F.KVar, F.WfC a)]
             -> Int
             -> CPart c a
mkPartition' :: forall (c :: * -> *) a.
GInfo c a
-> HashMap Int [(Integer, c a)]
-> HashMap Int [(KVar, WfC a)]
-> Int
-> CPart c a
mkPartition' GInfo c a
_ HashMap Int [(Integer, c a)]
icM HashMap Int [(KVar, WfC a)]
iwM Int
j
  = CPart { pcm :: HashMap Integer (c a)
pcm       = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList forall a b. (a -> b) -> a -> b
$ forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault [] Int
j HashMap Int [(Integer, c a)]
icM
          , pws :: HashMap KVar (WfC a)
pws       = forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
M.fromList forall a b. (a -> b) -> a -> b
$ forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault [] Int
j HashMap Int [(KVar, WfC a)]
iwM
          }

groupFun :: (Show k, Eq k, Hashable k) => M.HashMap k Int -> k -> Int
groupFun :: forall k. (Show k, Eq k, Hashable k) => HashMap k Int -> k -> Int
groupFun HashMap k Int
m k
k = forall k v.
(?callStack::CallStack, Eq k, Hashable k) =>
String -> k -> HashMap k v -> v
safeLookup (String
"groupFun: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show k
k) k
k HashMap k Int
m