from __future__ import division, print_function, absolute_import import sys import time import numpy as np from numpy.testing import dec, assert_ from scipy._lib.six import reraise from scipy.special._testutils import assert_func_equal try: import mpmath except ImportError: try: import sympy.mpmath as mpmath except ImportError: pass # ------------------------------------------------------------------------------ # Machinery for systematic tests with mpmath # ------------------------------------------------------------------------------ class Arg(object): """ Generate a set of numbers on the real axis, concentrating on 'interesting' regions and covering all orders of magnitude. """ def __init__(self, a=-np.inf, b=np.inf, inclusive_a=True, inclusive_b=True): self.a = a self.b = b self.inclusive_a = inclusive_a self.inclusive_b = inclusive_b if self.a == -np.inf: self.a = -np.finfo(float).max/2 if self.b == np.inf: self.b = np.finfo(float).max/2 def values(self, n): """Return an array containing approximatively `n` numbers.""" n1 = max(2, int(0.3*n)) n2 = max(2, int(0.2*n)) n3 = max(8, n - n1 - n2) v1 = np.linspace(-1, 1, n1) v2 = np.r_[np.linspace(-10, 10, max(0, n2-4)), -9, -5.5, 5.5, 9] if self.a >= 0 and self.b > 0: v3 = np.r_[ np.logspace(-30, -1, 2 + n3//4), np.logspace(5, np.log10(self.b), 1 + n3//4), ] v4 = np.logspace(1, 5, 1 + n3//2) elif self.a < 0 < self.b: v3 = np.r_[ np.logspace(-30, -1, 2 + n3//8), np.logspace(5, np.log10(self.b), 1 + n3//8), -np.logspace(-30, -1, 2 + n3//8), -np.logspace(5, np.log10(-self.a), 1 + n3//8) ] v4 = np.r_[ np.logspace(1, 5, 1 + n3//4), -np.logspace(1, 5, 1 + n3//4) ] elif self.b < 0: v3 = np.r_[ -np.logspace(-30, -1, 2 + n3//4), -np.logspace(5, np.log10(-self.b), 1 + n3//4), ] v4 = -np.logspace(1, 5, 1 + n3//2) else: v3 = [] v4 = [] v = np.r_[v1, v2, v3, v4, 0] if self.inclusive_a: v = v[v >= self.a] else: v = v[v > self.a] if self.inclusive_b: v = v[v <= self.b] else: v = v[v < self.b] return np.unique(v) class FixedArg(object): def __init__(self, values): self._values = np.asarray(values) def values(self, n): return self._values class ComplexArg(object): def __init__(self, a=complex(-np.inf, -np.inf), b=complex(np.inf, np.inf)): self.real = Arg(a.real, b.real) self.imag = Arg(a.imag, b.imag) def values(self, n): m = max(2, int(np.sqrt(n))) x = self.real.values(m) y = self.imag.values(m) return (x[:,None] + 1j*y[None,:]).ravel() class IntArg(object): def __init__(self, a=-1000, b=1000): self.a = a self.b = b def values(self, n): v1 = Arg(self.a, self.b).values(max(1 + n//2, n-5)).astype(int) v2 = np.arange(-5, 5) v = np.unique(np.r_[v1, v2]) v = v[(v >= self.a) & (v < self.b)] return v class MpmathData(object): def __init__(self, scipy_func, mpmath_func, arg_spec, name=None, dps=None, prec=None, n=5000, rtol=1e-7, atol=1e-300, ignore_inf_sign=False, distinguish_nan_and_inf=True, nan_ok=True, param_filter=None): self.scipy_func = scipy_func self.mpmath_func = mpmath_func self.arg_spec = arg_spec self.dps = dps self.prec = prec self.n = n self.rtol = rtol self.atol = atol self.ignore_inf_sign = ignore_inf_sign self.nan_ok = nan_ok if isinstance(self.arg_spec, np.ndarray): self.is_complex = np.issubdtype(self.arg_spec.dtype, np.complexfloating) else: self.is_complex = any([isinstance(arg, ComplexArg) for arg in self.arg_spec]) self.ignore_inf_sign = ignore_inf_sign self.distinguish_nan_and_inf = distinguish_nan_and_inf if not name or name == '': name = getattr(scipy_func, '__name__', None) if not name or name == '': name = getattr(mpmath_func, '__name__', None) self.name = name self.param_filter = param_filter def check(self): np.random.seed(1234) # Generate values for the arguments if isinstance(self.arg_spec, np.ndarray): argarr = self.arg_spec.copy() else: num_args = len(self.arg_spec) ms = np.asarray([1.5 if isinstance(arg, ComplexArg) else 1.0 for arg in self.arg_spec]) ms = (self.n**(ms/sum(ms))).astype(int) + 1 argvals = [] for arg, m in zip(self.arg_spec, ms): argvals.append(arg.values(m)) argarr = np.array(np.broadcast_arrays(*np.ix_(*argvals))).reshape(num_args, -1).T # Check old_dps, old_prec = mpmath.mp.dps, mpmath.mp.prec try: if self.dps is not None: dps_list = [self.dps] else: dps_list = [20] if self.prec is not None: mpmath.mp.prec = self.prec # Proper casting of mpmath input and output types. Using # native mpmath types as inputs gives improved precision # in some cases. if np.issubdtype(argarr.dtype, np.complexfloating): pytype = mpc2complex def mptype(x): return mpmath.mpc(complex(x)) else: def mptype(x): return mpmath.mpf(float(x)) def pytype(x): if abs(x.imag) > 1e-16*(1 + abs(x.real)): return np.nan else: return mpf2float(x.real) # Try out different dps until one (or none) works for j, dps in enumerate(dps_list): mpmath.mp.dps = dps try: assert_func_equal(self.scipy_func, lambda *a: pytype(self.mpmath_func(*map(mptype, a))), argarr, vectorized=False, rtol=self.rtol, atol=self.atol, ignore_inf_sign=self.ignore_inf_sign, distinguish_nan_and_inf=self.distinguish_nan_and_inf, nan_ok=self.nan_ok, param_filter=self.param_filter) break except AssertionError: if j >= len(dps_list)-1: reraise(*sys.exc_info()) finally: mpmath.mp.dps, mpmath.mp.prec = old_dps, old_prec def __repr__(self): if self.is_complex: return "" % (self.name,) else: return "" % (self.name,) def assert_mpmath_equal(*a, **kw): d = MpmathData(*a, **kw) d.check() def nonfunctional_tooslow(func): return dec.skipif(True, " Test not yet functional (too slow), needs more work.")(func) # ------------------------------------------------------------------------------ # Tools for dealing with mpmath quirks # ------------------------------------------------------------------------------ def mpf2float(x): """ Convert an mpf to the nearest floating point number. Just using float directly doesn't work because of results like this: with mp.workdps(50): float(mpf("0.99999999999999999")) = 0.9999999999999999 """ return float(mpmath.nstr(x, 17, min_fixed=0, max_fixed=0)) def mpc2complex(x): return complex(mpf2float(x.real), mpf2float(x.imag)) def trace_args(func): def tofloat(x): if isinstance(x, mpmath.mpc): return complex(x) else: return float(x) def wrap(*a, **kw): sys.stderr.write("%r: " % (tuple(map(tofloat, a)),)) sys.stderr.flush() try: r = func(*a, **kw) sys.stderr.write("-> %r" % r) finally: sys.stderr.write("\n") sys.stderr.flush() return r return wrap try: import posix import signal POSIX = ('setitimer' in dir(signal)) except ImportError: POSIX = False class TimeoutError(Exception): pass def time_limited(timeout=0.5, return_val=np.nan, use_sigalrm=True): """ Decorator for setting a timeout for pure-Python functions. If the function does not return within `timeout` seconds, the value `return_val` is returned instead. On POSIX this uses SIGALRM by default. On non-POSIX, settrace is used. Do not use this with threads: the SIGALRM implementation does probably not work well. The settrace implementation only traces the current thread. The settrace implementation slows down execution speed. Slowdown by a factor around 10 is probably typical. """ if POSIX and use_sigalrm: def sigalrm_handler(signum, frame): raise TimeoutError() def deco(func): def wrap(*a, **kw): old_handler = signal.signal(signal.SIGALRM, sigalrm_handler) signal.setitimer(signal.ITIMER_REAL, timeout) try: return func(*a, **kw) except TimeoutError: return return_val finally: signal.setitimer(signal.ITIMER_REAL, 0) signal.signal(signal.SIGALRM, old_handler) return wrap else: def deco(func): def wrap(*a, **kw): start_time = time.time() def trace(frame, event, arg): if time.time() - start_time > timeout: raise TimeoutError() return None # turn off tracing except at function calls sys.settrace(trace) try: return func(*a, **kw) except TimeoutError: sys.settrace(None) return return_val finally: sys.settrace(None) return wrap return deco def exception_to_nan(func): """Decorate function to return nan if it raises an exception""" def wrap(*a, **kw): try: return func(*a, **kw) except Exception: return np.nan return wrap def inf_to_nan(func): """Decorate function to return nan if it returns inf""" def wrap(*a, **kw): v = func(*a, **kw) if not np.isfinite(v): return np.nan return v return wrap def mp_assert_allclose(res, std, atol=0, rtol=1e-17): """ Compare lists of mpmath.mpf's or mpmath.mpc's directly so that it can be done to higher precision than double. """ try: len(res) except TypeError: res = list(res) n = len(std) if len(res) != n: raise AssertionError("Lengths of inputs not equal.") failures = [] for k in range(n): try: assert_(mpmath.fabs(res[k] - std[k]) <= atol + rtol*mpmath.fabs(std[k])) except AssertionError: failures.append(k) ndigits = int(abs(np.log10(rtol))) msg = [""] msg.append("Bad results ({} out of {}) for the following points:" .format(len(failures), n)) for k in failures: resrep = mpmath.nstr(res[k], ndigits, min_fixed=0, max_fixed=0) stdrep = mpmath.nstr(std[k], ndigits, min_fixed=0, max_fixed=0) if std[k] == 0: rdiff = "inf" else: rdiff = mpmath.fabs((res[k] - std[k])/std[k]) rdiff = mpmath.nstr(rdiff, 3) msg.append("{}: {} != {} (rdiff {})".format(k, resrep, stdrep, rdiff)) if failures: assert_(False, "\n".join(msg))