import numpy as np

from blatt08 import partialPivotLU

# forward substitution: Lx = b
def forward(L, b):
    x = np.zeros_like(b)
    for i in range(L.shape[0]):
        x[i] = (b[i] - L[i, :i] @ x[:i]) / L[i, i]
    return x

# back substitution: Ux = y
def backward(U, y):
    x = np.zeros_like(y)
    for i in reversed(range(U.shape[0])):
        x[i] = (y[i] - U[i, i+1:] @ x[i+1:]) / U[i, i]
    return x

# given correctly shaped data returns the (approximate) solution to the linear
# system Ax = b, calculated using an LU decomposition with (partial) pivoting.
def solveLS(A, b):
    M = A.copy()
    p = partialPivotLU(M)
    L = np.tril(M, -1) + np.diag(np.ones(p.size))

    y = forward(L, b[p])
    return backward(M, y)

# given correctly shaped data returns the (approximate) solution to the linear
# system Ax = b. In addition to using an LU decomposition an afterburner
# iteration scheme is used to improve the accuracy of the result.
def solveLSafterburner(A, b, LU, p, maxiter = 20):
    # first we find the usual solution we'd get with the LU decomposition. This
    # solution x may be inacurate and later we will use an iteration scheme to
    # improve the accuracy.
    L = np.tril(LU, -1) + np.diag(np.ones(p.size))
    x = backward(LU, forward(L, b[p]))

    for i in range(maxiter):
        r = b - A @ x
        d = backward(LU, forward(L, r[p]))

        x = x + d
        if np.linalg.norm(d) < 10**-6:
            break

    return x

if __name__ == "__main__":
    A = np.array([[2, 1, 6], [2, 6, 9], [4, 2, 10]], dtype=float)
    Aapprox = A + np.diag([0.1, 0, 0])
    b = np.array([14, 20, 24], dtype=float)

    inacurateX = solveLS(Aapprox, b)
    p = partialPivotLU(Aapprox)
    acurateX = solveLSafterburner(A, b, Aapprox, p)

    actualX = solveLS(A, b)

    print(inacurateX)
    print(acurateX)
    print(actualX)

