{-# LANGUAGE TemplateHaskell #-} module Prediction where import Data.Maybe import Data.Either import Data.Graph import Data.List hiding (transpose) import Data.Matrix hiding ((!), trace) import Data.Map.Strict (Map, (!)) import qualified Data.Map.Strict as Map import Test.QuickCheck import TypeInfo import Countable import Debug.Trace type Size = Int -- A FreqMap is a mapping between type constructor names and Int values -- representing the frequencies we want each one of them to occur in a random -- generated value. This mapping is later provided to QuickCheck `frequency` in -- order to derivate random value generators. type FreqMap = Map Name Int -- Hardcoded instances distributions. This does not work very well, since even -- if we hardcode the frequencies, the built-in instances does not reduce the -- size of the inner generation. The solution requires to get rid of the -- Arbitrary instances and carry arround concrete generators. builtInFreqs :: FreqMap builtInFreqs = Map.fromList [ {- Bool-} ('True, 1), ('False, 1) , {- Either -} ('Left, 1), ('Right, 1) ] initMap :: TypeEnv -> FreqMap initMap = Map.fromList . map setInitialFreq . consList where setInitialFreq cn | Map.member cn builtInFreqs = (cn, builtInFreqs ! cn * 100) | otherwise = (cn, 100) -- A ProbMap is similar to a FreqMap in the sense that it represents types -- constructor names vs. frequencies. The difference lie in the fact that for -- every data type T = C1 .. | C2 .. | ... | Cn then it must hold that -- pC1 + pC2 + ... + pCn = 1. type ProbMap = Map Name Double showMap :: (Show a, Show b) => Map a b -> String showMap m = intercalate "\n" (map showElem (Map.toList m)) where showElem e = " * " ++ show e filterKeys :: (Name -> Bool) -> Map Name b -> Map Name b filterKeys f = Map.filterWithKey (const . f) normalize :: TypeEnv -> FreqMap -> ProbMap normalize env freqMap = Map.mapWithKey freqRatio freqMap where freqRatio cn cfreq = fromIntegral cfreq / fromIntegral (freqSum cn) freqSum cn = Map.foldr (+) 0 (filterKeys (isSibling env cn) freqMap) normalizeTerminals :: TypeEnv -> FreqMap -> ProbMap normalizeTerminals env freqMap = Map.mapWithKey freqRatio terminalsMap where terminalsMap = filterKeys (isTerminal env) freqMap freqRatio cn cfreq = fromIntegral cfreq / fromIntegral (freqSum cn) freqSum cn = Map.foldr (+) 0 (filterKeys (isSibling env cn) terminalsMap) genGWMatrix :: TypeEnv -> ProbMap -> Matrix Double genGWMatrix env probMap = matrix size size genElem where size = length env genElem (m, n) = sum $ map multProb $ occsFromTo (env!!(m-1)) (env!!(n-1)) multProb (cn, occs) = probMap ! cn * fromIntegral occs occsFromTo from to = map (conOccurrences to) (tcons from) conOccurrences to con = (cname con, occurrences (tsig to) con) -- Predicts the distribution for a given type constructor frequencies map. predict :: TypeEnv -> Size -> FreqMap -> ProbMap predict env size freqs = prediction where -- Normalize the frequencies into probabilities allProbs = normalize env freqs termProbs = normalizeTerminals env freqs -- Split the type environment into branching types and leaf types. rootType = env !! 0 isBranchingType t = t == rootType || any rec (tcons t) (branchingTypes, leafTypes) = partition isBranchingType env -- Helpers bct = tsig . conType branchingTypes -- branching constructor type lct = tsig . conType leafTypes -- simple constructor type m !$ cn = Map.findWithDefault 0 cn m -- safe lookup ---------------------------------------------------------------------------- -- First we need to calculate the expectancy of the types involved at the -- branching process, we do this by calculating the pure random generation -- process at the first (size-1) levels, and adding the expectancy of the -- pseudo-random generation of the last level. ---------------------------------------------------------------------------- branchingTypesExp = Map.unionWith (+) brFirstLevels brLastLevel branchingProbs = filterKeys isBranchingTypeCon allProbs branchingTermProbs = filterKeys isBranchingTypeCon termProbs isBranchingTypeCon cn = cn `elem` consList branchingTypes branchingSigs = typeSigs branchingTypes -- Generate the Galton-Watson matrix with the given branching probabilities. mT = genGWMatrix branchingTypes branchingProbs ez0 = fromList 1 (length branchingTypes) (1 : repeat 0) genLevel 0 = ez0 genLevel k = ez0 * (mT^k) {- Branching process @ first (size-1) levels -} brFirstLevels = Map.mapWithKey multTypeExp branchingProbs where multTypeExp cn cp -- Is safe to use the geometric series simplification formula | length branchingTypes == 1 && mT' /= 1 = cp * ((1 - mT' ^ size) / (1 - mT')) -- Otherwise, we need to sum every level :( | otherwise = cp * typeExp ! bct cn mT' = getElem 1 1 mT typeExp = Map.fromList $ zip branchingSigs (toList predMatrix) predMatrix = foldr1 (+) (map genLevel [0..size-1]) {- Branching process @ last level -} brLastLevel = Map.mapWithKey sumTermExp branchingTermProbs where sumTermExp tn tp = sum [ tp * allProbs ! cname con * fromIntegral (occurrences (bct tn) con) * prevLvlExp ! bct (cname con) | con <- concatMap tcons branchingTypes ] prevLvlExp = Map.fromList $ zip branchingSigs (toList (genLevel (size-1))) ---------------------------------------------------------------------------- -- Once we have the expectancy for every type constructor involved at the -- branching process, we can incorporate the expectancy of the leaf types by -- counting how many times they are generated as result of the branching -- process. It is important to note here that a leaf type could generate -- another leaf type, so we need to perform a topological sort in order to -- start calculating the expectancy of the 'nearest' types to the branching -- process ones. This way the farthest ones are not multiplied by zero. ---------------------------------------------------------------------------- prediction = addLeafTypesExp branchingTypesExp sortedLeafTypeCons addLeafTypesExp pred [] = pred addLeafTypesExp pred (cn:cns) = addLeafTypesExp (Map.insert cn (sumOccurrences pred cn) pred) cns sumOccurrences pred cn = sum [ allProbs ! cn * pred !$ cname con * fromIntegral (occurrences (lct cn) con) | con <- allCons ] allCons = concatMap tcons env leafTypeCons = concatMap tcons leafTypes generatorsOf cn = [ cname con | con <- allCons, any (== lct cn) (cargs con) ] sortedLeafTypeCons = reverse (map (extractCName . gvert) (topSort graph)) (graph, gvert) = graphFromEdges' leafTypeDeps leafTypeDeps = map createVertex leafTypeCons extractCName (_, cn, _) = cn createVertex con = ((), cname con, generatorsOf (cname con)) -- Generates a bunch of samples using a given generator and prints the average -- number of type constructors generated in a random sample of size n. The -- _arb_ parameter needs a type annotation when the function is used with -- _arbitraty_ in order to break the ambiguity. -- E.g. confirm 10 (arbitrary @Tree) confirm :: (Countable a) => Size -> Gen a -> IO () confirm size arb = do let samples = 100000 values <- sequence (replicate samples (generate (resize size arb))) let consCount = Map.unionsWith (+) (map count values) consAvg = Map.map (\c -> fromIntegral c / fromIntegral samples) consCount putStrLn (showMap consAvg)