{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Ormolu.Printer.Operators
( OpTree (..),
opTreeLoc,
reassociateOpTree,
)
where
import Data.Function (on)
import qualified Data.List as L
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Ord (Down (Down), comparing)
import GHC
import OccName (mkVarOcc)
import Ormolu.Utils (unSrcSpan)
data OpTree ty op
= OpNode ty
| OpBranch
(OpTree ty op)
op
(OpTree ty op)
opTreeLoc :: OpTree (Located a) b -> SrcSpan
opTreeLoc (OpNode (L l _)) = l
opTreeLoc (OpBranch l _ r) = combineSrcSpans (opTreeLoc l) (opTreeLoc r)
reassociateOpTree ::
(op -> Maybe RdrName) ->
OpTree (Located ty) (Located op) ->
OpTree (Located ty) (Located op)
reassociateOpTree getOpName opTree =
reassociateOpTreeWith
(buildFixityMap getOpName normOpTree)
(getOpName . unLoc)
normOpTree
where
normOpTree = normalizeOpTree opTree
reassociateOpTreeWith ::
forall ty op.
[(RdrName, Fixity)] ->
(op -> Maybe RdrName) ->
OpTree ty op ->
OpTree ty op
reassociateOpTreeWith fixityMap getOpName = go
where
fixityOf :: op -> Fixity
fixityOf op = fromMaybe defaultFixity $ do
opName <- getOpName op
lookup opName fixityMap
go :: OpTree ty op -> OpTree ty op
go t@(OpNode _) = t
go t@(OpBranch (OpNode _) _ (OpNode _)) = t
go (OpBranch l@(OpNode _) op (OpBranch l' op' r')) =
go (OpBranch (OpBranch l op l') op' r')
go (OpBranch (OpBranch l op r) op' r'@(OpNode _)) =
if snd $ compareFixity (fixityOf op) (fixityOf op')
then OpBranch l op (go $ OpBranch r op' r')
else OpBranch (OpBranch l op r) op' r'
go (OpBranch (OpBranch l op r) op' (OpBranch l' op'' r')) =
if snd $ compareFixity (fixityOf op) (fixityOf op')
then go $ OpBranch (OpBranch l op (go $ OpBranch r op' l')) op'' r'
else go $ OpBranch (OpBranch (OpBranch l op r) op' l') op'' r'
buildFixityMap ::
forall ty op.
(op -> Maybe RdrName) ->
OpTree (Located ty) (Located op) ->
[(RdrName, Fixity)]
buildFixityMap getOpName opTree =
concatMap (\(i, ns) -> map (\(n, _) -> (n, fixity i InfixL)) ns)
. zip [0 ..]
. L.groupBy (doubleWithinEps 0.00001 `on` snd)
. (overrides ++)
. modeScores
$ score opTree
where
overrides :: [(RdrName, Double)]
overrides =
[ (mkRdrUnqual $ mkVarOcc "$", -1)
]
score :: OpTree (Located ty) (Located op) -> [(RdrName, Double)]
score (OpNode _) = []
score (OpBranch l o r) = fromMaybe (score r) $ do
le <- srcSpanEndLine <$> unSrcSpan (opTreeLoc l)
ob <- srcSpanStartLine <$> unSrcSpan (getLoc o)
oe <- srcSpanEndLine <$> unSrcSpan (getLoc o)
rb <- srcSpanStartLine <$> unSrcSpan (opTreeLoc r)
oc <- srcSpanStartCol <$> unSrcSpan (getLoc o)
opName <- getOpName (unLoc o)
let s
| le < ob =
fromIntegral oc / fromIntegral (maxCol + 1)
| oe < rb =
1
| otherwise =
2
return $ (opName, s) : score r
modeScores :: [(RdrName, Double)] -> [(RdrName, Double)]
modeScores =
L.sortOn snd
. mapMaybe
( \case
[] -> Nothing
xs@((n, _) : _) -> Just (n, mode $ map snd xs)
)
. L.groupBy ((==) `on` fst)
. L.sort
mode :: [Double] -> Double
mode =
head
. L.minimumBy (comparing (Down . length))
. L.groupBy (doubleWithinEps 0.0001)
. L.sort
maxCol = go opTree
where
go (OpNode (L _ _)) = 0
go (OpBranch l (L o _) r) =
maximum
[ go l,
maybe 0 srcSpanStartCol (unSrcSpan o),
go r
]
normalizeOpTree :: OpTree ty op -> OpTree ty op
normalizeOpTree (OpNode n) =
OpNode n
normalizeOpTree (OpBranch (OpNode l) lop r) =
OpBranch (OpNode l) lop (normalizeOpTree r)
normalizeOpTree (OpBranch (OpBranch l' lop' r') lop r) =
normalizeOpTree (OpBranch l' lop' (OpBranch r' lop r))
fixity :: Int -> FixityDirection -> Fixity
fixity = Fixity NoSourceText
doubleWithinEps :: Double -> Double -> Double -> Bool
doubleWithinEps eps a b = abs (a - b) < eps