遅延評価を活用して線形時間でGoogle Code Jam 2014 Round 1AのB問題を解く
公式の解説で,「遅延評価を使ってもできる」と書いてくれなかったので。
注意:この記事は,Google Code Jam 2014 Round 1AのB問題 Full Binary Tree についてのネタバレを含みます。
この問題は木DPをする典型的な問題であり,公式の解説にもあるように,
- 木を根付き木にした後,
- 頂点でのスコアを,子のスコアのうち大きい方から2つを用いて計算する
ことで解けます。何を根として選ぶかをすべて試すとO(N2)解法となり,すべての根の可能性について同時にDPをする*1とO(N)解法になります。だから,根の選び方によって隣接頂点のうち「親以外が子となる」ので,隣接頂点のスコアのうち大きい方から3つを保存するような構造を持つ――待ってください。
もちろん,降順ソートされたスコアのすべてを計算すると,O(N2)時間かかってしまいます。しかし,先頭3つを結果的に計算できれば十分で,明示的に保存する構造を作る必要はない,というのが今回のテーマです。
必要な部分(大きい方から3つ)は必要だから計算されるのです。遅延評価を使いましょう。
実装
降順ソートされたデータはリストで持ちます。スコア計算の際に,親p以外のうち大きい方から2つとるのを,
take 2 [x | (x,v) <- dat, Just v /= p]
で行えば,リストの先頭側の評価が優先され,後ろ側の評価は遅延されます。
DPの更新では,親(と仮定した隣接頂点)に保存されている「子のスコアの降順リスト」に自分のスコアを挿入します。ここで,安易にsortBy関数を使うとソート済みのリストへの挿入であることが用いられずに,元のリストがすべて計算されてしまって駄目です。
insertBy (flip compare `on` fst) :: Ord a => (a, b) -> [(a, b)] -> [(a, b)]
のように*2,ソート済みである仮定を用いて挿入しましょう。
最後に今回書いたコードの全体を載せます。どこにも “3” と書いていませんが,O(N)時間で動作しますよ!
{-# LANGUAGE BangPatterns #-} import Control.Monad import Control.Applicative import Data.Function import Data.List import Data.Maybe import Data.Tuple (swap) import qualified System.IO import qualified Data.Vector as V import qualified Data.Vector.Mutable as V import Data.Tree import Control.Monad.ST import qualified Data.ByteString.Char8 as BS main = gcj $ do n <- readInt <$> BS.getLine edges <- replicateM (n-1) $ map readInt . BS.words <$> BS.getLine return . show $ solve n $ makeTree 1 $ makeGraph n edges type Vertex = Int type Graph = V.Vector [Vertex] makeGraph n edges = V.accum (flip (:)) (V.fromList (undefined : replicate n [])) -- index: [1 .. n] [(v,w) | [v,w] <- edges ++ map reverse edges] :: Graph makeTree root graph = makeTree' root Nothing graph :: Tree Vertex makeTree' v p g = Node v [makeTree' w (Just v) g | w <- g V.! v, Just w /= p] -- Induction: visit except parent. -- directed edges of tree. DFS, pre-order. dirEdges :: Tree a -> [(a,a)] -- (parent,child) dirEdges (Node x ts) = [(x,y) | (Node y _) <- ts] ++ concatMap dirEdges ts solve :: Int -> Tree Vertex -> Int solve n tr = let vs = [1 .. n] es = dirEdges tr !dp = runST $ do dp <- V.new (n+1) sequence_ [V.write dp k [] | k <- vs] -- initialize sequence_ (map (update dp) $ reverse es ++ map swap es) -- run DP V.freeze dp in n - maximum [score Nothing (dp V.! k) | k <- vs] update dp (p,v) = do !res <- score (Just p) <$> V.read dp v V.write dp p . insertBy (flip compare `on` fst) (res,v) =<< V.read dp p -- the score of dat without p score :: Maybe Vertex -> [(Int, Vertex)] -> Int score p dat = let best = take 2 [x | (x,v) <- dat, Just v /= p] in if length best == 2 then 1 + sum best else 1 readInt = fst . fromJust . BS.readInt gcj solver = do System.IO.hSetBuffering System.IO.stdout System.IO.NoBuffering t <- readInt <$> BS.getLine forM_ [1..t] $ \num -> solver >>= \ans -> putStrLn $ "Case #"++show num++": "++ans