import numpy as np
import matplotlib.pyplot as plt

from blatt02 import divDiff

# Für vorgegebene Knotenpunkte xs und Auswertungen fs einer ansonsten
# unbekannten Funktion soll eine Funktion konstruiert werden, die den not-a-knot
# Spline auswertet.
# Um den Spline zu definieren, der eine stückweise kubische Funktion ist, ist
# ggf die Verwendung von np.piecewise anzuraten. Dividierte Differenzen dürfen
# mit der von der zweiten Programmieraufgabe bekannten Funktion divDiff
# berechnet werden.
def splineNotAKnot(xs, fs):
    assert(len(xs) == len(fs))

    n = len(xs) - 1

    # Konstruiere das lineare Gleichungssystem um die Momente zu bestimmen. Die
    # Bezeichnungen entsprechen denen im Skript.
    h = xs[1:] - xs[:-1]
    mu = h[:-1] / (h[:-1] + h[1:])
    lamb = h[1:] / (h[:-1] + h[1:])

    # Jede der nachfolgenden Zeilen entspricht einem der tri-diagonal Streifen
    # in der resultierenden Matrix.
    A = np.diag(np.hstack([mu, -1]), k=-1) \
      + np.diag(np.hstack([lamb[0], 2 * np.ones(n-1), mu[-1]])) \
      + np.diag(np.hstack([-1, lamb]), k=1)
    # Hier habe ich keinen schönen Weg gefunden diese beiden Einträge eben nicht
    # so "manuell" setzen zu müssen.
    A[0,2] = mu[0]
    A[-1,-3] = lamb[-1]

    # Wir benötigen die dividierten Differenzen von f zum einen mit zwei
    # und mit drei konsekutiven Punkten.
    sliding = np.lib.stride_tricks.sliding_window_view
    diffs = np.array([divDiff(x, f) for x, f in zip(sliding(xs, 3), sliding(fs, 3))])
    diffsTwoPoints = np.hstack([diffs[:,-2], divDiff(xs[-2:], fs[-2:])[-1]])

    rhs = 6 * np.hstack([0, diffs[:,-1], 0])

    M = np.linalg.solve(A, rhs)
    c = diffsTwoPoints - h/6 * (M[1:] - M[:-1])
    d = fs[:-1] - h**2/6 * M[:-1]

    def eval(x):
        conditions = (xs[:-1] <= x) * (x < xs[1:])
        # Hier sind wir ein bisschen verschwenderisch, denn wir benötigen ja
        # lediglich die Auswertung der einzelnen kubischen Teilstücke der
        # Spline-Funktionen auf den entsprechend durch conditions angegebenen
        # Intevallen.
        s = 1/h * (M[1:] / 6 * (x - xs[:-1])**3 + M[:-1] / 6 * (xs[1:] - x)**3) + c * (x - xs[:-1]) + d

        return np.piecewise(x, conditions.T, s)

    # Auf das manuelle Vektorisieren der Funktion könnte man hier verzichten,
    # wenn man stärkere Bedinungen an die Eingaben zu eval setzt, damit man bei
    # der Auswertung der splines bessere Kontrolle über das broadcasting der
    # Axen der Eingabe hat.
    return np.vectorize(eval)

if __name__ == "__main__":
    f = np.sin
    xs = np.linspace(0, 2*np.pi, 8)

    spline = splineNotAKnot(xs, f(xs))

    pts = np.linspace(0, 2*np.pi)
    plt.plot(pts, spline(pts))
    plt.plot(xs, f(xs), '.', color="red")

    plt.savefig("image.png")
