# Copyright (c) 2016, David H. Munro
# All rights reserved.
# This is Open Source software, released under the BSD 2-clause license,
# see http://opensource.org/licenses/BSD-2-Clause for details.
"""Periodic variants of solve_banded.
Adds corner elements to banded matrix connecting final elements of `x`
to first equations and first elements of `x` to final equations.
Also include `solves_banded`, a variant of `solveh_banded` that does not
require a positive definite symmetric matrix.
--------
"""
from numpy import asfarray, zeros, eye, arange, roll, diag
from numpy import concatenate
from scipy.linalg import solve_banded, solve
[docs]def solve_periodic(l_and_u, ab, b, overwrite_ab=False, overwrite_b=False,
check_finite=True):
"""Variant of solve_banded which includes matrix corner elements.
Layout of banded matrix `ab` is::
ab[u+i-j, j] == a[i, j]
ab[i, j] == a[i+j-u, j]
Example (l,u)=(2,1) for 6x6::
A50 a01 a12 a23 a34 a45
a00 a11 a22 a33 a44 a55
a10 a21 a32 a43 a54 A05
a20 a31 a42 a53 A04 A15
A50 0 0 x x x copy last row before first
a00 a01 0 0 A04 A05 <-- matrix begins here
a10 a11 a12 0 0 A15
a20 a21 a22 a23 0 0
0 a31 a32 a33 a34 0
0 0 a42 a43 a44 a45
A50 0 0 a53 a54 a55 <-- matrix ends here
x x 0 0 A04 A05 copy first two rows after last
x x x 0 0 A15
See Also
--------
scipy.linalg.solve_banded : for description of parameters and returns
Notes
-----
Outline of the algorithm:
1. Solve `ab` without the corner elements using solve_banded.
2. Solve `ab` without corner elements for ``b=0`` except 1 for each
of the first `u` equations and last `l` equations, a total of ``u+l``
solves with one more call to solve_banded.
3. Construct a dense ``u+l`` square matrix and solve to find the
linear combination of the results of (2) which compensates
for the corner elements omitted in (1).
"""
ab, b = asfarray(ab), asfarray(b)
dtype = (ab.ravel()[0] + b.ravel()[0]).dtype
n = ab.shape[1] # number of equations == number of unknowns
l, u = l_and_u # number of lower, upper diagonals
lpu = l + u # < n (or solve_banded will fail)
# first u variables, last u equations are the u corner (upper right of a)
# last l variables, first l equations are the l corner (lower left of a)
# construct l+u square matrix holding the corner matrix elements
# 0 A04 A05 begin with the l lowers, first l equation corner
# 0 0 A15
# A50 0 0 then the u uppers, last u equation corner
corn = zeros((lpu, lpu), dtype)
if u:
corn[-u:, :u] = _diag_to_norm((u-1, 0), ab[:u, :u])
if l:
corn[:l, -l:] = _diag_to_norm((0, l-1), ab[-l:, -l:])
# solve equation without corner marix elements (ignoring them)
xx = solve_banded(l_and_u, ab, b, overwrite_ab=False,
overwrite_b=overwrite_b, check_finite=check_finite)
# solve equation l+u times with only first l and last u equations non-0
bb = zeros((n, lpu), dtype)
if u:
bb[-u:, -u:] = eye(u, dtype=dtype)
bb[:l, :l] = eye(l, dtype=dtype)
xy = solve_banded(l_and_u, ab, bb, overwrite_ab=overwrite_ab,
overwrite_b=overwrite_b, check_finite=check_finite)
# if ad is the diagonal part of a (excluding corners)
# ad.dot(xx) = b and ad.dot(xy) = bb (0 except 1 in a single position)
mask = roll(arange(n) < lpu, -l) # first u, last l elements of x
acx = corn.dot(xx[mask, ...]) # (last u & 1st l eqns, trailing axes of b)
acy = corn.dot(xy[mask, :]) # (last u & 1st l eqns, 2nd axis of bb)
# seek x such that a.dot(x) = b
# seek p such that x = xx - xy.dot(p)
# a.dot(x) = ad.dot(xx) + acx - p - acy.dot(p) = b
# ==> p + acy.dot(p) = acx (lpu square dense solve finds p)
corn = eye(lpu, dtype=dtype) + acy # the matrix to be solved
p = solve(corn, acx, overwrite_a=True, overwrite_b=True,
check_finite=False)
return xx - xy.dot(p)
def _diag_to_norm(l_and_u, ab):
# Convert from diagonal ordered matrix form to normal matrix ordering.
l, u = l_and_u
bw, n = ab.shape
if bw != u+1+l:
raise ValueError("matrix shape inconsistent with given l and u")
a = diag(ab[u]) # construct (n,n) matrix with main diagonal
a1, np1 = a.ravel(), n+1
for i, d in enumerate(reversed(ab[:u])):
a1[i+1:n+n*(n-i-1):np1] = d[i+1:] # fill in upper diagonals
for i, d in enumerate(ab[u+1:]):
a1[(i+1)*n::np1] = d[:-1-i] # fill in lower diagonals
return a
[docs]def solves_periodic(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
check_finite=True):
"""Variant of solve_periodic for symmetric matrices.
Upper form unless keyword lower is false. Upper and lower forms are the
same as scipy.linalg.solveh_banded, except input to solves_periodic need
not be positive definite, and the entire ab array is used::
A40 A51 a02 a13 a24 a35 upper form
A50 a01 a12 a23 a34 a45
a00 a11 a22 a33 a44 a55
a00 a11 a22 a33 a44 a55 lower form
a10 a21 a32 a43 a54 A05
a20 a31 a42 a53 A04 A15
See Also
--------
scipy.linalg.solveh_banded : for description of parameters and returns
solve_periodic : for more details
"""
b, u = asfarray(b), ab.shape[0] - 1
ab = _solves_symgen(ab, u, lower)
# input overwrite_ab unused
return solve_periodic((u, u), ab, b, overwrite_ab=True,
overwrite_b=overwrite_b, check_finite=check_finite)
[docs]def solves_banded(ab, b, overwrite_ab=False, overwrite_b=False, lower=False,
check_finite=True):
"""Variant of solve_banded for arbitrary symmetric matrices.
Upper form unless keyword lower is false. Upper and lower forms are the
same as scipy.linalg.solveh_banded, except input to solves_periodic need
not be positive definite, and the entire ab array is used::
A40 A51 a02 a13 a24 a35 upper form
A50 a01 a12 a23 a34 a45
a00 a11 a22 a33 a44 a55
a00 a11 a22 a33 a44 a55 lower form
a10 a21 a32 a43 a54 A05
a20 a31 a42 a53 A04 A15
See Also
--------
scipy.linalg.solveh_banded : for description of parameters and returns
Notes
-----
Use solveh_banded if the matrix is known to be positive definite.
"""
b, u = asfarray(b), ab.shape[0] - 1
ab = _solves_symgen(ab, u, lower)
# input overwrite_ab unused
return solve_banded((u, u), ab, b, overwrite_ab=True,
overwrite_b=overwrite_b, check_finite=check_finite)
def _solves_symgen(ab, u, lower):
ab = asfarray(ab)
if lower:
a = ab[-1:0:-1].copy()
for i, row in enumerate(a):
row[:] = roll(row, u-i)
return concatenate((a, ab), axis=0)
else:
a = ab[-2::-1].copy()
for i, row in enumerate(a):
row[:] = roll(row, -1-i)
return concatenate((ab, a), axis=0)