Programming
Haskell
Functional Programming
March 25

Этот ваш хаскель (не) только для факториалов и годен

Когда речь заходит о любимых языках, я обычно говорю, что при прочих равных предпочитаю C++ для числодробилок и хаскель для всего остального. Полезно периодически проверять, насколько такое деление обосновано, а тут ещё недавно возник один праздный и очень простой вопрос: как себя будет вести сумма всех делителей числа с ростом этого самого числа, скажем, для первого миллиарда чисел. Эту задачу просто запрогать (аж стыдно называть получившееся числодробилкой), так что она выглядит как отличный вариант для такой проверки.


Кроме того, я всё ещё не владею навыком точного предсказания производительности хаскель-кода, так что полезно пробовать заведомо плохие подходы, чтобы посмотреть, как будет деградировать производительность.


Ну и вдобавок можно легонько выпендриться более эффективным алгоритмом, чем лобовой поиск делителей для каждого числа от $1$ до $n$.


Алгоритм


Итак, начнём с алгоритма.


Как найти сумму всех делителей числа $n$? Можно пройтись по всем $k_1 \in \{ 1 \dots \lfloor \sqrt n \rfloor \}$ и для каждого такого $k_1$ проверить остаток от деления $n$ на $k_1$. Если остаток — $0$, то добавляем к аккумулятору $k_1 + k_2$, где $k_2 = \frac{n}{k_1}$, если $k_1 \neq k_2$, и просто $k_1$ иначе.


Можно ли применять этот алгоритм $n$ раз, для каждого числа от $1$ до $n$? Можно, конечно. Какова будет сложность? Легко видеть, что порядка $O(n^\frac{3}{2})$ делений — для каждого числа мы делаем ровно корень-из-него делений, а чисел у нас $n$. Можем ли мы лучше? Оказывается, что да.


Одна из проблем этого метода — мы тратим слишком много сил впустую. Слишком много делений не приводят нас к успеху, давая ненулевой остаток. Естественно попытаться быть чуть более ленивыми и подойти к задаче с другой стороны: давайте просто будем генерировать всевозможные кандидаты на делители и смотреть, каким числам они удовлетворяют?


Итак, пусть теперь нам нужно одним махом для каждого числа от $1$ до $n$ посчитать сумму всех его делителей. Для этого пройдёмся по всем $k_1 \in \{ 1 \dots \lfloor \sqrt n \rfloor \}$, и для каждого такого $k_1$ пройдёмся по всем $k_2 \in \{ k_1 \dots \lfloor \frac{n}{k} \rfloor \}$. Для каждой пары $(k_1, k_2)$ добавим в ячейку с индексом $k_1 \cdot k_2$ значение $k_1 + k_2$, если $k_1 \neq k_2$, и $k_1$ иначе.


Этот алгоритм делает ровно $n^\frac{1}{2}$ делений, и каждое умножение (которое дешевле деления) приводит нас к успеху: на каждой итерации мы что-нибудь увеличиваем. Это сильно эффективнее, чем лобовой подход.


Кроме того, имея этот самый лобовой подход, можно сравнить обе реализации и убедиться, что они дают одинаковые результаты для достаточно маленьких чисел, что должно немного добавлять уверенности.


Первая реализация


И, кстати, это прямо почти псевдокод начальной реализации на хаскеле:


module Divisors.Multi(divisorSums) where

import Data.IntMap.Strict as IM

divisorSums :: Int -> Int
divisorSums n = IM.fromListWith (+) premap IM.! n
  where premap = [ (k1 * k2, if k1 /= k2 then k1 + k2 else k1)
                 | k1 <- [ 1 .. floor $ sqrt $ fromIntegral n ]
                 , k2 <- [ k1 .. n `quot` k1 ]
                 ]

Main-модуль простой, и я его не привожу.


Кроме того, здесь мы показываем сумму только для самого $n$ для простоты сравнения с другими реализациями. Несмотря на то, что хаскель — ленивый язык, в этом случае будут вычислены все суммы (хотя полное обоснование этого выходит за рамки этой заметки), так что тут не получится, что мы ненароком что-нибудь не посчитаем.


