import numpy as np
import matplotlib.pyplot as plt

# Für gegenebe Stützstellen xs und Auswertungen fs einer unbekannten Funktion
# gibt die Newtonschen dividierten Differenzen zurürck.
#
# Wären die Stützstellen und Auswertungen durch nummeriert
#   xs = [x_0, x_1, ..., x_n]
#   fs = [f[x_0], f[x_1], ..., f[x_n]]
# so entspricht die Ausgabe den dividierten Differenzen:
#   divDiff(xs, fs) = [f[x_0], f[x_0, x_1],..., f[x_0, x_1,...,x_n]]
# die für die Evaluierung des Interpolationspolynoms nötig sind. Beispielsweise
# wäre:
#   divDiff([3, 1, 5], [1, -3, 2]) # => [1, 2, -3/8]
def divDiff(xs, fs):
    assert(len(xs) != 0)
    assert(len(xs) == len(fs))

    sliding_window = np.lib.stride_tricks.sliding_window_view
    n = len(fs)
    w = np.copy(fs)
    dd = [fs[0]]

    for i in range(1, n):
        denom = np.sum(sliding_window(w, 2) * [-1, 1], axis = 1)
        numer = np.sum(sliding_window(xs, i+1) * np.hstack([-1, np.zeros(i-1), 1]), axis = 1)
        w = denom / numer
        dd = np.hstack([dd, w[0]])

    return dd

# Gibt eine Funktion zurück, die mittels Horner-Schema das Interpolationspolynom
# zu den gegebenen Daten (xs, fs) auswertet. Zu verwenden als:
#   xs = [-1, 1, 0, -2]
#   fs = [2, 4, 5, -5]
#   pf = interpolPoly(xs, fs)
#
#   pf(-3) # => -16.0
# (N.B. Aufgrund numerischer Fehler mag, bei Ihnen, das Ergebnis nicht exakt
# -16.0 sein.)
def interpolPoly(xs, fs):
    coeff = divDiff(xs, fs)
    def eval(x):
        res = coeff[len(coeff) - 1]
        for i in range(len(coeff) - 2, -1, -1):
            res = res * (x - xs[i]) + coeff[i]
        return res

    return eval

# Vergessen Sie nicht die auf dem Blatt beschriebenen Funktionen mit zugehörigen
# Interpolationspolynomen zu plotten.
if __name__ == "__main__":
    xs = np.array([-1, 0, 1])
    ws = np.array([-2*np.pi, -np.pi, 0, np.pi, 2*np.pi])

    pts = np.linspace(-5, 5)
    plt.plot(pts, np.sin(pts))
    plt.plot(pts, interpolPoly(xs, np.sin(xs))(pts))
    plt.plot(pts, interpolPoly(ws, np.sin(ws))(pts))
    plt.savefig("blatt02sin.png")

    xs = np.linspace(-1, 1, 8)
    ws = np.linspace(-1, 1, 12)
    zs = np.linspace(-1, 1, 15)

    pts = np.linspace(-1, 1)
    f = lambda x: 1 / (1 + 25 * x**2)
    plt.plot(pts, f(pts))
    plt.plot(pts, interpolPol(xs, f(xs))(pts), label="8 Punkte")
    plt.plot(pts, interpolPol(ws, f(ws))(pts), label="12 Punkte")
    plt.plot(pts, interpolPol(zs, f(zs))(pts), label="15 Punkte")
    plt.savefig("blatt02runge.png")
