1 | #include <iostream>
|
---|
2 | #include <cmath>
|
---|
3 | #include <boost/math/distributions/normal.hpp>
|
---|
4 | #include <boost/math/constants/constants.hpp>
|
---|
5 | #include <boost/multiprecision/cpp_dec_float.hpp>
|
---|
6 |
|
---|
7 | using namespace boost::math::constants;
|
---|
8 | using namespace boost::math;
|
---|
9 | using namespace boost::multiprecision;
|
---|
10 |
|
---|
11 | //using FLOAT = double;
|
---|
12 | using FLOAT = cpp_dec_float_50;
|
---|
13 |
|
---|
14 | FLOAT N_f(const FLOAT &x) {
|
---|
15 | return exp(-.5*x*x);
|
---|
16 | }
|
---|
17 | FLOAT N_finv(const FLOAT &x) {
|
---|
18 | return sqrt(-2*log(x));
|
---|
19 | }
|
---|
20 | FLOAT N_tail_area(const FLOAT &x) {
|
---|
21 | return root_half_pi<FLOAT>()*erfc(x * half_root_two<FLOAT>());
|
---|
22 | }
|
---|
23 |
|
---|
24 | FLOAT E_f(const FLOAT &x) {
|
---|
25 | return exp(-x);
|
---|
26 | }
|
---|
27 | FLOAT E_finv(const FLOAT &x) {
|
---|
28 | return -log(x);
|
---|
29 | }
|
---|
30 | FLOAT E_tail_area(const FLOAT &x) {
|
---|
31 | return exp(-x);
|
---|
32 | }
|
---|
33 |
|
---|
34 | int main(int argc, char *argv[]) {
|
---|
35 | std::function<FLOAT(const FLOAT&)> f, finv, tail_area;
|
---|
36 | std::string which;
|
---|
37 | unsigned size;
|
---|
38 | bool good = false;
|
---|
39 | if (argc == 3) {
|
---|
40 | which = argv[1];
|
---|
41 | size = std::stoul(argv[2]);
|
---|
42 | if (which == "E" or which == "e") {
|
---|
43 | f = E_f; finv = E_finv; tail_area = E_tail_area;
|
---|
44 | good = true;
|
---|
45 | }
|
---|
46 | else if (which == "N" or which == "n") {
|
---|
47 | f = N_f; finv = N_finv; tail_area = N_tail_area;
|
---|
48 | good = true;
|
---|
49 | }
|
---|
50 | else {
|
---|
51 | std::cerr << "Invalid distribution option `" << which << "'\n\n";
|
---|
52 | }
|
---|
53 | if (size == 0) {
|
---|
54 | std::cerr << "Invalid size option `" << size << "'\n\n";
|
---|
55 | }
|
---|
56 | else {
|
---|
57 | good = true;
|
---|
58 | if (size < 16) std::cerr << "Warning: size seems small (" << size << ")\n\n";
|
---|
59 | if (size % 2 == 0) std::cerr << "Warning: size is not a power of 2\n\n";
|
---|
60 | }
|
---|
61 | }
|
---|
62 |
|
---|
63 | if (not good) {
|
---|
64 | std::cerr << "Usage: " << argv[0] << " {E,N} SIZE -- generate ziggurat tables for the (E)xponential or (N)ormal distribution\n\n";
|
---|
65 | exit(1);
|
---|
66 | }
|
---|
67 |
|
---|
68 | std::vector<FLOAT> x(size+1);
|
---|
69 | std::vector<FLOAT> y(size+1);
|
---|
70 |
|
---|
71 | std::cout.precision(std::numeric_limits<FLOAT>::max_digits10);
|
---|
72 | FLOAT left = 0, right = 10, last = -1;
|
---|
73 | while (left != right) {
|
---|
74 | FLOAT r = 0.5*(left+right);
|
---|
75 | if (r == last) // We're at our precision limit, so stop
|
---|
76 | break;
|
---|
77 | last = r;
|
---|
78 | std::cout << "trying " << r << "\n";
|
---|
79 | x[1] = r;
|
---|
80 | y[1] = f(x[1]);
|
---|
81 |
|
---|
82 | FLOAT A = x[1]*y[1] + tail_area(x[1]);
|
---|
83 |
|
---|
84 | x[0] = A/y[1];
|
---|
85 | y[0] = f(x[0]);
|
---|
86 | for (unsigned i = 2; i <= size; i++) {
|
---|
87 | y[i] = y[i-1] + A/x[i-1];
|
---|
88 | if (y[i] > f(0)) { // x[1] guess was too low
|
---|
89 | left = r;
|
---|
90 | goto NEXTGUESS;
|
---|
91 | }
|
---|
92 | x[i] = finv(y[i]);
|
---|
93 | }
|
---|
94 | // If final is negative, r was too big; if positive, too small.
|
---|
95 | if (y[size] < f(0)) right = r;
|
---|
96 | else left = r;
|
---|
97 | NEXTGUESS:
|
---|
98 | ;
|
---|
99 | }
|
---|
100 |
|
---|
101 | // Make sure y[SIZE] =~ 1, and x[SIZE] =~ 0
|
---|
102 | if (abs(y[size] - 1) > 1e-12) throw "Error: y_n != 1";
|
---|
103 | if (abs(x[size]) > 1e-20) throw "Error: x_n != 0";
|
---|
104 | y[size] = 1;
|
---|
105 | x[size] = 0;
|
---|
106 | y[0] = 0; // This value never gets used, so just set it to 0 to slightly save space in the code
|
---|
107 |
|
---|
108 | std::string structname(which == "N" ? "normal_table" : "exponential_table");
|
---|
109 | std::cout << "\n\n\n// tables for the ziggurat algorithm\n" <<
|
---|
110 | "template<class RealType>\n" <<
|
---|
111 | "struct " << structname << " {\n" <<
|
---|
112 | " static const RealType table_x[" << size+1 << "];\n" <<
|
---|
113 | " static const RealType table_y[" << size+1 << "];\n" <<
|
---|
114 | "};\n\n";
|
---|
115 |
|
---|
116 | std::cout << "template<class RealType>\nconst RealType " << structname << "<RealType>::table_x[" << size+1 << "] = {";
|
---|
117 | std::cout.precision(20);
|
---|
118 | std::cout.setf(std::ios_base::showpoint);
|
---|
119 | for (unsigned i = 0; i < size; i++) {
|
---|
120 | std::cout << (i % 4 == 0 ? "\n " : " ");
|
---|
121 | if (x[i] == 0) std::cout << "0,";
|
---|
122 | else if (x[i] == 1) std::cout << "1,";
|
---|
123 | else std::cout << x[i] << ",";
|
---|
124 | }
|
---|
125 | std::cout.unsetf(std::ios_base::showpoint);
|
---|
126 | std::cout << "\n " << x[size] << "\n};\n\n";
|
---|
127 |
|
---|
128 | std::cout << "template<class RealType>\nconst RealType " << structname << "<RealType>::table_y[" << size+1 << "] = {";
|
---|
129 | std::cout.precision(20);
|
---|
130 | std::cout.setf(std::ios_base::showpoint);
|
---|
131 | for (unsigned i = 0; i < size; i++) {
|
---|
132 | std::cout << (i % 4 == 0 ? "\n " : " ");
|
---|
133 | if (y[i] == 0) std::cout << "0,";
|
---|
134 | else if (y[i] == 1) std::cout << "1,";
|
---|
135 | else std::cout << y[i] << ",";
|
---|
136 | }
|
---|
137 | std::cout.unsetf(std::ios_base::showpoint);
|
---|
138 | std::cout << "\n " << y[size] << "\n};\n\n";
|
---|
139 | }
|
---|