module Feldspar.Frontend where
import Prelude (Integral, Ord, RealFloat, RealFrac)
import qualified Prelude as P
import Prelude.EDSL
import Control.Monad.Identity
import Data.Bits (Bits, FiniteBits)
import qualified Data.Bits as Bits
import Data.Complex (Complex)
import Data.Int
import Data.List (genericLength)
import Language.Syntactic (Internal)
import Language.Syntactic.Functional
import qualified Language.Syntactic as Syntactic
import qualified Control.Monad.Operational.Higher as Oper
import Language.Embedded.Imperative (IxRange)
import qualified Language.Embedded.Imperative as Imp
import qualified Data.Inhabited as Inhabited
import Data.TypedStruct
import Feldspar.Primitive.Representation
import Feldspar.Representation
import Feldspar.Sugar ()
share :: (Syntax a, Syntax b)
=> a
-> (a -> b)
-> b
share = shareTag ""
shareTag :: (Syntax a, Syntax b)
=> String
-> a
-> (a -> b)
-> b
shareTag tag = sugarSymFeld (Let tag)
forLoop :: Syntax st => Data Length -> st -> (Data Index -> st -> st) -> st
forLoop = sugarSymFeld ForLoop
cond :: Syntax a
=> Data Bool
-> a
-> a
-> a
cond = sugarSymFeld Cond
(?) :: Syntax a
=> Data Bool
-> a
-> a
-> a
(?) = cond
infixl 1 ?
switch :: (Syntax a, Syntax b, PrimType (Internal a))
=> b
-> [(a,b)]
-> a
-> b
switch def [] _ = def
switch def cs s = P.foldr
(\(c,a) b -> desugar c == desugar s ? a $ b)
def
cs
value :: Syntax a => Internal a -> a
value = sugarSymFeld . Lit
false :: Data Bool
false = value False
true :: Data Bool
true = value True
instance Syntactic.Syntactic ()
where
type Domain () = FeldDomain
type Internal () = Int32
desugar () = unData 0
sugar _ = ()
example :: Syntax a => a
example = value Inhabited.example
instance (Bounded a, Type a) => Bounded (Data a)
where
minBound = value minBound
maxBound = value maxBound
instance (Num a, PrimType a) => Num (Data a)
where
fromInteger = value . fromInteger
(+) = sugarSymFeld Add
() = sugarSymFeld Sub
(*) = sugarSymFeld Mul
negate = sugarSymFeld Neg
abs = sugarSymFeld Abs
signum = sugarSymFeld Sign
instance (Fractional a, PrimType a) => Fractional (Data a)
where
fromRational = value . fromRational
(/) = sugarSymFeld FDiv
instance (Floating a, PrimType a) => Floating (Data a)
where
pi = sugarSymFeld Pi
exp = sugarSymFeld Exp
log = sugarSymFeld Log
sqrt = sugarSymFeld Sqrt
(**) = sugarSymFeld Pow
sin = sugarSymFeld Sin
cos = sugarSymFeld Cos
tan = sugarSymFeld Tan
asin = sugarSymFeld Asin
acos = sugarSymFeld Acos
atan = sugarSymFeld Atan
sinh = sugarSymFeld Sinh
cosh = sugarSymFeld Cosh
tanh = sugarSymFeld Tanh
asinh = sugarSymFeld Asinh
acosh = sugarSymFeld Acosh
atanh = sugarSymFeld Atanh
π :: (Floating a, PrimType a) => Data a
π = pi
quot :: (Integral a, PrimType a) => Data a -> Data a -> Data a
quot = sugarSymFeld Quot
rem :: (Integral a, PrimType a) => Data a -> Data a -> Data a
rem = sugarSymFeld Rem
quotRem :: (Integral a, PrimType a) => Data a -> Data a -> (Data a, Data a)
quotRem a b = (q,r)
where
q = quot a b
r = a b * q
div :: (Integral a, PrimType a) => Data a -> Data a -> Data a
div = sugarSymFeld Div
mod :: (Integral a, PrimType a) => Data a -> Data a -> Data a
mod = sugarSymFeld Mod
unsafeBalancedDiv :: (Integral a, PrimType a) => Data a -> Data a -> Data a
unsafeBalancedDiv a b = guardValLabel
InternalAssertion
(rem a b == 0)
"unsafeBalancedDiv: division not balanced"
(sugarSymFeld DivBalanced a b)
complex :: (Num a, PrimType a, PrimType (Complex a))
=> Data a
-> Data a
-> Data (Complex a)
complex = sugarSymFeld Complex
polar :: (Floating a, PrimType a, PrimType (Complex a))
=> Data a
-> Data a
-> Data (Complex a)
polar = sugarSymFeld Polar
realPart :: (PrimType a, PrimType (Complex a)) => Data (Complex a) -> Data a
realPart = sugarSymFeld Real
imagPart :: (PrimType a, PrimType (Complex a)) => Data (Complex a) -> Data a
imagPart = sugarSymFeld Imag
magnitude :: (RealFloat a, PrimType a, PrimType (Complex a)) =>
Data (Complex a) -> Data a
magnitude = sugarSymFeld Magnitude
phase :: (RealFloat a, PrimType a, PrimType (Complex a)) =>
Data (Complex a) -> Data a
phase = sugarSymFeld Phase
conjugate :: (RealFloat a, PrimType (Complex a)) =>
Data (Complex a) -> Data (Complex a)
conjugate = sugarSymFeld Conjugate
i2n :: (Integral i, Num n, PrimType i, PrimType n) => Data i -> Data n
i2n = sugarSymFeld I2N
i2b :: (Integral a, PrimType a) => Data a -> Data Bool
i2b = sugarSymFeld I2B
b2i :: (Integral a, PrimType a) => Data Bool -> Data a
b2i = sugarSymFeld B2I
round :: (RealFrac a, Num b, PrimType a, PrimType b) => Data a -> Data b
round = sugarSymFeld Round
not :: Data Bool -> Data Bool
not = sugarSymFeld Not
(&&) :: Data Bool -> Data Bool -> Data Bool
(&&) = sugarSymFeld And
infixr 3 &&
(||) :: Data Bool -> Data Bool -> Data Bool
(||) = sugarSymFeld Or
infixr 2 ||
(==) :: PrimType a => Data a -> Data a -> Data Bool
(==) = sugarSymFeld Eq
(/=) :: PrimType a => Data a -> Data a -> Data Bool
a /= b = not (a==b)
(<) :: (Ord a, PrimType a) => Data a -> Data a -> Data Bool
(<) = sugarSymFeld Lt
(>) :: (Ord a, PrimType a) => Data a -> Data a -> Data Bool
(>) = sugarSymFeld Gt
(<=) :: (Ord a, PrimType a) => Data a -> Data a -> Data Bool
(<=) = sugarSymFeld Le
(>=) :: (Ord a, PrimType a) => Data a -> Data a -> Data Bool
(>=) = sugarSymFeld Ge
infix 4 ==, /=, <, >, <=, >=
min :: (Ord a, PrimType a) => Data a -> Data a -> Data a
min a b = a<=b ? a $ b
max :: (Ord a, PrimType a) => Data a -> Data a -> Data a
max a b = a>=b ? a $ b
(.&.) :: (Bits a, PrimType a) => Data a -> Data a -> Data a
(.&.) = sugarSymFeld BitAnd
(.|.) :: (Bits a, PrimType a) => Data a -> Data a -> Data a
(.|.) = sugarSymFeld BitOr
xor :: (Bits a, PrimType a) => Data a -> Data a -> Data a
xor = sugarSymFeld BitXor
(⊕) :: (Bits a, PrimType a) => Data a -> Data a -> Data a
(⊕) = xor
complement :: (Bits a, PrimType a) => Data a -> Data a
complement = sugarSymFeld BitCompl
shiftL :: (Bits a, PrimType a)
=> Data a
-> Data Int32
-> Data a
shiftL = sugarSymFeld ShiftL
shiftR :: (Bits a, PrimType a)
=> Data a
-> Data Int32
-> Data a
shiftR = sugarSymFeld ShiftR
(.<<.) :: (Bits a, PrimType a)
=> Data a
-> Data Int32
-> Data a
(.<<.) = shiftL
(.>>.) :: (Bits a, PrimType a)
=> Data a
-> Data Int32
-> Data a
(.>>.) = shiftR
infixl 8 `shiftL`, `shiftR`, .<<., .>>.
infixl 7 .&.
infixl 6 `xor`
infixl 5 .|.
bitSize :: forall a . FiniteBits a => Data a -> Length
bitSize _ = P.fromIntegral $ Bits.finiteBitSize (a :: a)
where
a = P.error "finiteBitSize evaluates its argument"
allOnes :: (Bits a, Num a, PrimType a) => Data a
allOnes = complement 0
oneBits :: (Bits a, Num a, PrimType a) => Data Int32 -> Data a
oneBits n = complement (allOnes .<<. n)
lsbs :: (Bits a, Num a, PrimType a) => Data Int32 -> Data a -> Data a
lsbs k i = i .&. oneBits k
ilog2 :: (FiniteBits a, Integral a, PrimType a) => Data a -> Data a
ilog2 a = guardValLabel InternalAssertion (a >= 1) "ilog2: argument < 1" $
snd $ P.foldr (\ffi vr -> share vr (step ffi)) (a,0) ffis
where
step (ff,i) (v,r) =
share (b2i (v > fromInteger ff) .<<. value i) $ \shift ->
(v .>>. i2n shift, r .|. shift)
ffis
= (`P.zip` [0..])
$ P.takeWhile (P.<= (2 P.^ (bitSize a `P.div` 2) 1 :: Integer))
$ P.map ((subtract 1) . (2 P.^) . (2 P.^))
$ [(0::Integer)..]
arrIx :: Syntax a => IArr a -> Data Index -> a
arrIx arr i = resugar $ mapStruct ix $ unIArr arr
where
ix :: forall b . PrimType' b => Imp.IArr Index b -> Data b
ix arr' = sugarSymFeldPrim
(GuardVal InternalAssertion "arrIx: index out of bounds")
(i < length arr)
(sugarSymFeldPrim (ArrIx arr') (i + iarrOffset arr) :: Data b)
class Indexed a
where
type IndexedElem a
(!) :: a -> Data Index -> IndexedElem a
infixl 9 !
class Finite a
where
length :: a -> Data Length
instance Finite (Arr a) where length = arrLength
instance Finite (IArr a) where length = iarrLength
class Slicable a
where
slice
:: Data Index
-> Data Length
-> a
-> a
instance Syntax a => Indexed (IArr a)
where
type IndexedElem (IArr a) = a
(!) = arrIx
instance Slicable (Arr a)
where
slice from len (Arr o l arr) = Arr o' l' arr
where
o' = guardValLabel InternalAssertion (from<=l) "invalid Arr slice" (o+from)
l' = guardValLabel InternalAssertion (from+len<=l) "invalid Arr slice" len
instance Slicable (IArr a)
where
slice from len (IArr o l arr) = IArr o' l' arr
where
o' = guardValLabel InternalAssertion (from<=l) "invalid IArr slice" (o+from)
l' = guardValLabel InternalAssertion (from+len<=l) "invalid IArr slice" len
desugar :: Syntax a => a -> Data (Internal a)
desugar = Data . Syntactic.desugar
sugar :: Syntax a => Data (Internal a) -> a
sugar = Syntactic.sugar . unData
resugar :: (Syntax a, Syntax b, Internal a ~ Internal b) => a -> b
resugar = Syntactic.resugar
guardVal :: Syntax a
=> Data Bool
-> String
-> a
-> a
guardVal = guardValLabel $ UserAssertion ""
guardValLabel :: Syntax a
=> AssertionLabel
-> Data Bool
-> String
-> a
-> a
guardValLabel c cond msg = sugarSymFeld (GuardVal c msg) cond
unsafePerform :: Syntax a => Comp a -> a
unsafePerform = sugarSymFeld . UnsafePerform . fmap desugar
class Monad m => MonadComp m
where
liftComp :: Comp a -> m a
iff :: Data Bool -> m () -> m () -> m ()
for :: (Integral n, PrimType n) =>
IxRange (Data n) -> (Data n -> m ()) -> m ()
while :: m (Data Bool) -> m () -> m ()
instance MonadComp Comp
where
liftComp = id
iff c t f = Comp $ Imp.iff c (unComp t) (unComp f)
for range body = Comp $ Imp.for range (unComp . body)
while cont body = Comp $ Imp.while (unComp cont) (unComp body)
newRef :: (Syntax a, MonadComp m) => m (Ref a)
newRef = newNamedRef "r"
newNamedRef :: (Syntax a, MonadComp m)
=> String
-> m (Ref a)
newNamedRef base = liftComp $ fmap Ref $
mapStructA (const $ Comp $ Imp.newNamedRef base) typeRep
initRef :: (Syntax a, MonadComp m) => a -> m (Ref a)
initRef = initNamedRef "r"
initNamedRef :: (Syntax a, MonadComp m)
=> String
-> a
-> m (Ref a)
initNamedRef base =
liftComp . fmap Ref . mapStructA (Comp . Imp.initNamedRef base) . resugar
getRef :: (Syntax a, MonadComp m) => Ref a -> m a
getRef = liftComp . fmap resugar . mapStructA (Comp . Imp.getRef) . unRef
setRef :: (Syntax a, MonadComp m) => Ref a -> a -> m ()
setRef r
= liftComp
. sequence_
. zipListStruct (\r' a' -> Comp $ Imp.setRef r' a') (unRef r)
. resugar
modifyRef :: (Syntax a, MonadComp m) => Ref a -> (a -> a) -> m ()
modifyRef r f = setRef r . f =<< unsafeFreezeRef r
unsafeFreezeRef :: (Syntax a, MonadComp m) => Ref a -> m a
unsafeFreezeRef
= liftComp
. fmap resugar
. mapStructA (Comp . Imp.unsafeFreezeRef)
. unRef
newArr :: (Type (Internal a), MonadComp m) => Data Length -> m (Arr a)
newArr = newNamedArr "a"
newNamedArr :: (Type (Internal a), MonadComp m)
=> String
-> Data Length
-> m (Arr a)
newNamedArr base l = liftComp $ fmap (Arr 0 l) $
mapStructA (const (Comp $ Imp.newNamedArr base l)) typeRep
constArr :: (PrimType (Internal a), MonadComp m)
=> [Internal a]
-> m (Arr a)
constArr = constNamedArr "a"
constNamedArr :: (PrimType (Internal a), MonadComp m)
=> String
-> [Internal a]
-> m (Arr a)
constNamedArr base as =
liftComp $ fmap (Arr 0 len . Single) $ Comp $ Imp.constNamedArr base as
where
len = value $ genericLength as
getArr :: (Syntax a, MonadComp m) => Arr a -> Data Index -> m a
getArr arr i = do
assertLabel
InternalAssertion
(i < length arr)
"getArr: index out of bounds"
liftComp
$ fmap resugar
$ mapStructA (Comp . flip Imp.getArr (i + arrOffset arr))
$ unArr arr
setArr :: forall m a . (Syntax a, MonadComp m) =>
Arr a -> Data Index -> a -> m ()
setArr arr i a = do
assertLabel
InternalAssertion
(i < length arr)
"setArr: index out of bounds"
liftComp
$ sequence_
$ zipListStruct
(\a' arr' -> Comp $ Imp.setArr arr' (i + arrOffset arr) a') rep
$ unArr arr
where
rep = resugar a :: Struct PrimType' Data (Internal a)
copyArr :: MonadComp m
=> Arr a
-> Arr a
-> m ()
copyArr arr1 arr2 = do
assertLabel
InternalAssertion
(length arr1 >= length arr2)
"copyArr: destination too small"
liftComp $ sequence_ $
zipListStruct
(\a1 a2 ->
Comp $ Imp.copyArr
(a1, arrOffset arr1)
(a2, arrOffset arr2)
(length arr2)
)
(unArr arr1)
(unArr arr2)
freezeArr :: (Type (Internal a), MonadComp m) => Arr a -> m (IArr a)
freezeArr arr = liftComp $ do
arr2 <- newArr (length arr)
copyArr arr2 arr
unsafeFreezeArr arr2
freezeSlice :: (Type (Internal a), MonadComp m) =>
Data Length -> Arr a -> m (IArr a)
freezeSlice len = fmap (slice 0 len) . freezeArr
unsafeFreezeArr :: MonadComp m => Arr a -> m (IArr a)
unsafeFreezeArr arr
= liftComp
$ fmap (IArr (arrOffset arr) (length arr))
$ mapStructA (Comp . Imp.unsafeFreezeArr)
$ unArr arr
unsafeFreezeSlice :: MonadComp m => Data Length -> Arr a -> m (IArr a)
unsafeFreezeSlice len = fmap (slice 0 len) . unsafeFreezeArr
thawArr :: (Type (Internal a), MonadComp m) => IArr a -> m (Arr a)
thawArr arr = liftComp $ do
arr2 <- unsafeThawArr arr
arr3 <- newArr (length arr)
copyArr arr3 arr2
return arr3
unsafeThawArr :: MonadComp m => IArr a -> m (Arr a)
unsafeThawArr arr
= liftComp
$ fmap (Arr (iarrOffset arr) (length arr))
$ mapStructA (Comp . Imp.unsafeThawArr)
$ unIArr arr
constIArr :: (PrimType (Internal a), MonadComp m) =>
[Internal a] -> m (IArr a)
constIArr = constArr >=> unsafeFreezeArr
ifE :: (Syntax a, MonadComp m)
=> Data Bool
-> m a
-> m a
-> m a
ifE c t f = do
res <- newRef
iff c (t >>= setRef res) (f >>= setRef res)
unsafeFreezeRef res
break :: MonadComp m => m ()
break = liftComp $ Comp Imp.break
assert :: MonadComp m
=> Data Bool
-> String
-> m ()
assert = assertLabel $ UserAssertion ""
assertLabel :: MonadComp m
=> AssertionLabel
-> Data Bool
-> String
-> m ()
assertLabel c cond msg =
liftComp $ Comp $ Oper.singleInj $ Assert c cond msg
shareM :: (Syntax a, MonadComp m) => a -> m a
shareM = initRef >=> unsafeFreezeRef