Как быстро это работает? На моём i7 3930k в один поток 100'000 элементов отрабатывается за 0.4 с. При этом 0.15 с тратится на вычисления и 0.25 с — на GC. И занимаем мы примерно 8 мегабайт памяти, хотя, так как размер инта — 8 байт, в идеале нам должно хватить 800 килобайт.


Хорошо (на самом деле нет). Как эти числа будут расти с увеличением, гм, числа́? Для 1'000'000 элементов оно работает уже примерно 7.5 секунд, три секунды тратя на вычисления и 4.5 секунды тратя на GC, а также занимая 80 мегабайт (в 10 раз больше, чем нужно). И даже если мы на секунду прикинемся Senior Java Software Developer'ами и начнём тюнить GC, существенно картину мы не поменяем. Плохо. Похоже, миллиарда чисел мы не дождёмся никогда, да и по памяти не влезем: на моей машине всего 64 гигабайта оперативной памяти, а нужно будет примерно 80, если тенденция сохранится.


Кажется, время сделать


Вариант на C++


Попробуем получить некоторое представление о том, к чему вообще имеет смысл стремиться, а для этого напишем код на плюсах.


Ну, раз у нас уже есть отлаженный алгоритм, то тут всё просто:


#include <vector>
#include <string>
#include <cmath>
#include <iostream>

int main(int argc, char **argv)
{
    if (argc != 2)
    {
        std::cerr << "Usage: " << argv[0] << " maxN" << std::endl;
        return 1;
    }
    int64_t n = std::stoi(argv[1]);

    std::vector<int64_t> arr;
    arr.resize(n + 1);

    for (int64_t k1 = 1; k1 <= static_cast<int64_t>(std::sqrt(n)); ++k1)
    {
        for (int64_t k2 = k1; k2 <= n / k1; ++k2)
        {
            auto val = k1 != k2 ? k1 + k2 : k1;
            arr[k1 * k2] += val;
        }
    }

    std::cout << arr.back() << std::endl;
}

Если вдруг кое-что хочется написать про этот код

Компилятор отлично делает loop-invariant code motion в этом случае, вычисляя корень один раз за всю жизнь программы, и вычисляя n / k1 один раз на одну итерацию внешнего цикла.


И спойлер про простоту

Этот код у меня заработал не с первого раза, несмотря на то, что я копировал уже имеющийся алгоритм. Я сделал пару очень тупых ошибок, которые вроде бы и не относятся напрямую к типам, но всё же они были сделаны. Но это так, мысли вслух.


-O3 -march=native, clang 8, миллион элементов обрабатывается за 0.024 с, занимая положенные 8 мегабайт памяти. Миллиард — 155 секунд, 8 гигабайт памяти, как и ожидалось. Ой. Хаскель никуда не годится. Хаскель надо выкидывать. Только факториалы и препроморфизмы на нём и писать! Или нет?


Второй вариант


Очевидно, что прогонять все сгенерированные данные через IntMap, то есть, по факту, относительно обычную мапу — мягко скажем, не самое мудрое решение (да, это тот самый заведомо паршивый вариант, о котором говорилось вначале). Почему бы нам не использовать массив, как и в коде на C++?


Попробуем:


module Divisors.Multi(divisorSums) where

import qualified Data.Array.IArray as A
import qualified Data.Array.Unboxed as A

divisorSums :: Int -> Int
divisorSums n = arr A.! n
  where arr = A.accumArray (+) 0 (1, n) premap :: A.UArray Int Int
        premap = [ (k1 * k2, if k1 /= k2 then k1 + k2 else k1)
                 | k1 <- [ 1 .. floor bound ]
                 , k2 <- [ k1 .. n `quot` k1 ]
                 ]
        bound = sqrt $ fromIntegral n :: Double

Здесь мы сразу используем unboxed-версию массива, так как Int достаточно простой, и ленивость в нём нам не нужна. Boxed-версия отличалась бы только типом arr, так что в идиоматичности мы тоже не теряем. Кроме того, здесь отдельно вынесен байндинг для bound, но не потому, что компилятор глупый и не делает LICM, а потому, что тогда можно явно указать его тип и избежать предупреждения от компилятора о defaulting'е аргумента floor.


