Source code for core

"""
core.py — ToFUL Computational Backend
======================================

All numerical computation for the ToFUL moments calculator lives here.
The frontend (toful.py) should call only the public API at the bottom of
this module; it must not import scipy, numpy, or mpmath directly.

Architecture
------------
::

    toful_parser.py  →  normalise_for_eval()  →  core.py  →  results
                  build_safe_dict()

Performance strategy (applied in order of impact):
  1. compile()  — parse the expression AST once, reuse the code object.
  2. NumPy vectorisation — evaluate PMF/PDF over arrays, not Python loops.
  3. lru_cache — memoize moment results keyed on (expr, support, r, a,
                 precision) so re-renders do not recompute.
  4. Wynn epsilon-algorithm — best general convergence accelerator for series.
  5. Aitken delta-squared — lightweight fallback when epsilon is unstable.
  6. Cohen-Villegas-Zagier — optimal for alternating series.
  7. Ratio-bound tail estimate — geometric series tail correction.
  8. Domain-aware quadrature — Gauss-Laguerre for [0,inf), Gauss-Hermite
     for (-inf,inf), tanh-sinh (via mpmath) for high precision.
  9. SymPy symbolic path — exact closed-form where possible.
  10. mpmath arbitrary precision — fallback above _MPMATH_THRESHOLD dp.
  11. Parallel moment computation — ThreadPoolExecutor over orders r=1..R.
"""

from __future__ import annotations

import math
import warnings
import functools
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
from scipy import integrate, special
from scipy.special import roots_laguerre, roots_hermite

warnings.filterwarnings("ignore")

# Optional heavy dependencies — imported lazily
_sympy_available: Optional[bool] = None
_mpmath_available: Optional[bool] = None


def _try_import_sympy() -> bool:
    global _sympy_available
    if _sympy_available is None:
        try:
            import sympy  # noqa: F401
            _sympy_available = True
        except ImportError:
            _sympy_available = False
    return _sympy_available


def _try_import_mpmath() -> bool:
    global _mpmath_available
    if _mpmath_available is None:
        try:
            import mpmath  # noqa: F401
            _mpmath_available = True
        except ImportError:
            _mpmath_available = False
    return _mpmath_available


# ---------------------------------------------------------------------------
# Configuration constants
# ---------------------------------------------------------------------------

_MPMATH_THRESHOLD = 12    # precision (dp) above which we switch to mpmath
_MAX_WYNN_TERMS   = 120   # max terms fed to Wynn epsilon before giving up
_MAX_SERIES_TERMS = 500   # hard cap on infinite-series truncation
_CONVERGENCE_TOL  = 1e-12 # default convergence tolerance
_GAUSS_NODES      = 64    # nodes for Gauss-Laguerre / Hermite quadrature


# ---------------------------------------------------------------------------
# Unicode subscript helpers
# ---------------------------------------------------------------------------

SUBSCRIPT_MAP: Dict[int, str] = {
    0: "0", 1: "1", 2: "2", 3: "3", 4: "4",
    5: "5", 6: "6", 7: "7", 8: "8", 9: "9",
}

# Override with real Unicode subscripts
SUBSCRIPT_MAP = {
    0: "\u2080", 1: "\u2081", 2: "\u2082", 3: "\u2083", 4: "\u2084",
    5: "\u2085", 6: "\u2086", 7: "\u2087", 8: "\u2088", 9: "\u2089",
}


