Ticket #12099: rexp_tables.cpp

File rexp_tables.cpp, 4.5 KB (added by Jason Rhinelander <jason@…>, 7 years ago)

normal and exponential table generator

Line 
1#include <iostream>
2#include <cmath>
3#include <boost/math/distributions/normal.hpp>
4#include <boost/math/constants/constants.hpp>
5#include <boost/multiprecision/cpp_dec_float.hpp>
6
7using namespace boost::math::constants;
8using namespace boost::math;
9using namespace boost::multiprecision;
10
11//using FLOAT = double;
12using FLOAT = cpp_dec_float_50;
13
14FLOAT N_f(const FLOAT &x) {
15 return exp(-.5*x*x);
16}
17FLOAT N_finv(const FLOAT &x) {
18 return sqrt(-2*log(x));
19}
20FLOAT N_tail_area(const FLOAT &x) {
21 return root_half_pi<FLOAT>()*erfc(x * half_root_two<FLOAT>());
22}
23
24FLOAT E_f(const FLOAT &x) {
25 return exp(-x);
26}
27FLOAT E_finv(const FLOAT &x) {
28 return -log(x);
29}
30FLOAT E_tail_area(const FLOAT &x) {
31 return exp(-x);
32}
33
34int main(int argc, char *argv[]) {
35 std::function<FLOAT(const FLOAT&)> f, finv, tail_area;
36 std::string which;
37 unsigned size;
38 bool good = false;
39 if (argc == 3) {
40 which = argv[1];
41 size = std::stoul(argv[2]);
42 if (which == "E" or which == "e") {
43 f = E_f; finv = E_finv; tail_area = E_tail_area;
44 good = true;
45 }
46 else if (which == "N" or which == "n") {
47 f = N_f; finv = N_finv; tail_area = N_tail_area;
48 good = true;
49 }
50 else {
51 std::cerr << "Invalid distribution option `" << which << "'\n\n";
52 }
53 if (size == 0) {
54 std::cerr << "Invalid size option `" << size << "'\n\n";
55 }
56 else {
57 good = true;
58 if (size < 16) std::cerr << "Warning: size seems small (" << size << ")\n\n";
59 if (size % 2 == 0) std::cerr << "Warning: size is not a power of 2\n\n";
60 }
61 }
62
63 if (not good) {
64 std::cerr << "Usage: " << argv[0] << " {E,N} SIZE -- generate ziggurat tables for the (E)xponential or (N)ormal distribution\n\n";
65 exit(1);
66 }
67
68 std::vector<FLOAT> x(size+1);
69 std::vector<FLOAT> y(size+1);
70
71 std::cout.precision(std::numeric_limits<FLOAT>::max_digits10);
72 FLOAT left = 0, right = 10, last = -1;
73 while (left != right) {
74 FLOAT r = 0.5*(left+right);
75 if (r == last) // We're at our precision limit, so stop
76 break;
77 last = r;
78 std::cout << "trying " << r << "\n";
79 x[1] = r;
80 y[1] = f(x[1]);
81
82 FLOAT A = x[1]*y[1] + tail_area(x[1]);
83
84 x[0] = A/y[1];
85 y[0] = f(x[0]);
86 for (unsigned i = 2; i <= size; i++) {
87 y[i] = y[i-1] + A/x[i-1];
88 if (y[i] > f(0)) { // x[1] guess was too low
89 left = r;
90 goto NEXTGUESS;
91 }
92 x[i] = finv(y[i]);
93 }
94 // If final is negative, r was too big; if positive, too small.
95 if (y[size] < f(0)) right = r;
96 else left = r;
97NEXTGUESS:
98 ;
99 }
100
101 // Make sure y[SIZE] =~ 1, and x[SIZE] =~ 0
102 if (abs(y[size] - 1) > 1e-12) throw "Error: y_n != 1";
103 if (abs(x[size]) > 1e-20) throw "Error: x_n != 0";
104 y[size] = 1;
105 x[size] = 0;
106 y[0] = 0; // This value never gets used, so just set it to 0 to slightly save space in the code
107
108 std::string structname(which == "N" ? "normal_table" : "exponential_table");
109 std::cout << "\n\n\n// tables for the ziggurat algorithm\n" <<
110 "template<class RealType>\n" <<
111 "struct " << structname << " {\n" <<
112 " static const RealType table_x[" << size+1 << "];\n" <<
113 " static const RealType table_y[" << size+1 << "];\n" <<
114 "};\n\n";
115
116 std::cout << "template<class RealType>\nconst RealType " << structname << "<RealType>::table_x[" << size+1 << "] = {";
117 std::cout.precision(20);
118 std::cout.setf(std::ios_base::showpoint);
119 for (unsigned i = 0; i < size; i++) {
120 std::cout << (i % 4 == 0 ? "\n " : " ");
121 if (x[i] == 0) std::cout << "0,";
122 else if (x[i] == 1) std::cout << "1,";
123 else std::cout << x[i] << ",";
124 }
125 std::cout.unsetf(std::ios_base::showpoint);
126 std::cout << "\n " << x[size] << "\n};\n\n";
127
128 std::cout << "template<class RealType>\nconst RealType " << structname << "<RealType>::table_y[" << size+1 << "] = {";
129 std::cout.precision(20);
130 std::cout.setf(std::ios_base::showpoint);
131 for (unsigned i = 0; i < size; i++) {
132 std::cout << (i % 4 == 0 ? "\n " : " ");
133 if (y[i] == 0) std::cout << "0,";
134 else if (y[i] == 1) std::cout << "1,";
135 else std::cout << y[i] << ",";
136 }
137 std::cout.unsetf(std::ios_base::showpoint);
138 std::cout << "\n " << y[size] << "\n};\n\n";
139}