0.045 с для миллиона элементов (всего в два раза хуже плюсов!). 8 мегабайт памяти, ноль миллисекунд в GC (!). На размерах побольше тенденция сохраняется — примерно в два раза медленнее, чем C++, и столько же памяти. Отличный результат! Но можем ли мы лучше?


Оказывается, что да. accumArray проверяет индексы, чего нам в этом случае делать не надо — индексы корректны по построению. Попробуем заменить вызов accumArray на unsafeAccumArray:


module Divisors.Multi(divisorSums) where

import qualified Data.Array.Base as A
import qualified Data.Array.IArray as A
import qualified Data.Array.Unboxed as A

divisorSums :: Int -> Int
divisorSums n = arr A.! (n - 1)
  where arr = A.unsafeAccumArray (+) 0 (0, n - 1) premap :: A.UArray Int Int
        premap = [ (k1 * k2 - 1, if k1 /= k2 then k1 + k2 else k1)
                 | k1 <- [ 1 .. floor bound ]
                 , k2 <- [ k1 .. n `quot` k1 ]
                 ]
        bound = sqrt $ fromIntegral n :: Double

Как видим, изменения минимальны, кроме необходимости индексироваться с нуля (что, на мой взгляд, является багом в API библиотеки, но это другой вопрос). Какова производительность?


Миллион элементов — 0.021 с (уау, в рамках погрешности, но быстрее, чем плюсы!). Естественно, те же 8 мегабайт памяти, тот же 0 мс в GC.


Миллиард элементов — 152 с (похоже, оно действительно быстрее плюсов!). Чуть меньше 8 гигабайт. 0 мс в GC. Код по-прежнему идиоматичен. Думаю, можно сказать, что это победа.


В заключение


Во-первых, я был удивлён, что замена accumArray на unsafe-версию даст такой прирост. Разумнее было бы ожидать процентов 10-20 (в конце концов, в плюсах замена operator[] на at() не даёт существенного снижения производительности), но никак не половину!


Во-вторых, на мой взгляд, действительно круто, что достаточно идиоматичный чистый код без торчащего наружу мутабельного состояния достигает такого уровня производительности.


В-третьих, конечно, возможны дальнейшие оптимизации, причём на всех уровнях. Я уверен, например, что из кода на плюсах можно выжать ещё чуточку больше. Однако, на мой взгляд, во всяких таких бенчмарках важен баланс между затраченными усилиями (и объёмом кода) и полученным выхлопом. Иначе всё в конце концов в пределе сойдётся к вызову LLVM JIT или чего подобного. Кроме того, наверняка есть более эффективные алгоритмы решения этой задачи, но представленный результат непродолжительных раздумий тоже сойдёт для этого небольшого воскресного приключения.


В-четвёртых, моё любимое: надо развивать системы типов. unsafe здесь не нужен, я как программист могу доказать, что k_1 * k_2 <= n для всех k_1, k_2, встречающихся в цикле. В идеальном мире зависимо типизированных языков я бы конструировал это доказательство статически и передавал бы его в соответствующую функцию, что убирало бы необходимость проверок в рантайме. Но, увы, в хаскеле нет полноценных завтипов, а в языках, где завтипы есть (и которые я знаю), нет array и аналогов.


Ну и в-пятых: я не знаю других языков программирования достаточно, чтобы претендовать на околобенчмарки на этих языках, но один мой приятель написал аналог на питоне. Практически ровно в сто раз медленнее, и похуже по памяти. А сам алгоритм предельно простой, поэтому если кто-то знающий напишет в комментариях аналог на Go, Rust, Julia, D, Java, Malbolge или чём ещё и поделится сравнением, например, с C++-кодом на их машине — будет, наверное, здорово.


P.S.: Сорри за немного кликбейтный заголовок. У меня не получилось придумать ничего лучше.

+48
11.2k 60
Comments 144
Top of the day