{-# LANGUAGE DataKinds #-} {-# LANGUAGE MonoLocalBinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} module Internal.Quasi.Operator.Quasi where import Data.List.Split (chunksOf) import Data.Proxy import qualified GHC.Natural as Natural import GHC.TypeNats import Internal.Matrix import qualified Internal.Quasi.Operator.Parser as Parser import qualified Internal.Quasi.Parser as Parser import Internal.Quasi.Quasi import Language.Haskell.TH.Quote import Language.Haskell.TH.Syntax import QLinear.Identity {- | Macro constructor for operator >>> [operator| (x, y) => (y, x) |] [0,1] [1,0] >>> [operator| (x, y) => (2 * x, y + x) |] ~*~ [vector| 3 4 |] [6] [7] Do note,constructor __doesn't prove__ linearity. It just builds matrix of given operator. -} operator :: QuasiQuoter operator = QuasiQuoter { quoteExp = expr, quotePat = notDefined "Pattern", quoteType = notDefined "Type", quoteDec = notDefined "Declaration" } where notDefined = isNotDefinedAs "operator" expr :: String -> Q Exp expr source = do let (params, lams, n) = unwrap $ parse source let sizeType = LitT . NumTyLit let size = TupE $ map (LitE . IntegerL) [n, 1] let func = VarE 'matrixOfOperator let constructor = foldl AppTypeE (ConE 'Matrix) [sizeType n, sizeType 1, WildCardT] let value = ListE $ map (ListE . pure . LamE [ListP params]) lams pure $ AppE func $ foldl AppE constructor [size, value] where parse :: String -> Either [String] ([Pat], [Exp], Integer) parse source = do (params, lams) <- Parser.parse Parser.definition "QLinear" source size <- checkSize (params, lams) pure (params, lams, size) checkSize :: ([Pat], [Exp]) -> Either [String] Integer checkSize ([], _) = Left ["Parameters of operator cannot be empty"] checkSize (_, []) = Left ["Body of operator cannot be empty"] checkSize (names, exprs) = let namesLength = length names exprsLength = length exprs in if namesLength == exprsLength then Right $ fromIntegral namesLength else Left ["Number of arguments and number of lambdas must be equal"] matrixOfOperator :: forall n a b. (KnownNat n, HasIdentity a) => Matrix n 1 ([a] -> b) -> Matrix n n b matrixOfOperator (Matrix _ fs) = Matrix (n, n) $ chunksOf n [f line | f <- concat fs, line <- identity] where (Matrix _ identity) = e :: Matrix n n a n = Natural.naturalToInt $ natVal (Proxy @n)