{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fno-warn-missing-methods #-} -- | A minimal Feldspar core language implementation. The intention of this module is to demonstrate -- how to quickly make a language prototype using Syntactic. module NanoFeldspar where import Prelude hiding (max, min, not, (==), length, map, sum, zip, zipWith) import qualified Prelude import Data.Typeable import Language.Syntactic hiding (fold, printExpr, showAST, drawAST, writeHtmlAST) import qualified Language.Syntactic as Syntactic import Language.Syntactic.Functional import Language.Syntactic.Functional.Sharing import Language.Syntactic.Functional.Tuple import Language.Syntactic.Sugar.BindingTyped () import Language.Syntactic.Sugar.TupleTyped () import Language.Syntactic.TH -------------------------------------------------------------------------------- -- * Types -------------------------------------------------------------------------------- -- | Convenient class alias class (Typeable a, Show a, Eq a, Ord a) => Type a instance (Typeable a, Show a, Eq a, Ord a) => Type a type Length = Int type Index = Int -------------------------------------------------------------------------------- -- * Abstract syntax -------------------------------------------------------------------------------- data Arithmetic sig where Add :: (Type a, Num a) => Arithmetic (a :-> a :-> Full a) Sub :: (Type a, Num a) => Arithmetic (a :-> a :-> Full a) Mul :: (Type a, Num a) => Arithmetic (a :-> a :-> Full a) deriveSymbol ''Arithmetic deriveEquality ''Arithmetic instance StringTree Arithmetic instance EvalEnv Arithmetic env instance Render Arithmetic where renderSym Add = "(+)" renderSym Sub = "(-)" renderSym Mul = "(*)" renderArgs = renderArgsSmart instance Eval Arithmetic where evalSym Add = (+) evalSym Sub = (-) evalSym Mul = (*) data Parallel sig where Parallel :: Type a => Parallel (Length :-> (Index -> a) :-> Full [a]) deriveSymbol ''Parallel deriveRender id ''Parallel deriveEquality ''Parallel instance StringTree Parallel instance EvalEnv Parallel env instance Eval Parallel where evalSym Parallel = \len ixf -> Prelude.map ixf [0 .. len-1] data ForLoop sig where ForLoop :: Type st => ForLoop (Length :-> st :-> (Index -> st -> st) :-> Full st) deriveSymbol ''ForLoop deriveRender id ''ForLoop deriveEquality ''ForLoop instance StringTree ForLoop instance EvalEnv ForLoop env instance Eval ForLoop where evalSym ForLoop = \len init body -> foldl (flip body) init [0 .. len-1] type FeldDomain = Typed ( BindingT :+: Let :+: Tuple :+: Arithmetic :+: Parallel :+: ForLoop :+: Construct ) -- `Construct` can be used to create arbitrary symbols from a name and an -- evaluation function. We could have used `Construct` for all symbols, but -- the problem with `Construct` is that it does not know about the arity or -- type of the construct it represents, so it's easy to make mistakes, e.g. -- when transforming expressions with `Construct` symbols. newtype Data a = Data { unData :: ASTF FeldDomain a } -- | Declaring 'Data' as syntactic sugar instance Type a => Syntactic (Data a) where type Domain (Data a) = FeldDomain type Internal (Data a) = a desugar = unData sugar = Data -- | Specialization of the 'Syntactic' class for the Feldspar domain class (Syntactic a, Domain a ~ FeldDomain, Type (Internal a)) => Syntax a instance (Syntactic a, Domain a ~ FeldDomain, Type (Internal a)) => Syntax a instance Type a => Show (Data a) where show = showExpr -------------------------------------------------------------------------------- -- * "Backends" -------------------------------------------------------------------------------- -- | Interface for controlling code motion cmInterface :: CodeMotionInterface FeldDomain cmInterface = defaultInterface VarT LamT sharable (const True) where sharable :: ASTF FeldDomain a -> ASTF FeldDomain b -> Bool sharable (Sym _) _ = False -- Simple expressions not shared sharable (lam :$ _) _ | Just _ <- prLam lam = False -- Lambdas not shared sharable _ (lam :$ _) | Just _ <- prLam lam = False -- Don't place let bindings over lambdas. This ensures that function -- arguments of higher-order constructs such as `Parallel` are always -- lambdas. sharable (sel :$ _) _ | Just Sel1 <- prj sel = False | Just Sel2 <- prj sel = False | Just Sel3 <- prj sel = False | Just Sel4 <- prj sel = False -- Tuple selection not shared sharable (arrl :$ _ ) _ | Just (Construct "arrLen" _) <- prj arrl = False -- Array length not shared sharable (gix :$ _ :$ _) _ | Just (Construct "arrIx" _) <- prj gix = False -- Array indexing not shared sharable _ _ = True -- | Show the expression showExpr :: (Syntactic a, Domain a ~ FeldDomain) => a -> String showExpr = render . codeMotion cmInterface . desugar -- | Print the expression printExpr :: (Syntactic a, Domain a ~ FeldDomain) => a -> IO () printExpr = putStrLn . showExpr -- | Show the syntax tree using unicode art showAST :: (Syntactic a, Domain a ~ FeldDomain) => a -> String showAST = Syntactic.showAST . codeMotion cmInterface . desugar -- | Draw the syntax tree on the terminal using unicode art drawAST :: (Syntactic a, Domain a ~ FeldDomain) => a -> IO () drawAST = putStrLn . showAST -- | Write the syntax tree to an HTML file with foldable nodes writeHtmlAST :: (Syntactic a, Domain a ~ FeldDomain) => a -> IO () writeHtmlAST = Syntactic.writeHtmlAST "tree.html" . codeMotion cmInterface . desugar -- | Evaluate an expression eval :: (Syntactic a, Domain a ~ FeldDomain) => a -> Internal a eval = evalClosed . desugar -------------------------------------------------------------------------------- -- * Front end -------------------------------------------------------------------------------- -- | Literal value :: Syntax a => Internal a -> a value a = sugar $ injT $ Construct (show a) a false :: Data Bool false = value False true :: Data Bool true = value True -- | Force computation force :: Syntax a => a -> a force = resugar instance (Type a, Num a) => Num (Data a) where fromInteger = value . fromInteger (+) = sugarSymTyped Add (-) = sugarSymTyped Sub (*) = sugarSymTyped Mul -- | Explicit sharing share :: (Syntax a, Syntax b) => a -> (a -> b) -> b share = sugarSymTyped Let -- | Parallel array parallel :: Type a => Data Length -> (Data Index -> Data a) -> Data [a] parallel = sugarSymTyped Parallel -- | For loop forLoop :: Syntax st => Data Length -> st -> (Data Index -> st -> st) -> st forLoop = sugarSymTyped ForLoop -- | Conditional expression (?) :: forall a . Syntax a => Data Bool -> (a,a) -> a c ? (t,f) = sugarSymTyped sym c t f where sym :: Construct (Bool :-> Internal a :-> Internal a :-> Full (Internal a)) sym = Construct "cond" (\c t f -> if c then t else f) -- | Get the length of an array arrLen :: Type a => Data [a] -> Data Length arrLen = sugarSymTyped $ Construct "arrLen" Prelude.length -- | Index into an array arrIx :: Type a => Data [a] -> Data Index -> Data a arrIx = sugarSymTyped $ Construct "arrIx" eval where eval as i | i >= len || i < 0 = error "arrIx: index out of bounds" | otherwise = as !! i where len = Prelude.length as not :: Data Bool -> Data Bool not = sugarSymTyped $ Construct "not" Prelude.not (==) :: Type a => Data a -> Data a -> Data Bool (==) = sugarSymTyped $ Construct "(==)" (Prelude.==) max :: Type a => Data a -> Data a -> Data a max = sugarSymTyped $ Construct "max" Prelude.max min :: Type a => Data a -> Data a -> Data a min = sugarSymTyped $ Construct "min" Prelude.min -------------------------------------------------------------------------------- -- * Vector library -------------------------------------------------------------------------------- data Vector a where Indexed :: Data Length -> (Data Index -> a) -> Vector a instance Syntax a => Syntactic (Vector a) where type Domain (Vector a) = FeldDomain type Internal (Vector a) = [Internal a] desugar = desugar . freezeVector . map resugar sugar = map resugar . thawVector . sugar length :: Vector a -> Data Length length (Indexed len _) = len indexed :: Data Length -> (Data Index -> a) -> Vector a indexed = Indexed index :: Vector a -> Data Index -> a index (Indexed _ ixf) = ixf (!) :: Vector a -> Data Index -> a Indexed _ ixf ! i = ixf i infixl 9 ! freezeVector :: Type a => Vector (Data a) -> Data [a] freezeVector vec = parallel (length vec) (index vec) thawVector :: Type a => Data [a] -> Vector (Data a) thawVector arr = Indexed (arrLen arr) (arrIx arr) zip :: Vector a -> Vector b -> Vector (a,b) zip a b = indexed (length a `min` length b) (\i -> (index a i, index b i)) unzip :: Vector (a,b) -> (Vector a, Vector b) unzip ab = (indexed len (fst . index ab), indexed len (snd . index ab)) where len = length ab permute :: (Data Length -> Data Index -> Data Index) -> (Vector a -> Vector a) permute perm vec = indexed len (index vec . perm len) where len = length vec reverse :: Vector a -> Vector a reverse = permute $ \len i -> len-i-1 (...) :: Data Index -> Data Index -> Vector (Data Index) l ... h = indexed (h-l+1) (+l) map :: (a -> b) -> Vector a -> Vector b map f (Indexed len ixf) = Indexed len (f . ixf) zipWith :: (a -> b -> c) -> Vector a -> Vector b -> Vector c zipWith f a b = map (uncurry f) $ zip a b fold :: Syntax b => (a -> b -> b) -> b -> Vector a -> b fold f b (Indexed len ixf) = forLoop len b (\i st -> f (ixf i) st) fold1 :: Syntax a => (a -> a -> a) -> Vector a -> a fold1 f (Indexed len ixf) = forLoop len (ixf 0) (\i st -> f (ixf i) st) sum :: (Num a, Syntax a) => Vector a -> a sum = fold (+) 0 type Matrix a = Vector (Vector (Data a)) -- | Transpose of a matrix. Assumes that the number of rows is > 0. transpose :: Type a => Matrix a -> Matrix a transpose a = indexed (length (a!0)) $ \k -> indexed (length a) $ \l -> a ! l ! k -------------------------------------------------------------------------------- -- * Examples -------------------------------------------------------------------------------- -- | Fibonacci function fib :: Data Int -> Data Int fib n = fst $ forLoop n (0,1) $ \_ (a,b) -> (b,a+b) -- | The span of a vector (difference between greatest and smallest element) spanVec :: Vector (Data Int) -> Data Int spanVec vec = hi-lo where (lo,hi) = fold (\a (l,h) -> (min a l, max a h)) (vec!0,vec!0) vec -- This demonstrates how tuples interplay with sharing. Tuples are essentially -- useless without sharing. This function would get two identical for loops if -- it wasn't for sharing. -- | Scalar product scProd :: Vector (Data Float) -> Vector (Data Float) -> Data Float scProd a b = sum (zipWith (*) a b) forEach = flip map -- | Matrix multiplication matMul :: Matrix Float -> Matrix Float -> Matrix Float matMul a b = forEach a $ \a' -> forEach (transpose b) $ \b' -> scProd a' b'