优化可变数组状态繁重操作代码

Optimizing mutable array state heavy manipulation code

提问人:Dulguun Otgon 提问时间:2/14/2016 更新时间:6/2/2016 访问量:245

问:

我一直在努力及时完成这个关于hackerrank的练习。 但是由于超时,我下面的 Haskell 解决方案在测试用例 13 到 15 上失败。enter image description here

我的 Haskell 解决方案

import           Data.Vector(Vector(..),fromList,(!),(//),toList)
import           Data.Vector.Mutable
import qualified Data.Vector as V 
import           Data.ByteString.Lazy.Char8 (ByteString(..))
import qualified Data.ByteString.Lazy.Char8 as L
import Data.ByteString.Lazy.Builder
import Data.Maybe
import Control.Applicative
import Data.Monoid
import Prelude hiding (length)

readInt' = fst . fromJust . L.readInt 
toB []     = mempty
toB (x:xs) = string8 (show x) <> string8 " " <> toB xs

main = do 
  [firstLine, secondLine] <- L.lines <$> L.getContents
  let [n,k] = map readInt' $ L.words firstLine
  let xs = largestPermutation n k $ fromList $ map readInt' $ Prelude.take n $ L.words secondLine
  L.putStrLn $ toLazyByteString $ toB $ toList xs


largestPermutation n k v
  | i >= l || k == 0 = v 
  | n == x           = largestPermutation (n-1) k v
  | otherwise        = largestPermutation (n-1) (k-1) (replaceOne n x (i+1) (V.modify (\v' -> write v' i n) v))
        where l = V.length v 
              i = l - n
              x = v!i

replaceOne n x i v
  | n == h = V.modify (\v' -> write v' i x ) v
  | otherwise = replaceOne n x (i+1) v
    where h = v!i 

我发现的最佳解决方案不断更新 2 个阵列。一个数组是主要目标,另一个数组用于快速索引查找。

更好的 Java 解决方案

public static void main(String[] args) {
  Scanner input = new Scanner(System.in);
  int n = input.nextInt();
  int k = input.nextInt();
  int[] a = new int[n];
  int[] index = new int[n + 1];
  for (int i = 0; i < n; i++) {
      a[i] = input.nextInt();
      index[a[i]] = i;
  }
  for (int i = 0; i < n && k > 0; i++) {
      if (a[i] == n - i) {
          continue;
      }
      a[index[n - i]] = a[i];
      index[a[i]] = index[n - i];
      a[i] = n - i;
      index[n - i] = i;
      k--; 
  }
  for (int i = 0; i < n; i++) {
      System.out.print(a[i] + " ");
  }
}

我的问题是

  1. 在Haskell中,这种算法的优雅和快速实现是什么?
  2. 有没有比 Java 解决方案更快的方法来解决这个问题?
  3. 一般来说,我应该如何在 Haskell 中优雅而高效地处理繁重的阵列更新?
阵 列 算法 哈斯克尔 可变

评论

2赞 leftaroundabout 2/14/2016
请输入签名...
0赞 Sibi 2/14/2016
只是想知道您的 Java 解决方案是否通过了现场的所有时间板测试?
0赞 Dulguun Otgon 2/14/2016
@leftaroundabout对不起。缺少类型签名会降低其可读性吗?我以为更少的代码=更多的可读性,愚蠢的我。
0赞 leftaroundabout 2/14/2016
是的,但类型签名的作用至少与自文档一样多,就像它们作为代码一样。
0赞 dfeuer 6/2/2016
没有类型签名的代码真的很难阅读。

答:

7赞 behzad.nouri 2/14/2016 #1

您可以对可变数组进行的一项优化是完全不使用它们。特别是,您链接到的问题具有右折叠解决方案

这个想法是你折叠列表并贪婪地将具有最大值的项目交换到右边,并保持已经在 Data.Map 中进行的交换:

import qualified Data.Map as M
import Data.Map (empty, insert)

solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = foldr go (\_ _ _ -> []) xs n empty k
    where
    go x run i m k
        -- out of budget to do a swap or no swap necessary
        | k == 0 || y == i = y : run (pred i) m k
        -- make a swap and record the swap made in the map
        | otherwise        = i : run (pred i) (insert i y m) (k - 1)
        where
        -- find the value current position is swapped with
        y = find x
        find k = case M.lookup k m of
            Just a  -> find a
            Nothing -> k

在上面,是一个函数,它给定反向索引、当前映射和剩余的掉期预算,求解列表的其余部分。通过反向索引,我的意思是列表在相反方向上的索引:。runimkn, n - 1, ..., 1

折叠函数 ,通过更新 的值在每一步构建函数,并将这些值传递给下一步。最后,我们用初始参数和初始交换预算调用这个函数。gorunimki = nm = emptyk

可以通过维护反向映射来优化递归搜索,但这已经比您发布的 Java 代码快得多。find


编辑:上述解决方案,仍然为树木访问支付对数成本。这是使用可变 STUArray 和一元折叠foldM_的替代解决方案,其执行速度实际上比上述更快:

import Control.Monad.ST (ST)
import Control.Monad (foldM_)
import Data.Array.Unboxed (UArray, elems, listArray, array)
import Data.Array.ST (STUArray, readArray, writeArray, runSTUArray, thaw)

-- first 3 args are the scope, which will be curried
swap :: STUArray s Int Int -> STUArray s Int Int -> Int
     -> Int -> Int -> ST s Int
swap   _   _ _ 0 _ = return 0  -- out of budget to make a swap
swap arr rev n k i = do
    xi <- readArray arr i
    if xi + i == n + 1
    then return k -- no swap necessary
    else do -- make a swap, and reduce budget
        j <- readArray rev (n + 1 - i)
        writeArray rev xi j
        writeArray arr j  xi
        writeArray arr i (n + 1 - i)
        return $ pred k

solve :: Int -> Int -> [Int] -> [Int]
solve n k xs = elems $ runSTUArray $ do
    arr <- thaw (listArray (1, n) xs :: UArray Int Int)
    rev <- thaw (array (1, n) (zip xs [1..]) :: UArray Int Int)
    foldM_ (swap arr rev n) k [1..n]
    return arr

评论

1赞 Dulguun Otgon 2/15/2016
哇,与其他 java 解决方案相比,它几乎是即时的。然而,这部分让我有点绊倒。我应该如何尝试和理解这部分?你能在 @behzad.nouri 添加更多括号吗?foldr go (const []) (zip [n,n - 1..] xs) (empty, k)
2赞 behzad.nouri 2/15/2016
@DulguunOtgon此页面页面应该会有所帮助。您可以将 Lean 推得太右,以便它:)再次向左返回。该页面解释了这是如何完成的(通过咖喱折叠功能)。foldr
1赞 behzad.nouri 2/15/2016
@DulguunOtgon使用可变 STUArrays 的替代解决方案编辑的答案
0赞 Dulguun Otgon 2/17/2016
我必须在哪些领域熟练掌握才能编写和理解类似第一个版本的东西?
0赞 behzad.nouri 2/17/2016
@DulguunOtgon我在 StackOverflow 上遇到了这个技巧。如果你对书籍、博客和问答网站有足够的了解,随着时间的推移,你就会掌握这些技巧。foldr
1赞 hilberts_drinking_problem 2/20/2016 #2

不完全是 #2 的答案,但有一个左折叠解决方案,一次最多需要在内存中加载 ~K 个值。

因为问题涉及排列,所以我们知道 1 到 N 将出现在输出中。如果 K > 0,则至少前 K 项将是 N、N-1、...N - K,因为我们至少可以负担得起 K 次掉期。此外,我们预计一些 (K/N) 数字处于最佳位置。

这提出了一个算法:

初始化地图/字典并将输入扫描为 .对于每个 ,如果 ,我们“递减”并更新字典 s.t. .当(交换不足)或我们用完输入(可以输出 {N, N-1, ...1}).xszip xs [n, n-1..](x, i)x \= iKdct[i] = xK == 0

接下来,如果我们有更多,我们会查看每一个,并打印出是否不在我们的字典中或其他。x <- xsxxdct[x]

只有当我们的字典包含循环时,上述算法才能无法产生最佳排列。在这种情况下,我们使用交换来移动绝对值为 >= 的元素。但这意味着我们将一个元素移动到其原始位置!因此,我们始终可以在每个周期(即增量)上保存一个交换。K|cycle|K

最后,这给出了内存效率高的算法。

第 0 步:获取 N、K

第 1 步:读取输入排列和输出 {N, N-1, ...N-K-E}, N <- N - K - E, K <- 0, 更新字典如上,

其中 E = 元素数 X 等于 N - (X 的索引)

第 2 步:从字典中删除并计算周期;let = 循环次数;如果 ,让 ,转到步骤 1,cyclescycles > 0K <- |cycles|

否则,请转到步骤 3。我们可以通过优化字典来使这一步更有效率。

第 3 步:按原样输出其余输入。

下面的 Python 代码实现了这个想法,如果使用更好的循环检测,可以非常快地完成。当然,数据最好以块的形式读取,这与下面不同。

from collections import deque

n, t = map(int, raw_input().split())

xs = deque(map(int, raw_input().split()))

dct = {}

cycles = True
while cycles:
    while t > 0 and xs:
        x = xs.popleft()
        if x != n:
            dct[n] = x
            t -= 1
        print n,
        n -= 1

    cycles = False
    for k, v in dct.items():
        visited = set()
        cycle = False
        while v in dct:
            if v in visited:
                cycle = True
                break
            visited.add(v)
            v, buf = dct[v], v
            dct[buf] = v
        if cycle:
            cycles = True
            for i in visited:
                del dct[i]
            t += 1
        else:
            dct[k] = v

while xs:
    x = xs.popleft()
    print dct.get(x, x),