Source code for npplus.pcwise
# Copyright (c) 2016, David H. Munro
# All rights reserved.
# This is Open Source software, released under the BSD 2-clause license,
# see http://opensource.org/licenses/BSD-2-Clause for details.
"""A decorator to neatly write piecewise functions.
The decorator is an aid to writing functions of one variable `x` which
have different algorithms in different domains of `x`::
@pcwise
def fun(x):
def funlo(x):
return 1. - x
def funmid(x):
return numpy.sin(x)
def funhi(x):
return x**2 - 1.
return funlo, xa, funmid, xb, funhi
Defines ``fun(x)`` to be ``funlo(x)`` for ``x<xa``, ``funmid(x)`` for
``xa<=x<xb``, and ``funhi(x)`` for ``xb<=x``. Any number of domains
is allowed.
--------
"""
__all__ = ['pcwise']
from numpy import array, asarray, diff, zeros
from functools import wraps
[docs]def pcwise(f):
"""Decorator to simplify creating ``f(x)`` with algorithm dependent on `x`.
Use the pcwise decorator like this::
@pcwise
def f(x):
'''Return a function of a single real variable x.'''
def f0(x):
...algorithm when x < x1
def f1(x):
...algorithm when x1 <= x < x2
def f2(x):
...algorithm when x2 <= x < x3
<and so on>
return (f0, x1, f1, x2, f2, ...)
Any of the `fN` in the return value may be a string to raise a
ValueError with that string for any points in that range.
"""
fxtuple = f(None)
funcs = list(fxtuple[0::2])
xvals = array(fxtuple[1::2])
if len(funcs) != len(xvals)+1:
raise TypeError("@pcwise function must return "
"(f0,x1,...xN,fN) sequence.")
if xvals.size > 1 and (diff(xvals) <= 0).any():
raise TypeError("@pcwise function must return "
"(f0,x1,...xN,fN) sequence with increasing xI.")
for i, fn in enumerate(funcs):
try:
fn + ''
except TypeError:
pass
else: # replace string by function that raises ValueError(fn)
funcs[i] = _value_error_func(fn)
funcs = tuple(funcs)
@wraps(f)
def multif(x):
x = asarray(x)
s = x.shape
ialg = xvals.searchsorted(x)
if not s:
return funcs[ialg](x)
result = zeros(s)
for mask, f in [(ialg == i, fi) for i, fi in enumerate(funcs)]:
xm = x[mask]
if xm.size:
result[mask] = f(xm)
return result
return multif
def _value_error_func(msg):
def raise_value_error(x):
raise ValueError(msg)
return raise_value_error