| 1 | #include <iostream>
|
|---|
| 2 | #include <cmath>
|
|---|
| 3 | #include <boost/math/distributions/normal.hpp>
|
|---|
| 4 | #include <boost/math/constants/constants.hpp>
|
|---|
| 5 | #include <boost/multiprecision/cpp_dec_float.hpp>
|
|---|
| 6 |
|
|---|
| 7 | using namespace boost::math::constants;
|
|---|
| 8 | using namespace boost::math;
|
|---|
| 9 | using namespace boost::multiprecision;
|
|---|
| 10 |
|
|---|
| 11 | //using FLOAT = double;
|
|---|
| 12 | using FLOAT = cpp_dec_float_50;
|
|---|
| 13 |
|
|---|
| 14 | FLOAT N_f(const FLOAT &x) {
|
|---|
| 15 | return exp(-.5*x*x);
|
|---|
| 16 | }
|
|---|
| 17 | FLOAT N_finv(const FLOAT &x) {
|
|---|
| 18 | return sqrt(-2*log(x));
|
|---|
| 19 | }
|
|---|
| 20 | FLOAT N_tail_area(const FLOAT &x) {
|
|---|
| 21 | return root_half_pi<FLOAT>()*erfc(x * half_root_two<FLOAT>());
|
|---|
| 22 | }
|
|---|
| 23 |
|
|---|
| 24 | FLOAT E_f(const FLOAT &x) {
|
|---|
| 25 | return exp(-x);
|
|---|
| 26 | }
|
|---|
| 27 | FLOAT E_finv(const FLOAT &x) {
|
|---|
| 28 | return -log(x);
|
|---|
| 29 | }
|
|---|
| 30 | FLOAT E_tail_area(const FLOAT &x) {
|
|---|
| 31 | return exp(-x);
|
|---|
| 32 | }
|
|---|
| 33 |
|
|---|
| 34 | int main(int argc, char *argv[]) {
|
|---|
| 35 | std::function<FLOAT(const FLOAT&)> f, finv, tail_area;
|
|---|
| 36 | std::string which;
|
|---|
| 37 | unsigned size;
|
|---|
| 38 | bool good = false;
|
|---|
| 39 | if (argc == 3) {
|
|---|
| 40 | which = argv[1];
|
|---|
| 41 | size = std::stoul(argv[2]);
|
|---|
| 42 | if (which == "E" or which == "e") {
|
|---|
| 43 | f = E_f; finv = E_finv; tail_area = E_tail_area;
|
|---|
| 44 | good = true;
|
|---|
| 45 | }
|
|---|
| 46 | else if (which == "N" or which == "n") {
|
|---|
| 47 | f = N_f; finv = N_finv; tail_area = N_tail_area;
|
|---|
| 48 | good = true;
|
|---|
| 49 | }
|
|---|
| 50 | else {
|
|---|
| 51 | std::cerr << "Invalid distribution option `" << which << "'\n\n";
|
|---|
| 52 | }
|
|---|
| 53 | if (size == 0) {
|
|---|
| 54 | std::cerr << "Invalid size option `" << size << "'\n\n";
|
|---|
| 55 | }
|
|---|
| 56 | else {
|
|---|
| 57 | good = true;
|
|---|
| 58 | if (size < 16) std::cerr << "Warning: size seems small (" << size << ")\n\n";
|
|---|
| 59 | if (size % 2 == 0) std::cerr << "Warning: size is not a power of 2\n\n";
|
|---|
| 60 | }
|
|---|
| 61 | }
|
|---|
| 62 |
|
|---|
| 63 | if (not good) {
|
|---|
| 64 | std::cerr << "Usage: " << argv[0] << " {E,N} SIZE -- generate ziggurat tables for the (E)xponential or (N)ormal distribution\n\n";
|
|---|
| 65 | exit(1);
|
|---|
| 66 | }
|
|---|
| 67 |
|
|---|
| 68 | std::vector<FLOAT> x(size+1);
|
|---|
| 69 | std::vector<FLOAT> y(size+1);
|
|---|
| 70 |
|
|---|
| 71 | std::cout.precision(std::numeric_limits<FLOAT>::max_digits10);
|
|---|
| 72 | FLOAT left = 0, right = 10, last = -1;
|
|---|
| 73 | while (left != right) {
|
|---|
| 74 | FLOAT r = 0.5*(left+right);
|
|---|
| 75 | if (r == last) // We're at our precision limit, so stop
|
|---|
| 76 | break;
|
|---|
| 77 | last = r;
|
|---|
| 78 | std::cout << "trying " << r << "\n";
|
|---|
| 79 | x[1] = r;
|
|---|
| 80 | y[1] = f(x[1]);
|
|---|
| 81 |
|
|---|
| 82 | FLOAT A = x[1]*y[1] + tail_area(x[1]);
|
|---|
| 83 |
|
|---|
| 84 | x[0] = A/y[1];
|
|---|
| 85 | y[0] = f(x[0]);
|
|---|
| 86 | for (unsigned i = 2; i <= size; i++) {
|
|---|
| 87 | y[i] = y[i-1] + A/x[i-1];
|
|---|
| 88 | if (y[i] > f(0)) { // x[1] guess was too low
|
|---|
| 89 | left = r;
|
|---|
| 90 | goto NEXTGUESS;
|
|---|
| 91 | }
|
|---|
| 92 | x[i] = finv(y[i]);
|
|---|
| 93 | }
|
|---|
| 94 | // If final is negative, r was too big; if positive, too small.
|
|---|
| 95 | if (y[size] < f(0)) right = r;
|
|---|
| 96 | else left = r;
|
|---|
| 97 | NEXTGUESS:
|
|---|
| 98 | ;
|
|---|
| 99 | }
|
|---|
| 100 |
|
|---|
| 101 | // Make sure y[SIZE] =~ 1, and x[SIZE] =~ 0
|
|---|
| 102 | if (abs(y[size] - 1) > 1e-12) throw "Error: y_n != 1";
|
|---|
| 103 | if (abs(x[size]) > 1e-20) throw "Error: x_n != 0";
|
|---|
| 104 | y[size] = 1;
|
|---|
| 105 | x[size] = 0;
|
|---|
| 106 | y[0] = 0; // This value never gets used, so just set it to 0 to slightly save space in the code
|
|---|
| 107 |
|
|---|
| 108 | std::string structname(which == "N" ? "normal_table" : "exponential_table");
|
|---|
| 109 | std::cout << "\n\n\n// tables for the ziggurat algorithm\n" <<
|
|---|
| 110 | "template<class RealType>\n" <<
|
|---|
| 111 | "struct " << structname << " {\n" <<
|
|---|
| 112 | " static const RealType table_x[" << size+1 << "];\n" <<
|
|---|
| 113 | " static const RealType table_y[" << size+1 << "];\n" <<
|
|---|
| 114 | "};\n\n";
|
|---|
| 115 |
|
|---|
| 116 | std::cout << "template<class RealType>\nconst RealType " << structname << "<RealType>::table_x[" << size+1 << "] = {";
|
|---|
| 117 | std::cout.precision(20);
|
|---|
| 118 | std::cout.setf(std::ios_base::showpoint);
|
|---|
| 119 | for (unsigned i = 0; i < size; i++) {
|
|---|
| 120 | std::cout << (i % 4 == 0 ? "\n " : " ");
|
|---|
| 121 | if (x[i] == 0) std::cout << "0,";
|
|---|
| 122 | else if (x[i] == 1) std::cout << "1,";
|
|---|
| 123 | else std::cout << x[i] << ",";
|
|---|
| 124 | }
|
|---|
| 125 | std::cout.unsetf(std::ios_base::showpoint);
|
|---|
| 126 | std::cout << "\n " << x[size] << "\n};\n\n";
|
|---|
| 127 |
|
|---|
| 128 | std::cout << "template<class RealType>\nconst RealType " << structname << "<RealType>::table_y[" << size+1 << "] = {";
|
|---|
| 129 | std::cout.precision(20);
|
|---|
| 130 | std::cout.setf(std::ios_base::showpoint);
|
|---|
| 131 | for (unsigned i = 0; i < size; i++) {
|
|---|
| 132 | std::cout << (i % 4 == 0 ? "\n " : " ");
|
|---|
| 133 | if (y[i] == 0) std::cout << "0,";
|
|---|
| 134 | else if (y[i] == 1) std::cout << "1,";
|
|---|
| 135 | else std::cout << y[i] << ",";
|
|---|
| 136 | }
|
|---|
| 137 | std::cout.unsetf(std::ios_base::showpoint);
|
|---|
| 138 | std::cout << "\n " << y[size] << "\n};\n\n";
|
|---|
| 139 | }
|
|---|