遅延評価を活用して線形時間でGoogle Code Jam 2014 Round 1AのB問題を解く

公式の解説で,「遅延評価を使ってもできる」と書いてくれなかったので。
注意:この記事は,Google Code Jam 2014 Round 1AのB問題 Full Binary Tree についてのネタバレを含みます。
この問題は木DPをする典型的な問題であり,公式の解説にもあるように,

  • 木を根付き木にした後,
  • 頂点でのスコアを,子のスコアのうち大きい方から2つを用いて計算する

ことで解けます。何を根として選ぶかをすべて試すとO(N2)解法となり,すべての根の可能性について同時にDPをする*1とO(N)解法になります。だから,根の選び方によって隣接頂点のうち「親以外が子となる」ので,隣接頂点のスコアのうち大きい方から3つを保存するような構造を持つ――待ってください。
もちろん,降順ソートされたスコアのすべてを計算すると,O(N2)時間かかってしまいます。しかし,先頭3つを結果的に計算できれば十分で,明示的に保存する構造を作る必要はない,というのが今回のテーマです。
必要な部分(大きい方から3つ)は必要だから計算されるのです。遅延評価を使いましょう。

実装

降順ソートされたデータはリストで持ちます。スコア計算の際に,親p以外のうち大きい方から2つとるのを,

take 2 [x | (x,v) <- dat, Just v /= p]

で行えば,リストの先頭側の評価が優先され,後ろ側の評価は遅延されます。
DPの更新では,親(と仮定した隣接頂点)に保存されている「子のスコアの降順リスト」に自分のスコアを挿入します。ここで,安易にsortBy関数を使うとソート済みのリストへの挿入であることが用いられずに,元のリストがすべて計算されてしまって駄目です。

insertBy (flip compare `on` fst) :: Ord a => (a, b) -> [(a, b)] -> [(a, b)]

のように*2,ソート済みである仮定を用いて挿入しましょう。
最後に今回書いたコードの全体を載せます。どこにも “3” と書いていませんが,O(N)時間で動作しますよ!

{-# LANGUAGE BangPatterns #-}
import Control.Monad
import Control.Applicative
import Data.Function
import Data.List
import Data.Maybe
import Data.Tuple (swap)
import qualified System.IO
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as V
import Data.Tree
import Control.Monad.ST
import qualified Data.ByteString.Char8 as BS

main = gcj $ do
    n <- readInt <$> BS.getLine
    edges <- replicateM (n-1) $ map readInt . BS.words <$> BS.getLine
    return . show $ solve n $ makeTree 1 $ makeGraph n edges

type Vertex = Int
type Graph = V.Vector [Vertex]

makeGraph n edges = V.accum (flip (:))
    (V.fromList (undefined : replicate n []))  -- index: [1 .. n]
    [(v,w) | [v,w] <- edges ++ map reverse edges]
    :: Graph

makeTree root graph = makeTree' root Nothing graph
    :: Tree Vertex
makeTree' v p g = Node v [makeTree' w (Just v) g | w <- g V.! v, Just w /= p]
    -- Induction: visit except parent.

-- directed edges of tree. DFS, pre-order.
dirEdges :: Tree a -> [(a,a)]  -- (parent,child)
dirEdges (Node x ts) = [(x,y) | (Node y _) <- ts] ++ concatMap dirEdges ts

solve :: Int -> Tree Vertex -> Int
solve n tr = let
    vs = [1 .. n]
    es = dirEdges tr
    !dp = runST $ do
        dp <- V.new (n+1)
        sequence_ [V.write dp k [] | k <- vs]  -- initialize
        sequence_ (map (update dp) $ reverse es ++ map swap es)  -- run DP
        V.freeze dp
    in n - maximum [score Nothing (dp V.! k) | k <- vs]

update dp (p,v) = do
    !res <- score (Just p) <$> V.read dp v
    V.write dp p . insertBy (flip compare `on` fst) (res,v) =<< V.read dp p

-- the score of dat without p
score :: Maybe Vertex -> [(Int, Vertex)] -> Int
score p dat = let best = take 2 [x | (x,v) <- dat, Just v /= p]
    in if length best == 2
        then 1 + sum best
        else 1

readInt = fst . fromJust . BS.readInt

gcj solver = do
    System.IO.hSetBuffering System.IO.stdout System.IO.NoBuffering
    t <- readInt <$> BS.getLine
    forM_ [1..t] $ \num -> solver >>= \ans ->
        putStrLn $ "Case #"++show num++": "++ans

*1:根を1つ固定して,葉から根の順ですべての辺を走査した後,辺を逆向きだと考えて根から葉の順に走査する。

*2:スコアと頂点のpairについて,「スコアの比較の逆順」と指定している。