import math
import numpy as np
import matplotlib.pyplot as plt

# The following two functions calculate the LU decomposition of the given matrix
# A using either parital or full pivoting. The resulting lower and upper
# triangular matricies override the passed matrix. When using partial pivoting
# the return value is a permutation array of the indicies of the rows of A and
# when using full pivoting the corresponding function shall return a tuple of
# permuation arrays of the indicies of the rows and columns of A respectively.
# Thus the following equalities shall hold, at least approximately (that is up
# to numerical error):
#     A = np.array([[...], [...], ..., [...]])
#
#     partial = A.copy()
#     p = partialPivotLU(partial)
#     L, U = np.tril(partial, -1), np.triu(partial)
#     L = L + np.diag(np.ones(L.shape[0]))
#
#     L @ U == A[p, :] # up to numerical error
#
#     full = A.copy()
#     p, q = fullPivotLU(full)
#     L, U = np.tril(full, -1), np.triu(full)
#     L = L + np.diag(np.ones(L.shape[0]))
#
#     L @ U == A[p, q] # up to numerical error
#
# You can base your solution on the provided code from the scripture from the
# lecture, but note that there a full permutation matrix is contstructed and the
# LU decomposition is also not done in-place.
def partialPivotLU(A):
    n, m = A.shape
    assert n == m, "Can only compute LU decomposition of square matrix."

    p = np.arange(n)
    for i in range(n - 1):
        j = i + np.argmax(np.abs(A[i:, i]))
        p[[i, j]] = p[[j, i]]
        A[[i, j], :] = A[[j, i], :]

        assert np.abs(A[i, i]) > 10**-10, "Are you sure this matrix has an LU decomposition?"

        A[i+1:, i] /= A[i, i]
        A[i+1:, i+1:] -= np.outer(A[i+1:, i], A[i, i+1:])

    return p

def fullPivotLU(A):
    n, m = A.shape
    assert n == m, "Can only compute LU decomposition of square matrix."

    p, q = np.arange(n), np.arange(n)
    for i in range(n - 1):
        # Only change is here, where we have to search for the overall largest
        # element and then swap that to be the pivot.
        j, k = np.unravel_index(np.argmax(np.abs(A[i:, i:])), (n-i, n-i))
        j, k = j + i, k + i
        p[[i, j]], q[[i, k]] = p[[j, i]], q[[k, i]]
        A[[i, j], :] = A[[j, i], :]
        A[:, [i, k]] = A[:, [k, i]]

        assert np.abs(A[i, i]) > 10**-10, "Are you sure this matrix has an LU decomposition?"

        A[i+1:, i] /= A[i, i]
        A[i+1:, i+1:] -= np.outer(A[i+1:, i], A[i, i+1:])

    return p, q

def pascalMatrix(n):
    f = lambda i, j: np.int64(math.comb(int(i+j), int(i)))
    return np.fromfunction(np.vectorize(f), (n, n)).astype(float)

if __name__ == "__main__":
    sizes = np.arange(25) + 2
    errsPartial, errsFull = [], []

    for n in sizes:
        A = pascalMatrix(n)
        original = A.copy()

        p = partialPivotLU(A)
        L, U = np.diag(np.ones(p.size)) + np.tril(A, -1), np.triu(A)
        errsPartial.append(np.linalg.norm(original[p, :] - (L @ U)))

        A = original.copy()
        p, q = fullPivotLU(A)
        L, U = np.diag(np.ones(p.size)) + np.tril(A, -1), np.triu(A)
        errsFull.append(np.linalg.norm(original[p, :][:, q] - (L @ U)))

    # Intersting to see that, in general, full pivoting leads to lower
    # approximation error. But: the error curve is not as "smooth", i.e. there
    # are peaks where the error is large at ~20 and ~25.
    plt.semilogy(sizes, np.array([errsPartial, errsFull]).T, label=["partial pivoting", "full pivoting"])
    plt.legend()
    plt.savefig("image.png")

