习题来自 https://bartoszmilewski.com/2014/11/24/types-and-functions/
Define a higher-order function (or a function object) memoize in your favorite language. This function takes a pure function f as an argument and returns a function that behaves almost exactly like f, except that it only calls the original function once for every argument, stores the result internally, and subsequently returns this stored result every time it’s called with the same argument. You can tell the memoized function from the original by watching its performance. For instance, try to memoize a function that takes a long time to evaluate. You’ll have to wait for the result the first time you call it, but on subsequent calls, with the same argument, you should get the result immediately.
(当然 C++ 不是我最喜欢的语言,但是确实是很有挑战性的一个,所以这样写了)
#include <chrono>
#include <concepts>
#include <functional>
#include <iostream>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
template <typename T, typename = std::void_t<>>
struct is_std_hashable : std::false_type {};
template <typename T>
struct is_std_hashable<
T, std::void_t<decltype(std::declval<std::hash<T>>()(std::declval<T>()))>>
: std::true_type {};
template <typename T>
constexpr bool is_std_hashable_v = is_std_hashable<T>::value;
template <typename... Args>
requires std::equality_comparable<std::tuple<Args...>>
struct Arguments {
std::tuple<Args...> data;
bool operator==(const Arguments &other) const = default;
};
template <typename R, typename... Args> using fn_type = R(Args...);
template <typename R, typename... Args> auto memorize(fn_type<R, Args...> f) {
auto cache = std::unordered_map<Arguments<Args...>, R>();
return [=](Args... args) mutable -> R {
auto key = Arguments{std::make_tuple(args...)};
if (cache.find(key) == cache.end()) {
auto result = f(args...);
cache[key] = result;
return result;
}
return cache[key];
};
}
template <typename R, typename... Args>
auto memorize(std::function<R(Args...)> &&f) {
auto cache = std::unordered_map<Arguments<Args...>, R>();
return [=](Args... args) mutable -> R {
auto key = Arguments{std::make_tuple(args...)};
if (cache.find(key) == cache.end()) {
auto result = f(args...);
cache[key] = result;
return result;
}
return cache[key];
};
}
template <typename... Args> class std::hash<Arguments<Args...>> {
template <size_t I>
requires(I == std::tuple_size_v<std::tuple<Args...>>)
size_t hasher_recusive(const std::tuple<Args...> &) const {
return 0;
}
template <size_t I>
requires(I < std::tuple_size_v<std::tuple<Args...>>)
size_t hasher_recusive(const std::tuple<Args...> &t) const {
using the_tuple = std::tuple<Args...>;
using the_element = std::tuple_element_t<I, the_tuple>;
static_assert(is_std_hashable_v<the_element>,
"the element is not std::hash-able");
return std::hash<the_element>{}(std::get<I>(t)) ^ hasher_recusive<I + 1>(t);
}
public:
size_t operator()(const Arguments<Args...> &t) const {
return hasher_recusive<0>(t.data);
}
};
long long fib(long long n) {
if (n <= 1)
return n;
return fib(n - 1) + fib(n - 2);
}
// struct NotHashable {
// constexpr bool operator==(const NotHashable &) const { return true; }
// };
// int foo(int x, NotHashable bar) { return 1; }
int main() {
auto memoized_fib = memorize(fib);
// auto memoized_foo = memorize(foo);
for (int i = 0; i < 100; ++i) {
auto t0 = std::chrono::system_clock::now();
std::cout << "calculating fib(" << 30 << ") = ";
std::cout << memoized_fib(30) << " ";
auto t1 = std::chrono::system_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0);
std::cout << " took " << duration.count() << " microseconds" << std::endl;
}
int n = 30;
auto memoized_lamfib = memorize(std::function([=]() { return fib(n); }));
// auto memoized_foo = memorize(foo);
for (int i = 0; i < 100; ++i) {
auto t0 = std::chrono::system_clock::now();
std::cout << "calculating fib2(" << 30 << ") = ";
std::cout << memoized_lamfib() << " ";
auto t1 = std::chrono::system_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0);
std::cout << " took " << duration.count() << " microseconds" << std::endl;
}
return 0;
}