[docs] def to_subscript(n: int) -> str: """Convert a non-negative integer to its Unicode subscript string. Examples -------- >>> to_subscript(12) '12' # rendered as subscript digits """ return "".join(SUBSCRIPT_MAP.get(int(d), str(d)) for d in str(n))
# --------------------------------------------------------------------------- # Safe eval namespace # --------------------------------------------------------------------------- def _build_safe_dict(x_val: float = 0.0, extra: dict | None = None) -> dict: """ Build the restricted namespace used for all eval() calls. This is the single source of truth for what names are available in user expressions. The toful_parser.build_safe_dict() delegates here. Parameters ---------- x_val : float Value to bind to 'x'. extra : dict, optional Additional bindings (distribution parameters, etc.) Returns ------- dict Namespace for use as locals= in eval(), with {"__builtins__": {}} as globals to prevent access to Python builtins. """ ns: dict = { "pi": np.pi, "e": np.e, "inf": np.inf, "nan": np.nan, "abs": abs, "round": round, "min": min, "max": max, "sqrt": np.sqrt, "exp": np.exp, "log": np.log, "log2": np.log2, "log10": np.log10, "sin": np.sin, "cos": np.cos, "tan": np.tan, "asin": np.arcsin, "acos": np.arccos, "atan": np.arctan, "atan2": np.arctan2, "sinh": np.sinh, "cosh": np.cosh, "tanh": np.tanh, "ceil": np.ceil, "floor": np.floor, "sign": np.sign, "factorial": math.factorial, "gamma": special.gamma, "erf": math.erf, "erfc": math.erfc, # Greek parameter aliases (populated by parser) "lam": 1.0, "mu": 0.0, "sigma": 1.0, "alpha": 1.0, "beta": 1.0, "theta": 0.5, # Variable placeholder "x": x_val, } if extra: ns.update(extra) return ns @functools.lru_cache(maxsize=512) def _cached_compile(expr_str: str): """ Compile an expression string to a code object exactly once. The code object is cached by the expression string. Subsequent calls for the same string reuse the compiled object, avoiding repeated AST construction inside tight evaluation loops. Parameters ---------- expr_str : str Valid Python expression string. Returns ------- code object """ return compile(expr_str, "<toful_expr>", "eval") def _eval_scalar(code, x_val: float, extra: dict | None = None) -> float: """Evaluate a compiled expression at a single scalar x value. Parameters ---------- code : code object Pre-compiled expression (from _cached_compile). x_val : float extra : dict, optional Returns ------- float Result of evaluation, or 0.0 on any exception. """ ns = _build_safe_dict(x_val, extra) try: return float(eval(code, {"__builtins__": {}}, ns)) except Exception: return 0.0 def _eval_array(expr_str: str, x_arr: np.ndarray, extra: dict | None = None) -> np.ndarray: """ Vectorised evaluation of an expression over a NumPy array. Uses the cached compiled code object so the AST is parsed only once regardless of array length. Falls back to a Python element-wise loop if vectorisation raises (e.g. factorial on a float array). Parameters ---------- expr_str : str x_arr : np.ndarray extra : dict, optional Returns ------- np.ndarray of float64 """ code = _cached_compile(expr_str) ns_template = _build_safe_dict(0.0, extra) def _f(x: float) -> float: ns = dict(ns_template) ns["x"] = float(x) try: return float(eval(code, {"__builtins__": {}}, ns)) except Exception: return 0.0 try: vf = np.vectorize(_f, otypes=[float]) return vf(x_arr) except Exception: return np.array([_f(float(xi)) for xi in x_arr], dtype=float) # --------------------------------------------------------------------------- # Convergence acceleration algorithms # --------------------------------------------------------------------------- def _wynn_epsilon(partial_sums: List[float]) -> Tuple[float, bool, str]: """ Wynn epsilon-algorithm for sequence acceleration. Builds a 2-D epsilon-table from the list of partial sums and extracts the best limit estimate. Even-column diagonal entries converge to the true limit; odd columns are auxiliary. The recursive rule is: eps[-1, n] = 0 for all n eps[0, n] = S_n (n-th partial sum) eps[k+1, n] = eps[k-1, n+1] + 1 / (eps[k, n+1] - eps[k, n]) Reference: Wynn, P. (1956). On a device for computing the e_m(S_n) transformation. Mathematical Tables and Aids to Computation. Parameters ---------- partial_sums : list of float Sequence S_0, S_1, ..., S_n. Returns ------- estimate : float converged : bool info : str """ n = len(partial_sums) if n < 3: return partial_sums[-1], False, "Too few terms for Wynn epsilon" eps_prev = [0.0] * n # epsilon^{-1} column (all zeros by definition) eps_curr = list(partial_sums) # epsilon^{0} column = partial sums best_estimate = partial_sums[-1] best_col = 0 for col in range(1, n): eps_next = [] for row in range(n - col): denom = eps_curr[row + 1] - eps_curr[row] if abs(denom) < 1e-300: break # near-singular: algorithm breaks here val = eps_prev[row + 1] + 1.0 / denom eps_next.append(val) if not eps_next: break eps_prev = eps_curr[: len(eps_next) + 1] eps_curr = eps_next if col % 2 == 0 and eps_next: candidate = eps_next[0] if np.isfinite(candidate): best_estimate = candidate best_col = col converged = (best_col >= 2) and np.isfinite(best_estimate) info = ( f"Wynn epsilon converged at column {best_col}: {best_estimate:.15g}" if converged else f"Wynn epsilon: best estimate {best_estimate:.15g} (col {best_col})" ) return best_estimate, converged, info def _aitken_delta2(s0: float, s1: float, s2: float) -> float: """ Aitken delta-squared method: extrapolate limit from three partial sums. Formula: S* = S_0 - (S_1 - S_0)^2 / (S_2 - 2*S_1 + S_0) Returns the Aitken estimate, or s2 if the denominator is too small. Parameters ---------- s0, s1, s2 : float Three consecutive partial sums. Returns ------- float """ denom = (s2 - s1) - (s1 - s0) if abs(denom) < 1e-300: return s2 return s0 - (s1 - s0) ** 2 / denom def _cohen_villegas_zagier( expr_str: str, n_terms: int, extra: dict | None = None, ) -> Tuple[float, bool, str]: """ Cohen-Villegas-Zagier algorithm for alternating series acceleration. Computes near-optimal Chebyshev weights d_k and evaluates: S = (1/d_n) * sum_{k=0}^{n-1} (-1)^k * (d_k - d_n) * f(k) Achieves near-machine-precision in exactly n_terms evaluations for alternating monotone-decreasing series. Reference: Cohen, H., Villegas, F. R., & Zagier, D. (2000). Convergence acceleration of alternating series. Experimental Mathematics. Parameters ---------- expr_str : str PMF expression evaluated at non-negative integer k values. n_terms : int Number of terms / precision target. extra : dict, optional Returns ------- (estimate, True, info_string) """ n = n_terms # Build Chebyshev weights d = np.zeros(n + 1) d[n] = 1.0 for k in range(n - 1, -1, -1): d[k] = d[k + 1] + math.comb(n + k, n - k) * (4.0 ** (n - k)) / math.comb(2 * k, k) code = _cached_compile(expr_str) total = 0.0 for k in range(n): fk = _eval_scalar(code, float(k), extra) sign = 1.0 if k % 2 == 0 else -1.0 total += sign * (d[k] - d[n]) * fk estimate = total / d[n] return estimate, True, f"Cohen-Villegas-Zagier ({n} terms)" def _is_alternating(terms: List[float]) -> bool: """Return True if the term signs alternate for all consecutive pairs.""" if len(terms) < 4: return False return all(terms[i] * terms[i + 1] < 0 for i in range(len(terms) - 1)) # --------------------------------------------------------------------------- # Series pattern detection and extension # ---------------------------------------------------------------------------
[docs] @dataclass class SeriesPattern: """Detected pattern of a discrete sequence.""" kind: str # "arithmetic", "geometric", "custom" params: Dict description: str = ""
[docs] def detect_series_pattern(values: List[float]) -> SeriesPattern: """ Detect whether seed values form an arithmetic or geometric sequence. Parameters ---------- values : list of float Seed values from the user range input. Returns ------- SeriesPattern """ if len(values) < 2: return SeriesPattern("custom", {"values": values}, "Single value") arr = np.array(values, dtype=float) diffs = np.diff(arr) if np.all(np.abs(diffs - diffs[0]) < 1e-12): return SeriesPattern( "arithmetic", {"start": float(arr[0]), "diff": float(diffs[0])}, f"Arithmetic: start={arr[0]}, step={diffs[0]}", ) if np.all(arr != 0): ratios = arr[1:] / arr[:-1] if np.all(np.abs(ratios - ratios[0]) < 1e-12): return SeriesPattern( "geometric", {"start": float(arr[0]), "ratio": float(ratios[0])}, f"Geometric: start={arr[0]}, ratio={ratios[0]}", ) return SeriesPattern("custom", {"values": list(values)}, "Custom pattern")
[docs] def generate_extended_series(pattern: SeriesPattern, max_terms: int = 500) -> np.ndarray: """ Generate a NumPy array of the first max_terms values matching the pattern. Parameters ---------- pattern : SeriesPattern max_terms : int Returns ------- np.ndarray """ if pattern.kind == "arithmetic": start = pattern.params["start"] diff = pattern.params["diff"] return start + np.arange(max_terms, dtype=float) * diff if pattern.kind == "geometric": start = pattern.params["start"] ratio = pattern.params["ratio"] return start * (ratio ** np.arange(max_terms, dtype=float)) base = np.array(pattern.params.get("values", [0.0]), dtype=float) if len(base) < 2: return base[0] + np.arange(max_terms, dtype=float) last_diff = base[-1] - base[-2] extra_vals = base[-1] + np.arange(1, max_terms - len(base) + 1) * last_diff full = np.concatenate([base, extra_vals]) return full[:max_terms]
# --------------------------------------------------------------------------- # Discrete series summation with full convergence-acceleration cascade # ---------------------------------------------------------------------------
[docs] @dataclass class SeriesAnalysis: """Diagnostic output from a single discrete series computation.""" value: float converged: bool terms_used: int method: str info: str partial_sums: List[float] = field(default_factory=list) term_values: List[float] = field(default_factory=list)
def _sum_series_with_acceleration( expr_str: str, x_values: np.ndarray, r: int = 0, a: float = 0.0, tol: float = _CONVERGENCE_TOL, extra: dict | None = None, ) -> SeriesAnalysis: """ Compute sum_{x in x_values} (x - a)^r * f(x) using the full convergence-acceleration cascade: 1. Vectorised NumPy evaluation of all terms 2. Early exit if term magnitudes already below tolerance 3. Wynn epsilon-algorithm on partial sums 4. Aitken delta-squared on last three partial sums 5. Cohen-Villegas-Zagier for alternating series 6. Geometric ratio-bound tail correction 7. Partial-sum fallback (uncertain) Parameters ---------- expr_str : str Eval-normalised expression string. x_values : np.ndarray Support values (finite or truncated infinite series). r : int Power weight exponent (0 for probability sum, r for r-th moment). a : float Reference point (0 for raw, mu for central). tol : float extra : dict, optional Returns ------- SeriesAnalysis """ # Vectorised evaluation of f(x) over the full support pmf_vals = _eval_array(expr_str, x_values, extra) # Moment weights (x - a)^r weights = np.ones_like(x_values) if r == 0 else (x_values - a) ** r term_values: List[float] = (weights * pmf_vals).tolist() partial_sums: List[float] = list(np.cumsum(term_values)) n = len(partial_sums) # Strategy 1: term-magnitude early exit if n >= 20: recent = np.abs(term_values[-10:]) if np.all(recent < tol * (1.0 + abs(partial_sums[-1]))): return SeriesAnalysis( value=partial_sums[-1], converged=True, terms_used=n, method="term-magnitude", info=f"Converged: |term| < {tol:.2e} for last 10 terms", partial_sums=partial_sums, term_values=term_values, ) # Strategy 2: Wynn epsilon if n >= 5: wynn_val, wynn_ok, wynn_info = _wynn_epsilon(partial_sums) if wynn_ok: if abs(partial_sums[-1] - wynn_val) < tol * max(1.0, abs(wynn_val)): return SeriesAnalysis( value=wynn_val, converged=True, terms_used=n, method="wynn-epsilon", info=wynn_info, partial_sums=partial_sums, term_values=term_values, ) # Strategy 3: Aitken delta-squared if n >= 3: aitken_val = _aitken_delta2(partial_sums[-3], partial_sums[-2], partial_sums[-1]) diff = abs(partial_sums[-1] - aitken_val) if diff < tol * max(1.0, abs(aitken_val)): return SeriesAnalysis( value=aitken_val, converged=True, terms_used=n, method="aitken-delta2", info=f"Aitken delta-squared: {aitken_val:.15g} (delta={diff:.2e})", partial_sums=partial_sums, term_values=term_values, ) # Strategy 4: Cohen-Villegas-Zagier for alternating series if _is_alternating(term_values[:min(30, n)]): try: cvz_val, _, cvz_info = _cohen_villegas_zagier( expr_str, min(n, 60), extra ) return SeriesAnalysis( value=cvz_val, converged=True, terms_used=min(n, 60), method="cohen-villegas-zagier", info=cvz_info, partial_sums=partial_sums, term_values=term_values, ) except Exception: pass # Strategy 5: geometric ratio-bound tail correction if n >= 30: recent_abs = np.abs(term_values[-20:]) nz = recent_abs[recent_abs > 1e-300] if len(nz) >= 4: ratios = nz[1:] / nz[:-1] avg_ratio = float(np.mean(ratios)) if avg_ratio < 0.999: tail = float(term_values[-1]) * avg_ratio / (1.0 - avg_ratio + 1e-300) estimate = partial_sums[-1] + tail return SeriesAnalysis( value=estimate, converged=True, terms_used=n, method="ratio-bound", info=f"Geometric tail: ratio={avg_ratio:.6f}, tail~{tail:.4g}", partial_sums=partial_sums, term_values=term_values, ) if avg_ratio >= 1.001: return SeriesAnalysis( value=partial_sums[-1], converged=False, terms_used=n, method="ratio-bound", info=f"Series appears to diverge (ratio={avg_ratio:.4f})", partial_sums=partial_sums, term_values=term_values, ) # Fallback: return raw partial sum, convergence uncertain return SeriesAnalysis( value=partial_sums[-1], converged=False, terms_used=n, method="partial-sum", info=f"Convergence uncertain after {n} terms (partial sum={partial_sums[-1]:.10g})", partial_sums=partial_sums, term_values=term_values, ) # --------------------------------------------------------------------------- # Probability validation # ---------------------------------------------------------------------------
[docs] @dataclass class ValidationResult: """Returned by the probability validators.""" is_valid: bool message: str integral_or_sum: float analysis: Dict = field(default_factory=dict)
[docs] def validate_drv_probabilities( expr_str: str, x_values, is_infinite: bool = False, tol: float = 0.01, extra: dict | None = None, ) -> ValidationResult: """ Validate that a discrete PMF sums to 1 and has no negative values. Parameters ---------- expr_str : str Eval-normalised PMF expression. x_values : array-like Support values (finite window or extended series for infinite). is_infinite : bool Whether the support is effectively unbounded beyond x_values. tol : float Tolerance on ``abs(sum - 1.0)``. Loosened for non-converged infinite series due to truncation error. extra : dict, optional Returns ------- ValidationResult """ x_arr = np.asarray(x_values, dtype=float) pmf_vals = _eval_array(expr_str, x_arr, extra) neg_mask = pmf_vals < -1e-15 if np.any(neg_mask): bad_x = x_arr[neg_mask][:3].tolist() return ValidationResult( is_valid=False, message=f"Negative probabilities detected at x = {bad_x}", integral_or_sum=float(np.sum(pmf_vals)), analysis={"series_type": "infinite" if is_infinite else "finite"}, ) analysis_obj = _sum_series_with_acceleration( expr_str, x_arr, r=0, a=0.0, extra=extra ) effective_tol = (0.05 if is_infinite and not analysis_obj.converged else tol) is_valid = abs(analysis_obj.value - 1.0) <= effective_tol msg = ( f"Valid PMF (sum = {analysis_obj.value:.10g}, method={analysis_obj.method})" if is_valid else f"PMF sums to {analysis_obj.value:.10g} (expected 1.0, " f"|delta| = {abs(analysis_obj.value - 1.0):.4g})" ) return ValidationResult( is_valid=is_valid, message=msg, integral_or_sum=analysis_obj.value, analysis={ "series_type": "infinite" if is_infinite else "finite", "terms_computed": analysis_obj.terms_used, "convergence_info": analysis_obj.info, "method": analysis_obj.method, "converged": analysis_obj.converged, "partial_sums": analysis_obj.partial_sums[-10:], }, )
[docs] def validate_crv_pdf( expr_str: str, bounds: Tuple[float, float], precision: int = 8, extra: dict | None = None, ) -> ValidationResult: """ Validate that a continuous PDF integrates to 1 and is non-negative. Quadrature rule selected based on the integration domain: - Finite [a, b] : scipy.integrate.quad (Gauss-Kronrod, adaptive) - Semi-infinite [0, inf): Gauss-Laguerre nodes (exact for exp weight) - Doubly-infinite : Gauss-Hermite nodes (exact for Gaussian weight) - High precision (>12dp): mpmath tanh-sinh quadrature Parameters ---------- expr_str : str Eval-normalised PDF expression. bounds : (float, float) Integration limits. precision : int Target decimal places; triggers mpmath above _MPMATH_THRESHOLD. extra : dict, optional Returns ------- ValidationResult """ lower, upper = bounds code = _cached_compile(expr_str) def pdf(x: float) -> float: return max(0.0, _eval_scalar(code, x, extra)) # Spot-check non-negativity at 100 test points if np.isfinite(lower) and np.isfinite(upper): test_x = np.linspace(lower, upper, 100) elif not np.isfinite(lower) and not np.isfinite(upper): test_x = np.concatenate([np.linspace(-100, -1, 20), np.linspace(-1, 1, 40), np.linspace(1, 100, 20)]) elif not np.isfinite(lower): test_x = np.linspace(upper - 200, upper, 50) else: test_x = np.linspace(lower, lower + 200, 50) test_vals = _eval_array(expr_str, test_x, extra) neg_mask = test_vals < -1e-10 if np.any(neg_mask): bad = test_x[neg_mask][:3].tolist() return ValidationResult( is_valid=False, message=f"Negative PDF value at x ~ {bad}", integral_or_sum=0.0, ) integral_val, error, method = _quadrature_integrate( pdf, lower, upper, precision=precision, moment_r=0, a=0.0 ) tol = max(1e-4, 10 ** (-precision + 2)) is_valid = abs(integral_val - 1.0) <= tol msg = ( f"Valid PDF (integral = {integral_val:.10g} +/- {error:.2e}, method={method})" if is_valid else f"PDF integrates to {integral_val:.10g} (expected 1.0, method={method})" ) return ValidationResult( is_valid=is_valid, message=msg, integral_or_sum=integral_val, analysis={"method": method, "error_estimate": error}, )
# --------------------------------------------------------------------------- # Domain-aware quadrature # --------------------------------------------------------------------------- def _quadrature_integrate( func: Callable[[float], float], lower: float, upper: float, precision: int = 8, moment_r: int = 0, a: float = 0.0, ) -> Tuple[float, float, str]: """ Integrate func(x) from lower to upper using the best available method. Selection logic: - precision > _MPMATH_THRESHOLD -> mpmath tanh-sinh - (-inf, inf) -> Gauss-Hermite (cross-checked with quad) - [0, inf) -> Gauss-Laguerre (cross-checked with quad) - everything else -> scipy.integrate.quad (Gauss-Kronrod) Parameters ---------- func : callable The integrand (already includes the (x-a)^r moment weight if needed). lower, upper : float Integration limits. precision : int Target decimal places. moment_r, a : int, float Unused directly here; passed for logging/info. Returns ------- (value, error_estimate, method_name) """ eps = 10 ** (-min(precision + 4, 15)) # High-precision fallback via mpmath if precision > _MPMATH_THRESHOLD and _try_import_mpmath(): return _mpmath_integrate(func, lower, upper, precision) both_inf = np.isinf(lower) and np.isinf(upper) right_inf = np.isfinite(lower) and np.isinf(upper) # Gauss-Hermite for doubly-infinite integrals if both_inf: try: gh_val, gh_err = _gauss_hermite_integrate(func) q_val, q_err = integrate.quad( func, -np.inf, np.inf, limit=200, epsabs=eps, epsrel=eps ) if abs(gh_val - q_val) < 1e-4 * max(1.0, abs(gh_val)): return float(gh_val), float(gh_err), "gauss-hermite+quad" return float(q_val), float(q_err), "quad-adaptive" except Exception: pass # Gauss-Laguerre for [0, inf) if right_inf and abs(lower) < 1e-10: try: gl_val, gl_err = _gauss_laguerre_integrate(func) q_val, q_err = integrate.quad( func, lower, np.inf, limit=200, epsabs=eps, epsrel=eps ) if abs(gl_val - q_val) < 1e-4 * max(1.0, abs(gl_val)): return float(gl_val), float(gl_err), "gauss-laguerre+quad" return float(q_val), float(q_err), "quad-adaptive" except Exception: pass # Standard adaptive quad try: val, err = integrate.quad(func, lower, upper, limit=500, epsabs=eps, epsrel=eps) return float(val), float(err), "quad-adaptive" except Exception: try: val, err = integrate.quad(func, lower, upper, limit=100) return float(val), float(err), "quad-fallback" except Exception as ex: return 0.0, np.inf, f"integration-failed: {ex}" def _gauss_laguerre_integrate( func: Callable[[float], float], n: int = _GAUSS_NODES, ) -> Tuple[float, float]: """ Gauss-Laguerre quadrature for integral from 0 to inf of f(x) dx. The weight function for Laguerre polynomials is e^{-x}, so nodes and weights satisfy sum w_i * g(x_i) = integral of e^{-x} g(x) dx. We want integral of f(x) dx, so set g(x) = f(x) * e^x. Includes a Richardson error estimate using half the nodes. Parameters ---------- func : callable n : int Number of quadrature nodes. Returns ------- (estimate, error_estimate) """ nodes, weights = roots_laguerre(n) vals = np.array([func(float(xi)) * np.exp(float(xi)) for xi in nodes]) estimate = float(np.dot(weights, vals)) nodes2, weights2 = roots_laguerre(n // 2) vals2 = np.array([func(float(xi)) * np.exp(float(xi)) for xi in nodes2]) estimate2 = float(np.dot(weights2, vals2)) return estimate, abs(estimate - estimate2) def _gauss_hermite_integrate( func: Callable[[float], float], n: int = _GAUSS_NODES, ) -> Tuple[float, float]: """ Gauss-Hermite quadrature for integral from -inf to inf of f(x) dx. The weight function for Hermite polynomials is e^{-x^2}, so we compensate by multiplying f(x) by e^{x^2}. Includes a Richardson error estimate using half the nodes. Parameters ---------- func : callable n : int Returns ------- (estimate, error_estimate) """ nodes, weights = roots_hermite(n) vals = np.array([func(float(xi)) * np.exp(float(xi) ** 2) for xi in nodes]) estimate = float(np.dot(weights, vals)) nodes2, weights2 = roots_hermite(n // 2) vals2 = np.array([func(float(xi)) * np.exp(float(xi) ** 2) for xi in nodes2]) estimate2 = float(np.dot(weights2, vals2)) return estimate, abs(estimate - estimate2) def _mpmath_integrate( func: Callable[[float], float], lower: float, upper: float, precision: int, ) -> Tuple[float, float, str]: """ Tanh-sinh (double-exponential) quadrature via mpmath. Used automatically when precision > _MPMATH_THRESHOLD. Tanh-sinh is particularly effective near endpoint singularities and for smooth functions on infinite domains. Parameters ---------- func : callable lower, upper : float precision : int Target decimal places. Returns ------- (value, error_estimate, "mpmath-tanh-sinh") """ import mpmath mpmath.mp.dps = precision + 10 lo = mpmath.mpf("-inf") if (np.isinf(lower) and lower < 0) else mpmath.mpf(lower) hi = mpmath.mpf("+inf") if (np.isinf(upper) and upper > 0) else mpmath.mpf(upper) try: val, err = mpmath.quad( lambda x: mpmath.mpf(func(float(x))), [lo, hi], error=True ) return float(val), float(err), "mpmath-tanh-sinh" except Exception as ex: return 0.0, np.inf, f"mpmath-failed: {ex}" # --------------------------------------------------------------------------- # SymPy symbolic moment computation # --------------------------------------------------------------------------- def _try_sympy_moment( expr_str: str, bounds: Tuple[float, float], r: int, a: float, ) -> Tuple[float | None, str]: """ Attempt a fully symbolic moment computation via SymPy. If SymPy can integrate (x-a)^r * PDF(x) in closed form, the exact result is returned as a float. This skips numerical integration entirely and gives machine-precision results. Parameters ---------- expr_str : str Eval-normalised PDF expression. bounds : (float, float) r : int Moment order. a : float Reference point. Returns ------- (value, method_name) if symbolic succeeded, else (None, reason) """ if not _try_import_sympy(): return None, "sympy-unavailable" # Skip expressions with Python conditionals (if/else) — SymPy cannot # evaluate relational truth values and will raise TypeError. if " if " in expr_str or " else " in expr_str: return None, "sympy-skipped: conditional expression" import sympy as sp x = sp.Symbol("x", real=True) sympy_locals = { "x": x, "factorial": sp.factorial, "sqrt": sp.sqrt, "exp": sp.exp, "log": sp.log, "sin": sp.sin, "cos": sp.cos, "tan": sp.tan, "pi": sp.pi, "e": sp.E, "abs": sp.Abs, "gamma": sp.gamma, } try: expr = sp.sympify(expr_str, locals=sympy_locals) except (sp.SympifyError, SyntaxError, Exception): return None, "sympy-parse-failed" integrand = (x - a) ** r * expr lower, upper = bounds lo = -sp.oo if (np.isinf(lower) and lower < 0) else sp.nsimplify(lower, rational=True) hi = sp.oo if (np.isinf(upper) and upper > 0) else sp.nsimplify(upper, rational=True) try: with warnings.catch_warnings(): warnings.simplefilter("ignore") result = sp.integrate(integrand, (x, lo, hi)) val = float(result.evalf(20)) if np.isfinite(val): return val, "sympy-exact" except Exception: pass return None, "sympy-integration-failed" # --------------------------------------------------------------------------- # Moment computation — result dataclass and engines # ---------------------------------------------------------------------------
[docs] @dataclass class MomentResult: """Result for a single moment computation.""" order: int reference: float value: float method: str converged: bool terms_or_nodes: int info: str
def _compute_drv_moment( expr_str: str, x_values: np.ndarray, r: int, a: float, is_infinite: bool = False, tol: float = _CONVERGENCE_TOL, extra: dict | None = None, ) -> MomentResult: """ Compute the r-th moment of a discrete RV about point a. Uses the full convergence-acceleration cascade. Parameters ---------- expr_str : str Eval-normalised PMF expression. x_values : np.ndarray Support values (finite window or truncated infinite series). r : int Moment order. a : float Reference point. is_infinite : bool tol : float extra : dict, optional Returns ------- MomentResult """ analysis = _sum_series_with_acceleration( expr_str, x_values, r=r, a=a, tol=tol, extra=extra ) return MomentResult( order=r, reference=a, value=analysis.value, method=analysis.method, converged=analysis.converged, terms_or_nodes=analysis.terms_used, info=analysis.info, ) def _compute_crv_moment( expr_str: str, bounds: Tuple[float, float], r: int, a: float, precision: int = 8, extra: dict | None = None, ) -> MomentResult: """ Compute the r-th moment of a continuous RV about point a. Pipeline: 1. SymPy symbolic integral (exact, zero numerical error) 2. Domain-aware numerical quadrature (Gauss-Laguerre / Hermite / quad) 3. mpmath tanh-sinh for precision > _MPMATH_THRESHOLD Parameters ---------- expr_str : str Eval-normalised PDF expression. bounds : (float, float) r : int a : float precision : int extra : dict, optional Returns ------- MomentResult """ # Attempt symbolic path first sym_val, sym_method = _try_sympy_moment(expr_str, bounds, r, a) if sym_val is not None: return MomentResult( order=r, reference=a, value=sym_val, method=sym_method, converged=True, terms_or_nodes=0, info="Exact closed-form result via SymPy", ) # Numerical path code = _cached_compile(expr_str) def integrand(x: float) -> float: pdf_val = max(0.0, _eval_scalar(code, x, extra)) return float((x - a) ** r) * pdf_val val, err, method = _quadrature_integrate( integrand, bounds[0], bounds[1], precision=precision, moment_r=r, a=a, ) return MomentResult( order=r, reference=a, value=val, method=method, converged=np.isfinite(val), terms_or_nodes=_GAUSS_NODES, info=f"integral (x-a)^{r} f(x) dx = {val:.10g} +/- {err:.2e}", ) # --------------------------------------------------------------------------- # Parallel moment computation — public entry point # ---------------------------------------------------------------------------
[docs] def compute_moments_parallel( expr_str: str, support_or_bounds, max_order: int, a: float, is_crv: bool, is_infinite: bool = False, precision: int = 8, tol: float = _CONVERGENCE_TOL, extra: dict | None = None, max_workers: int = 4, ) -> Dict[int, MomentResult]: """ Compute moments of orders 1 through max_order in parallel. Moment computations for different orders are independent. All orders are submitted to a ThreadPoolExecutor; results are collected as they complete. For max_order <= 2 the threading overhead is avoided and computation is sequential. Parameters ---------- expr_str : str Eval-normalised PMF/PDF expression. support_or_bounds : np.ndarray or (float, float) For DRV: array of support values. For CRV: (lower, upper) tuple. max_order : int Compute moments mu_1 through mu_{max_order}. a : float Reference point (0=raw, mu=central, other=custom). is_crv : bool True for continuous RV, False for discrete. is_infinite : bool For DRV: whether support is unbounded beyond support array. precision : int Target decimal places (affects quadrature tolerance and mpmath). tol : float Convergence tolerance for discrete series. extra : dict, optional Extra variable bindings. max_workers : int Thread pool size. Returns ------- dict mapping int -> MomentResult """ def _compute_one(r: int) -> Tuple[int, MomentResult]: if is_crv: return r, _compute_crv_moment( expr_str, support_or_bounds, r, a, precision, extra ) return r, _compute_drv_moment( expr_str, support_or_bounds, r, a, is_infinite, tol, extra ) results: Dict[int, MomentResult] = {} orders = list(range(1, max_order + 1)) if max_order <= 2: for r in orders: _, res = _compute_one(r) results[r] = res else: with ThreadPoolExecutor(max_workers=min(max_workers, max_order)) as pool: futures = {pool.submit(_compute_one, r): r for r in orders} for future in as_completed(futures): r, res = future.result() results[r] = res return results
# --------------------------------------------------------------------------- # Statistical summary from moments # ---------------------------------------------------------------------------
[docs] @dataclass class StatisticalSummary: """Derived statistical measures computed from central moments.""" mean: float variance: float | None = None std_dev: float | None = None skewness: float | None = None kurtosis: float | None = None excess_kurtosis: float | None = None
[docs] def compute_statistical_summary( mean: float, central_moments: Dict[int, MomentResult], ) -> StatisticalSummary: """ Derive variance, standard deviation, skewness, and kurtosis from central moments (moments computed about the mean). Standard formulae: variance = mu_2 std_dev = sqrt(abs(mu_2)) skewness = mu_3 / sigma^3 kurtosis = mu_4 / sigma^4 excess_kurtosis = kurtosis - 3 Parameters ---------- mean : float First raw moment (E[X]). central_moments : dict Mapping order -> MomentResult for central moments (a = mean). Returns ------- StatisticalSummary """ s = StatisticalSummary(mean=mean) if 2 in central_moments: s.variance = central_moments[2].value s.std_dev = np.sqrt(abs(s.variance)) if 3 in central_moments and s.std_dev and s.std_dev > 1e-15: s.skewness = central_moments[3].value / (s.std_dev ** 3) if 4 in central_moments and s.std_dev and s.std_dev > 1e-15: s.kurtosis = central_moments[4].value / (s.std_dev ** 4) s.excess_kurtosis = s.kurtosis - 3.0 return s
# --------------------------------------------------------------------------- # Range parsing utilities # ---------------------------------------------------------------------------
[docs] def parse_range_input(range_input: str) -> Tuple[np.ndarray, bool, str]: """ Parse a discrete range string into a NumPy array of support values. Accepts: "1, 2, 3, 4, 5" -> finite support "0, 1, 2, 3, ..." -> infinite arithmetic series (extended to _MAX_SERIES_TERMS) "1, 2, 4, 8, ..." -> infinite geometric series Parameters ---------- range_input : str Range string, already cleaned by toful_parser.normalise_range_input(). Returns ------- (values_array, is_infinite, pattern_description) Raises ------ ValueError If the string cannot be parsed as a comma-separated list of numbers. """ s = range_input.strip() is_infinite = s.endswith("...") or s.endswith("\u2026") if is_infinite: s = s.rstrip(".\u2026").rstrip(",").strip() try: base_values = [float(v.strip()) for v in s.split(",") if v.strip()] except ValueError as exc: raise ValueError(f"Cannot parse range '{range_input}': {exc}") from exc if not base_values: raise ValueError(f"Empty range string: '{range_input}'") if is_infinite: pattern = detect_series_pattern(base_values) extended = generate_extended_series(pattern, max_terms=_MAX_SERIES_TERMS) return extended, True, pattern.description values = np.array(sorted(base_values), dtype=float) return values, False, f"Finite support with {len(values)} values"
[docs] def parse_continuous_bound(bound_str: str) -> float: """ Parse a single bound string to a float, accepting infinity variants. Recognised infinity tokens: inf, infinity, +inf, +infinity, and their Unicode variants (inf_sign + oo). Parameters ---------- bound_str : str E.g. "0", "1.5", "inf", "-inf", "\u221e", "-\u221e". Returns ------- float """ s = bound_str.strip().lower() pos_inf = {"inf", "infinity", "\u221e", "+inf", "+infinity", "+\u221e"} neg_inf = {"-inf", "-infinity", "-\u221e"} if s in pos_inf: return np.inf if s in neg_inf: return -np.inf return float(s)
# --------------------------------------------------------------------------- # Legacy compatibility shim # --------------------------------------------------------------------------- # The old code in toful.py instantiated these classes directly. # Provide thin wrappers so old call sites keep working without modification # until toful.py is updated.
[docs] class InfiniteSeriesHandler: """Legacy shim — use detect_series_pattern() and generate_extended_series().""" @staticmethod def detect_series_pattern(values): p = detect_series_pattern(values) return p.kind, p.params @staticmethod def generate_extended_series(pattern_type, params, max_terms=100): p = SeriesPattern(kind=pattern_type, params=params) return generate_extended_series(p, max_terms).tolist() @staticmethod def estimate_infinite_sum(func_str, values, pattern_type, params): x_arr = np.asarray(values, dtype=float) res = _sum_series_with_acceleration(func_str, x_arr) return res.value, res.converged, res.info
[docs] class EnhancedProbabilityValidator: """Legacy shim — use validate_drv_probabilities() / validate_crv_pdf().""" @staticmethod def validate_drv_probabilities(func_str, range_values, is_infinite=False): r = validate_drv_probabilities(func_str, range_values, is_infinite) return r.is_valid, r.message, r.integral_or_sum, r.analysis @staticmethod def validate_crv_pdf(func_str, range_bounds): r = validate_crv_pdf(func_str, range_bounds) return r.is_valid, r.message, r.integral_or_sum
[docs] class EnhancedMomentCalculator: """Legacy shim — use compute_moments_parallel() or _compute_drv_moment().""" @staticmethod def calculate_drv_moment(func_str, range_values, r, a, is_infinite=False, max_iter=10**6, tol=1e-12): x_arr = np.asarray(range_values, dtype=float) res = _compute_drv_moment(func_str, x_arr, r, a, is_infinite, tol) analysis = { "converged": res.converged, "terms_used": res.terms_or_nodes, "convergence_info": res.info, } return res.value, analysis @staticmethod def calculate_drv_moment_infinite(func_str, range_values, r, a, max_iter=10**6, tol=1e-12): return EnhancedMomentCalculator.calculate_drv_moment( func_str, range_values, r, a, is_infinite=True, tol=tol ) @staticmethod def calculate_crv_moment(func_str, range_bounds, r, a): res = _compute_crv_moment(func_str, range_bounds, r, a) return res.value
[docs] def rth_moment(pmf, support, r, c=0, tol=1e-12, max_iter=10**6): """Legacy function — kept for backward compatibility.""" if support == "infinite": x_vals = np.arange(0, min(max_iter, _MAX_SERIES_TERMS), dtype=float) else: x_vals = np.asarray(support, dtype=float) import inspect # pmf may be a Python callable; wrap it as an expression string via eval # We cannot easily convert a callable back to a string, so use the # vectorised loop directly. weights = np.ones_like(x_vals) if r == 0 else (x_vals - c) ** r pmf_vals = np.array([float(pmf(float(x))) for x in x_vals]) term_values = weights * pmf_vals partial_sums = list(np.cumsum(term_values)) # Apply Wynn epsilon directly on the partial sums wynn_val, wynn_ok, _ = _wynn_epsilon(partial_sums) if wynn_ok: return wynn_val return partial_sums[-1]