module Language.Paraiso.OM.Builder.Internal
(
Builder, BuilderState(..),
B, BuilderOf,
buildKernel, initState,
modifyG, getG, freeNode, addNode, addNodeE, valueToNode, lookUpStatic,
bind,
load, store,
reduce, broadcast,
shift, loadIndex,loadSize,
imm, mkOp1, mkOp2,
annotate, (<?>),
withAnnotation
) where
import qualified Algebra.Absolute as Absolute
import qualified Algebra.Additive as Additive
import qualified Algebra.Algebraic as Algebraic
import qualified Algebra.Field as Field
import qualified Algebra.IntegralDomain as IntegralDomain
import qualified Algebra.Lattice as Lattice
import qualified Algebra.Ring as Ring
import qualified Algebra.Transcendental as Transcendental
import qualified Algebra.ZeroTestable as ZeroTestable
import Control.Monad
import qualified "mtl" Control.Monad.State as State
import qualified Data.Graph.Inductive as FGL
import Data.Dynamic (Typeable)
import qualified Data.Dynamic as Dynamic
import Data.Tensor.TypeLevel
import qualified Data.Vector as V
import Language.Paraiso.Name
import qualified Language.Paraiso.OM.Arithmetic as A
import Language.Paraiso.OM.DynValue as DVal
import Language.Paraiso.OM.Graph
import Language.Paraiso.OM.Realm as Realm
import Language.Paraiso.OM.Reduce as Reduce
import Language.Paraiso.OM.Value as Val
import Language.Paraiso.Prelude
import qualified Prelude (Num(..), Fractional(..))
import NumericPrelude hiding ((++))
buildKernel ::
Setup v g a
-> Name
-> Builder v g a ()
-> Kernel v g a
buildKernel setup0 name0 builder0 = let
state0 = initState setup0
graph = target $ snd $ State.runState builder0 state0
in Kernel{kernelName = name0, dataflow = graph}
data BuilderState vector gauge anot = BuilderState
{ setup :: Setup vector gauge anot,
context :: BuilderContext anot,
target :: Graph vector gauge anot} deriving (Show)
data BuilderContext anot =
BuilderContext
{ currentAnnotation :: anot } deriving (Show)
initState :: Setup v g a -> BuilderState v g a
initState s = BuilderState {
setup = s,
context = BuilderContext{currentAnnotation = globalAnnotation s},
target = FGL.empty
}
type Builder (vector :: * -> *) (gauge :: *) (anot :: *) (val :: *) =
State.State (BuilderState vector gauge anot) val
instance Eq (Builder v g a ret) where
_ == _ = undefined
instance Show (Builder v g a ret) where
show _ = "<<REDACTED>>"
type B ret = forall (v :: * -> *) (g :: *) (a :: *). Builder v g a ret
type BuilderOf r c = forall (v :: * -> *) (g :: *) (a :: *). Builder v g a (Value r c)
modifyG ::
(Graph v g a -> Graph v g a)
-> Builder v g a ()
modifyG f = State.modify (\bs -> bs{target = f.target $ bs})
getG :: Builder v g a (Graph v g a)
getG = fmap target State.get
freeNode :: B FGL.Node
freeNode = do
n <- fmap (FGL.noNodes) getG
return n
addNode ::
[FGL.Node]
-> Node v g a
-> Builder v g a FGL.Node
addNode froms new = do
n <- freeNode
modifyG (([(EOrd i, froms !! i) | i <-[0..length froms 1] ], n, new, []) FGL.&)
return n
addNodeE ::
[FGL.Node]
-> (a -> Node v g a)
-> Builder v g a FGL.Node
addNodeE froms new' = do
anot <- fmap (currentAnnotation . context) State.get
addNode froms (new' anot)
valueToNode :: (TRealm r, Typeable c) => Value r c -> B FGL.Node
valueToNode val = do
let
con = Val.content val
type0 = toDyn val
case val of
FromNode _ _ n -> return n
FromImm _ _ -> do
n0 <- addNodeE [] $ NInst (Imm (Dynamic.toDyn con))
n1 <- addNodeE [n0] $ NValue type0
return n1
lookUpStatic :: Named DynValue -> B StaticIdx
lookUpStatic (Named name0 type0)= do
st <- State.get
let
vs :: V.Vector (Named DynValue)
vs = staticValues $ setup st
matches = V.filter (\(_,v)-> name v==name0) $ V.imap (\i v->(i,v)) vs
(ret, Named _ type1) = if V.length matches /= 1
then error (show (V.length matches)++" match found for '" ++ nameStr name0 ++
"' in " ++ show vs)
else V.head matches
when (type0 /= type1) $ error ("type mismatch; expected: " ++ show type1 ++ "; " ++
" actual: " ++ nameStr name0 ++ "::" ++ show type0)
return $ StaticIdx ret
bind :: (Monad m, Functor m) => m a -> m (m a)
bind = fmap return
load :: (TRealm r, Typeable c) =>
r
-> c
-> Name
-> B (Value r c)
load r0 c0 name0 = do
let
type0 = mkDyn r0 c0
nv = Named name0 type0
idx <- lookUpStatic nv
n0 <- addNodeE [] $ NInst (Load idx)
n1 <- addNodeE [n0] $ NValue type0
return (FromNode r0 c0 n1)
store :: (TRealm r, Typeable c) =>
Name
-> Builder v g a (Value r c)
-> Builder v g a ()
store name0 builder0 = do
val0 <- builder0
let
type0 = toDyn val0
nv = Named name0 type0
idx <- lookUpStatic nv
n0 <- valueToNode val0
_ <- addNodeE [n0] $ NInst (Store idx)
return ()
reduce :: (Typeable c) =>
Reduce.Operator
-> Builder v g a (Value TArray c)
-> Builder v g a (Value TScalar c)
reduce op builder1 = do
val1 <- builder1
let
c1 = Val.content val1
type2 = mkDyn TScalar c1
n1 <- valueToNode val1
n2 <- addNodeE [n1] $ NInst (Reduce op)
n3 <- addNodeE [n2] $ NValue type2
return (FromNode TScalar c1 n3)
broadcast :: (Typeable c) =>
Builder v g a (Value TScalar c)
-> Builder v g a (Value TArray c)
broadcast builder1 = do
val1 <- builder1
let
c1 = Val.content val1
type2 = mkDyn TArray c1
n1 <- valueToNode val1
n2 <- addNodeE [n1] $ NInst Broadcast
n3 <- addNodeE [n2] $ NValue type2
return (FromNode TArray c1 n3)
shift :: (Typeable c)
=> v g
-> Builder v g a (Value TArray c)
-> Builder v g a (Value TArray c)
shift vec builder1 = do
val1 <- builder1
let
type1 = toDyn val1
c1 = Val.content val1
n1 <- valueToNode val1
n2 <- addNodeE [n1] $ NInst $ Shift vec
n3 <- addNodeE [n2] $ NValue type1
return (FromNode TArray c1 n3)
loadIndex :: (Typeable c) =>
c
-> Axis v
-> Builder v g a (Value TArray c)
loadIndex c0 axis = do
let type0 = mkDyn TArray c0
n0 <- addNodeE [] $ NInst (LoadIndex axis)
n1 <- addNodeE [n0] $ NValue type0
return (FromNode TArray c0 n1)
loadSize :: (TRealm r, Typeable c)
=> r
-> c
-> Axis v
-> Builder v g a (Value r c)
loadSize r0 c0 axis = do
let type0 = mkDyn r0 c0
n0 <- addNodeE [] $ NInst (LoadSize axis)
n1 <- addNodeE [n0] $ NValue type0
return (FromNode r0 c0 n1)
imm :: (TRealm r, Typeable c) =>
c
-> B (Value r c)
imm c0 = return (FromImm unitTRealm c0)
mkOp1 :: (TRealm r, Typeable c) =>
A.Operator
-> (Builder v g a (Value r c))
-> (Builder v g a (Value r c))
mkOp1 op builder1 = do
v1 <- builder1
let
r1 = Val.realm v1
c1 = Val.content v1
n1 <- valueToNode v1
n0 <- addNodeE [n1] $ NInst (Arith op)
n01 <- addNodeE [n0] $ NValue (toDyn v1)
return $ FromNode r1 c1 n01
mkOp2 :: (TRealm r, Typeable c) =>
A.Operator
-> (Builder v g a (Value r c))
-> (Builder v g a (Value r c))
-> (Builder v g a (Value r c))
mkOp2 op builder1 builder2 = do
v1 <- builder1
v2 <- builder2
let
r1 = Val.realm v1
c1 = Val.content v1
n1 <- valueToNode v1
n2 <- valueToNode v2
n0 <- addNodeE [n1, n2] $ NInst (Arith op)
n01 <- addNodeE [n0] $ NValue (toDyn v1)
return $ FromNode r1 c1 n01
withAnnotation :: (a -> a) -> Builder v g a ret -> Builder v g a ret
withAnnotation f builder1 = do
stat0 <- State.get
let curAnot0 = currentAnnotation (context stat0)
curAnot1 = f curAnot0
State.put $ stat0{ context = (context stat0){ currentAnnotation = curAnot1 } }
ret <- builder1
stat1 <- State.get
State.put $ stat1{ context = (context stat1){ currentAnnotation = curAnot0} }
return ret
annotate :: (TRealm r, Typeable c) => (a -> a) -> Builder v g a (Value r c) -> Builder v g a (Value r c)
annotate f builder1 = do
v1 <- builder1
n1 <- valueToNode v1
let
r1 = Val.realm v1
c1 = Val.content v1
annotator con@(ins, n2, node2, outs)
| n1 /= n2 = con
| otherwise = (ins, n2, fmap f node2, outs)
stat0 <- State.get
State.put $ stat0 {
target = FGL.gmap annotator (target stat0)
}
return $ FromNode r1 c1 n1
infixr 0 <?>
(<?>) :: (TRealm r, Typeable c) => (a -> a) -> Builder v g a (Value r c) -> Builder v g a (Value r c)
(<?>) = annotate
instance (TRealm r, Typeable c, Additive.C c) => Additive.C (Builder v g a (Value r c)) where
zero = return $ FromImm unitTRealm Additive.zero
(+) = mkOp2 A.Add
() = mkOp2 A.Sub
negate = mkOp1 A.Neg
instance (TRealm r, Typeable c, Ring.C c) => Ring.C (Builder v g a (Value r c)) where
one = return $ FromImm unitTRealm Ring.one
(*) = mkOp2 A.Mul
fromInteger = imm . fromInteger
a ^ n
| n== 0 = fromInteger 1
| n== 1 = a
| True = do
ba <- fmap return a
f ba n
where
f x 1 = x
f x n2 = do
let n3 = div n2 2
modify = if n2 2*n3 > 0 then (x*) else id
bx_n3 <- fmap return $ f x n3
modify $ bx_n3*bx_n3
instance (TRealm r, Typeable c, IntegralDomain.C c) => IntegralDomain.C (Builder v g a (Value r c)) where
div = mkOp2 A.Div
mod = mkOp2 A.Mod
divMod = error "divmod is to be defined!"
instance (TRealm r, Typeable c, Ring.C c) => Prelude.Num (Builder v g a (Value r c)) where
(+) = (Additive.+)
(*) = (Ring.*)
() = (Additive.-)
negate = Additive.negate
abs = undefined
signum = undefined
fromInteger = Ring.fromInteger
instance (TRealm r, Typeable c, Field.C c) => Field.C (Builder v g a (Value r c)) where
(/) = mkOp2 A.Div
recip = mkOp1 A.Inv
fromRational' = imm . fromRational'
instance (TRealm r, Typeable c, Field.C c, Prelude.Fractional c) => Prelude.Fractional (Builder v g a (Value r c)) where
(/) = (Field./)
recip = Field.recip
fromRational = imm . Prelude.fromRational
instance (TRealm r) => Boolean (Builder v g a (Value r Bool)) where
true = imm True
false = imm False
not = mkOp1 A.Not
(&&) = mkOp2 A.And
(||) = mkOp2 A.Or
instance (TRealm r, Typeable c, Algebraic.C c) => Algebraic.C (Builder v g a (Value r c)) where
sqrt = mkOp1 A.Sqrt
x ^/ y = mkOp2 A.Pow x (fromRational' y)
instance (TRealm r, Typeable c) => Lattice.C (Builder v g a (Value r c))
where
up = mkOp2 A.Max
dn = mkOp2 A.Min
instance (TRealm r, Typeable c) => ZeroTestable.C (Builder v g a (Value r c))
where
isZero _ = error "isZero undefined for builder."
instance (TRealm r, Typeable c, Ring.C c) => Absolute.C (Builder v g a (Value r c))
where
abs = mkOp1 A.Abs
signum = mkOp1 A.Signum
instance (TRealm r, Typeable c, Transcendental.C c) =>
Transcendental.C (Builder v g a (Value r c)) where
pi = imm pi
exp = mkOp1 A.Exp
log = mkOp1 A.Log
sin = mkOp1 A.Sin
cos = mkOp1 A.Cos
tan = mkOp1 A.Tan
asin = mkOp1 A.Asin
acos = mkOp1 A.Acos
atan = mkOp1 A.Atan