146 lines
4.9 KiB
TypeScript
146 lines
4.9 KiB
TypeScript
import { type Options } from "./types";
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Individual distribution samplers
|
||
// ---------------------------------------------------------------------------
|
||
|
||
// Cached second normal deviate from Box-Muller transform.
|
||
let spareNormal: number | null = null;
|
||
|
||
/** Box-Muller with pair caching: uses half the RNG calls of naive Box-Muller. */
|
||
function normalRandom(mean: number, stddev: number): number {
|
||
if (spareNormal !== null) {
|
||
const v = spareNormal;
|
||
spareNormal = null;
|
||
return mean + stddev * v;
|
||
}
|
||
let u1 = Math.random();
|
||
for (let attempt = 0; u1 === 0 && attempt < 100; attempt++) {
|
||
u1 = Math.random();
|
||
}
|
||
if (u1 === 0) u1 = Number.EPSILON;
|
||
const u2 = Math.random();
|
||
const r = Math.sqrt(-2 * Math.log(u1));
|
||
spareNormal = r * Math.sin(2 * Math.PI * u2);
|
||
return mean + stddev * r * Math.cos(2 * Math.PI * u2);
|
||
}
|
||
|
||
/** Bernoulli trials; Normal approximation when n>10_000 and np, n(1-p) both >5. */
|
||
function binomialRandom(n: number, p: number): number {
|
||
if (n > 10_000 && n * p > 5 && n * (1 - p) > 5) {
|
||
const mean = n * p;
|
||
const stddev = Math.sqrt(n * p * (1 - p));
|
||
return Math.max(0, Math.min(n, Math.round(normalRandom(mean, stddev))));
|
||
}
|
||
let s = 0;
|
||
for (let i = 0; i < n; i++) {
|
||
if (Math.random() < p) s++;
|
||
}
|
||
return s;
|
||
}
|
||
|
||
/** Knuth's algorithm; Normal approximation for λ > 100 (avoids exp underflow). */
|
||
function poissonRandom(lambda: number): number {
|
||
if (lambda > 100) {
|
||
return Math.max(0, Math.round(normalRandom(lambda, Math.sqrt(lambda))));
|
||
}
|
||
const L = Math.exp(-lambda);
|
||
let k = 0;
|
||
let p = 1;
|
||
do {
|
||
k++;
|
||
p *= Math.random();
|
||
} while (p > L);
|
||
return k - 1;
|
||
}
|
||
|
||
/** Inverse CDF. Caller must ensure λ > 0. */
|
||
function exponentialRandom(lambda: number): number {
|
||
return -Math.log(Math.random() || Number.EPSILON) / lambda;
|
||
}
|
||
|
||
/** Urn model — simulate drawing without replacement. */
|
||
function hypergeometricRandom(N: number, K: number, n: number): number {
|
||
let s = 0;
|
||
let remainingK = K;
|
||
let remainingTotal = N;
|
||
const draws = Math.min(n, N);
|
||
for (let i = 0; i < draws; i++) {
|
||
if (Math.random() < remainingK / remainingTotal) {
|
||
s++;
|
||
remainingK--;
|
||
}
|
||
remainingTotal--;
|
||
}
|
||
return s;
|
||
}
|
||
|
||
// ---------------------------------------------------------------------------
|
||
// Dispatcher
|
||
// ---------------------------------------------------------------------------
|
||
|
||
export function validateOptions(opts: Options): void {
|
||
if (!Number.isInteger(opts.count) || opts.count < 1) {
|
||
throw new Error(`invalid count: ${opts.count}`);
|
||
}
|
||
if (!Number.isInteger(opts.decimals) || opts.decimals < 0 || opts.decimals > 100) {
|
||
throw new Error(`decimals must be 0–100, got ${opts.decimals}`);
|
||
}
|
||
switch (opts.dist) {
|
||
case "uniform":
|
||
if (opts.min > opts.max) throw new Error(`min (${opts.min}) > max (${opts.max})`);
|
||
break;
|
||
case "normal":
|
||
if (opts.stddev <= 0) throw new Error(`stddev must be > 0, got ${opts.stddev}`);
|
||
break;
|
||
case "binomial":
|
||
if (opts.trials < 0 || !Number.isInteger(opts.trials))
|
||
throw new Error(`trials must be a non-negative integer, got ${opts.trials}`);
|
||
if (opts.prob < 0 || opts.prob > 1)
|
||
throw new Error(`prob must be 0–1, got ${opts.prob}`);
|
||
break;
|
||
case "poisson":
|
||
case "exponential":
|
||
if (opts.lambda <= 0) throw new Error(`lambda must be > 0, got ${opts.lambda}`);
|
||
break;
|
||
case "hypergeometric":
|
||
if (opts.popSize < 0 || !Number.isInteger(opts.popSize))
|
||
throw new Error(`population size N must be a non-negative integer, got ${opts.popSize}`);
|
||
if (opts.successes < 0 || opts.successes > opts.popSize || !Number.isInteger(opts.successes))
|
||
throw new Error(`successes K must be 0–N, got ${opts.successes} (N=${opts.popSize})`);
|
||
if (opts.draws < 0 || opts.draws > opts.popSize || !Number.isInteger(opts.draws))
|
||
throw new Error(`draws n must be 0–N, got ${opts.draws} (N=${opts.popSize})`);
|
||
break;
|
||
}
|
||
}
|
||
|
||
export function* generate(opts: Options): Generator<number> {
|
||
validateOptions(opts);
|
||
for (let i = 0; i < opts.count; i++) {
|
||
let v: number;
|
||
switch (opts.dist) {
|
||
case "uniform":
|
||
v = opts.decimals === 0
|
||
? Math.floor(Math.random() * (opts.max - opts.min + 1)) + opts.min
|
||
: Math.random() * (opts.max - opts.min) + opts.min;
|
||
break;
|
||
case "normal":
|
||
v = normalRandom(opts.mean, opts.stddev);
|
||
break;
|
||
case "binomial":
|
||
v = binomialRandom(opts.trials, opts.prob);
|
||
break;
|
||
case "poisson":
|
||
v = poissonRandom(opts.lambda);
|
||
break;
|
||
case "exponential":
|
||
v = exponentialRandom(opts.lambda);
|
||
break;
|
||
case "hypergeometric":
|
||
v = hypergeometricRandom(opts.popSize, opts.successes, opts.draws);
|
||
break;
|
||
}
|
||
yield opts.decimals > 0 ? v : Math.round(v);
|
||
}
|
||
}
|