{-# LANGUAGE CPP , GADTs , DataKinds , TypeFamilies , FlexibleContexts , UndecidableInstances , LambdaCase , OverloadedStrings #-} {-# OPTIONS_GHC -Wall -fwarn-tabs -fsimpl-tick-factor=1000 #-} module Language.Hakaru.Runtime.Prelude where #if __GLASGOW_HASKELL__ < 710 import Data.Functor ((<$>)) import Control.Applicative (Applicative(..)) #endif import Data.Foldable as F import qualified System.Random.MWC as MWC import qualified System.Random.MWC.Distributions as MWCD import Data.Number.Natural import qualified Data.Vector as V import qualified Data.Vector.Unboxed as U import qualified Data.Vector.Generic as G import Control.Monad import Prelude hiding (product) type family MinBoxVec (v1 :: * -> *) (v2 :: * -> *) :: * -> * type instance MinBoxVec V.Vector v = V.Vector type instance MinBoxVec v V.Vector = V.Vector type instance MinBoxVec U.Vector U.Vector = U.Vector type family MayBoxVec a :: * -> * type instance MayBoxVec Int = U.Vector type instance MayBoxVec Double = U.Vector type instance MayBoxVec (U.Vector a) = V.Vector type instance MayBoxVec (V.Vector a) = V.Vector type instance MayBoxVec (a,b) = MinBoxVec (MayBoxVec a) (MayBoxVec b) lam :: (a -> b) -> a -> b lam = id app :: (a -> b) -> a -> b app f x = f x let_ :: a -> (a -> b) -> b let_ x f = let x1 = x in f x1 ann_ :: a -> b -> b ann_ _ a = a newtype Measure a = Measure { unMeasure :: MWC.GenIO -> IO (Maybe a) } instance Functor Measure where fmap = liftM instance Applicative Measure where pure x = Measure $ \_ -> return (Just x) (<*>) = ap instance Monad Measure where return = pure m >>= f = Measure $ \g -> do Just x <- unMeasure m g unMeasure (f x) g makeMeasure :: (MWC.GenIO -> IO a) -> Measure a makeMeasure f = Measure $ \g -> Just <$> f g uniform :: Double -> Double -> Measure Double uniform lo hi = makeMeasure $ MWC.uniformR (lo, hi) normal :: Double -> Double -> Measure Double normal mu sd = makeMeasure $ MWCD.normal mu sd beta :: Double -> Double -> Measure Double beta a b = makeMeasure $ MWCD.beta a b gamma :: Double -> Double -> Measure Double gamma a b = makeMeasure $ MWCD.gamma a b categorical :: MayBoxVec Double Double -> Measure Int categorical a = makeMeasure (\g -> fromIntegral <$> MWCD.categorical a g) plate :: (G.Vector (MayBoxVec a) a) => Int -> (Int -> Measure a) -> Measure (MayBoxVec a a) plate n f = G.generateM (fromIntegral n) $ \x -> f (fromIntegral x) pair :: a -> b -> (a, b) pair = (,) true, false :: Bool true = True false = False unit :: () unit = () data Pattern = PVar | PWild newtype Branch a b = Branch { extract :: a -> Maybe b } ptrue, pfalse :: a -> Branch Bool a ptrue b = Branch { extract = extractBool True b } pfalse b = Branch { extract = extractBool False b } extractBool :: Bool -> a -> Bool -> Maybe a extractBool b a p | p == b = Just a | otherwise = Nothing ppair :: Pattern -> Pattern -> (x -> y -> b) -> Branch (x,y) b ppair PVar PVar c = Branch { extract = (\(x,y) -> Just (c x y)) } ppair _ _ _ = error "ppair: TODO" case_ :: a -> [Branch a b] -> b case_ e_ bs_ = go e_ bs_ where go _ [] = error "case_: unable to match any branches" go e (b:bs) = case extract b e of Just b' -> b' Nothing -> go e bs branch :: (c -> Branch a b) -> c -> Branch a b branch pat body = pat body dirac :: a -> Measure a dirac = return pose :: Double -> Measure a -> Measure a pose _ a = a superpose :: [(Double, Measure a)] -> Measure a superpose pms = do i <- makeMeasure $ MWCD.categorical (U.fromList $ map fst pms) snd (pms !! i) reject :: Measure a reject = Measure $ \_ -> return Nothing nat_ :: Int -> Int nat_ = id int_ :: Int -> Int int_ = id unsafeNat :: Int -> Int unsafeNat = id nat2prob :: Int -> Double nat2prob = fromIntegral fromInt :: Int -> Double fromInt = fromIntegral nat2int :: Int -> Int nat2int = id nat2real :: Int -> Double nat2real = fromIntegral fromProb :: Double -> Double fromProb = id unsafeProb :: Double -> Double unsafeProb = id real_ :: Rational -> Double real_ = fromRational prob_ :: NonNegativeRational -> Double prob_ = fromRational . fromNonNegativeRational infinity :: Double infinity = 1/0 abs_ :: Num a => a -> a abs_ = abs thRootOf :: Int -> Double -> Double thRootOf a b = b ** (recip $ fromIntegral a) array :: (G.Vector (MayBoxVec a) a) => Int -> (Int -> a) -> MayBoxVec a a array n f = G.generate (fromIntegral n) (f . fromIntegral) (!) :: (G.Vector (MayBoxVec a) a) => MayBoxVec a a -> Int -> a a ! b = a G.! (fromIntegral b) size :: (G.Vector (MayBoxVec a) a) => MayBoxVec a a -> Int size v = fromIntegral (G.length v) product :: Num a => Int -> Int -> (Int -> a) -> a product a b f = F.foldl' (\x y -> x * f y) 1 [a .. b-1] summate :: Num a => Int -> Int -> (Int -> a) -> a summate a b f = F.foldl' (\x y -> x + f y) 0 [a .. b-1] run :: Show a => MWC.GenIO -> Measure a -> IO () run g k = unMeasure k g >>= \case Just a -> print a Nothing -> return () iterateM_ :: Monad m => (a -> m a) -> a -> m b iterateM_ f = g where g x = f x >>= g withPrint :: Show a => (a -> IO b) -> a -> IO b withPrint f x = print x >> f x