module Data.Packed.Syntax(vec, mat) where
import Data.Packed.Syntax.Internal
import Language.Haskell.TH as TH
import Language.Haskell.TH.Quote as TH
import Language.Haskell.TH.Syntax as TH
import Data.Packed.Vector(
Vector,
dim,
)
import Data.Packed.Matrix(
Matrix,
rows,
cols,
)
import Data.Packed.ST(
runSTVector,
newUndefinedVector,
unsafeWriteVector,
runSTMatrix,
newUndefinedMatrix,
unsafeWriteMatrix,
)
import Data.Packed.Development(
MatrixOrder(..),
at',
atM',
)
vec :: QuasiQuoter
vec = qq vecExp vecPat
mat :: QuasiQuoter
mat = qq matExp matPat
qq exp pat = QuasiQuoter exp pat (const $ fail "Type quasiquotes not supported") (const $ fail "Declaration quasiquotes not supported")
vecExp s = case listExp s of
Right es -> buildVectorST es
Left msg -> fail msg
buildVectorST es =
[| runSTVector (do
v <- newUndefinedVector $( lift (length es) )
$( let buildWrites _i [] = [| return () |]
buildWrites i (exp:exps) = [| unsafeWriteVector v i $(return exp) >> $(buildWrites (i+1) exps) |]
in buildWrites 0 es)
return v) |]
buildToList n =
[| \vec -> if dim vec /= n
then Nothing
else Just $(let
buildList i | i == n = [| [] |]
| otherwise = [| at' vec i : $(buildList (i+1)) |]
in buildList 0) |]
vecPat :: String -> Q TH.Pat
vecPat s = case listPat s of
Right ps ->
let l = ListP ps in viewP (buildToList (length ps)) (conP 'Just [return l])
Left msg -> fail msg
matExp s = case matListExp s of
Right (_, _, rows) -> buildMatST rows
Left msg -> fail msg
buildMatST :: [[TH.Exp]] -> Q TH.Exp
buildMatST es =
let r = length es
c = length (head es)
in
[| runSTMatrix
(do
m <- newUndefinedMatrix RowMajor r c
$( let writes = [ [| unsafeWriteMatrix m ir ic $(return $ es !! ir !! ic) |] | ir <- [0..r1], ic <- [0..c1] ]
in foldr (\h t -> [| $h >> $t |]) [| return () |] writes)
return m
) |]
matPat s = case matListPat s of
Right (rowLen, colLen, rows) ->
viewP (buildToLists colLen rowLen)
(conP 'Just [return $ ListP $ map ListP rows])
Left msg -> fail msg
buildToLists r c =
[| \m -> if (rows m, cols m) /= (r, c) then Nothing
else Just
$( TH.listE [ TH.listE [ [| atM' m ir ic |] | ic <- [0..c1] ] | ir <- [0..r1] ] )
|]