{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wall #-}
module NLP.GizaPlusPlus where

import Control.Arrow ((&&&),(***))
import Control.Exception (bracket)
import Data.List (group, sort, intersperse)
import qualified Data.Map as Map
import System.Directory (removeFile, getTemporaryDirectory, createDirectoryIfMissing)
import System.FilePath ((</>),(<.>))
import System.IO (openBinaryFile, IOMode(WriteMode), openTempFile, Handle)
import System.Process (runProcess, waitForProcess)

import Text.ParserCombinators.Parsec (parseFromFile)

import NLP.GizaPlusPlus.Parsec (alignFile)

data WordPos = WordPos String Int deriving (Read,Show,Eq,Ord)
data Align   = Align WordPos WordPos deriving (Read,Show,Eq,Ord)

-- ----------------------------------------------------------------------
-- GIZA++ configuration
-- ----------------------------------------------------------------------

data GizaCfg = GizaCfg
 { gizaHmmiterations    :: Int
 , gizaModel1iterations :: Int
 , gizaModel2iterations :: Int
 , gizaModel3iterations :: Int
 , gizaModel4iterations :: Int
 , gizaModel5iterations :: Int
 , gizaPegging          :: Bool
 , gizaNBestAlignments  :: Bool
 , gizaCompactTable     :: Bool
 , gizaVerbose          :: Bool
 }

defaultGizaCfg :: GizaCfg
defaultGizaCfg = GizaCfg
 { gizaHmmiterations    = 0
 , gizaModel1iterations = 5
 , gizaModel2iterations = 5
 , gizaModel3iterations = 3
 , gizaModel4iterations = 3
 , gizaModel5iterations = 3
 , gizaPegging          = True
 , gizaNBestAlignments  = True
 , gizaCompactTable     = False
 , gizaVerbose          = False
 }

-- | Convert a 'GizaCfg' into a (fragment of a) GIZA++ configuration
--   file (when we call giza, we will append other entries)
fromGizaCfg :: GizaCfg -> String
fromGizaCfg gz =
 concat . intersperse "\n" $
   [ iEntry "hmmiterations "   gizaHmmiterations
   , iEntry "model1iterations" gizaModel1iterations
   , iEntry "model2iterations" gizaModel2iterations
   , iEntry "model3iterations" gizaModel3iterations
   , iEntry "model4iterations" gizaModel4iterations
   , iEntry "model5iterations" gizaModel5iterations
   , bEntry "pegging"          gizaPegging
   , bEntry "nbestalignments"  gizaNBestAlignments
   , bEntry "compactadtable"   gizaCompactTable ]
 where
  iEntry k f = k ++ " " ++ (show $ f gz)
  bEntry k f = k ++ " " ++ (if f gz then "1" else "0")

-- ----------------------------------------------------------------------
-- running GIZA++
-- ----------------------------------------------------------------------

type Alignment = ([String],[String],[Align])

-- | Run GIZA++ and extract a list of word alignments
align :: GizaCfg
      -> [(String,String)] -- ^ aligned sentence pairs
      -> IO [Alignment]
align gzcfg spairs =
 do tmpDir <- getTemporaryDirectory
    withTmp tmpDir $ \(cfg, _) -> do
    -- write giza input and configuration files
    let subTmpDir = tmpDir </> cfg ++ "-dir"
        v1  = subTmpDir </> "source" <.> "vcb"
        v2  = subTmpDir </> "target" <.> "vcb"
        snt = subTmpDir </> "corpus" <.> "snt"
        realcfg = subTmpDir </> "config"
    createDirectoryIfMissing False subTmpDir
    writeVcb v1 idx1
    writeVcb v2 idx2
    writeSnt snt wspairs
    writeFile realcfg . unlines $
      fromGizaCfg gzcfg : [ "s " ++ v1
                          , "t " ++ v2
                          , "c " ++ snt
                          , "o " ++ "output"
                          , "outputpath " ++ subTmpDir ]
    -- go!
    gizaStdout <- if gizaVerbose gzcfg
                     then return Nothing
                     else Just `fmap` openBinaryFile _dev_null WriteMode
    proc <- runProcess "GIZA++" [realcfg] Nothing Nothing
              Nothing -- stdin
              gizaStdout -- stdout
              gizaStdout -- stderr
    _ <- waitForProcess proc
    -- parse the output
    let algnfile = subTmpDir </> "output.A3.final"
    mparse <- parseFromFile alignFile algnfile
    case mparse of
      Left err -> fail $ "Error parsing GIZA++ output:\n" ++ show err
      Right ps -> return $ map toAlignment ps
 where
  wspairs  = map (words *** words) $ spairs
  -- we start from 2 because GIZA++ uses denotes sentence boundaries with 1
  countAndIdx :: [[String]] -> [(Int,(String,Int))]
  countAndIdx = zip [2..] . count . concat
  (idx1,idx2) = (countAndIdx *** countAndIdx) . unzip $ wspairs
  --
  writeVcb f = writeFile f . unlines . map toVcb
  toVcb (i,(w,c)) = unwords [ show i, w, show c ]
  --
  writeSnt f = writeFile f . unlines . concatMap toSnt
  toSnt (s1,s2) = ["1", toSntLine wi_map1 s1
                      , toSntLine wi_map2 s2 ]
  toSntLine m = unwords . map (show . (m Map.!))
  getWordIdx (i,(w,_)) = (w,i)
  toWordIdxMap = Map.fromList . map getWordIdx
  wi_map1 = toWordIdxMap idx1
  wi_map2 = toWordIdxMap idx2
  --
  withTmp d = withTempFile d "hs-gizapp"

_dev_null :: FilePath
#ifdef WIN32
_dev_null = "NUL"
#else
_dev_null = "/dev/null"
#endif

-- ----------------------------------------------------------------------
-- parsing GIZA++ alignments
-- ----------------------------------------------------------------------

type OneToManyPair = (String, [Integer])

-- by our use of Map.! we assume that the indices in the OneToManyPair
-- actually correspond to the [String]
toAlignment :: ([String], [OneToManyPair]) -> Alignment
toAlignment (ts,pairs_) = (ss,ts,alignments)
 where
  pairs = drop 1 pairs_ -- NULL
  ss = map fst pairs
  alignments = concat $ zipWith alignment [0::Int ..] pairs
  alignment si (s,tis) =
   let f ti = Align (WordPos s               (fromIntegral si))
                    (WordPos (tmap Map.! ti) (fromIntegral ti))
       offset ti = ti - 1
   in  map (f . offset) tis
  tmap :: Map.Map Integer String
  tmap = Map.fromList $ zip [0..] $ ts

-- ----------------------------------------------------------------------
-- odds and ends
-- ----------------------------------------------------------------------

count :: Ord a => [a] -> [(a,Int)]
count = map (head &&& length) . group . sort

withTempFile :: FilePath -> String -> ((FilePath, Handle) -> IO a) -> IO a
withTempFile d t = bracket (openTempFile d t) (removeFile . fst)