import numpy as np
import matplotlib.pyplot as plt
from math import ceil, comb, factorial

from blatt02 import interpolPoly

# Die Interpolationspolynom produzierenden Funktion sollen wie im Folgenden
# dargestellt funktionieren:
#     pf = barIntEqui(-5, 5, [-1, 2, 3, 0])
#     pf(np.linspace(-1, 1))
# Hierbei wird also das Interpolationspolynom mit vier äquidistanten
# Stützstellen in der barizentrischen Darstellung berechnet. Die angegebenen
# Auswertungen entsprechen also der zu interpolierenden Funktion an den
# Stützstellen x = -5, -5/3, 5/3, 5.
# Beachten Sie außerdem, dass es für die Verwendung von Tschebyscheff Knoten
# ggf nötig wird diese auf das Intervall [a, b] zu skalieren, wie in der
# Vorlesung erklärt.

def barInt(ss, lambdas, fs):
    # ss für Stützstellen
    def eval(x):
        # Falls wir das Interpolationspolynom an einer der Stützstellen
        # auswerten wollen, so geben wir einfach die gegebene Auswertung der
        # unbekannten Funktion zurück.
        if x in ss:
            return fs[np.where(x == ss)]

        return np.sum(lambdas * fs / (x - ss)) / np.sum(lambdas / (x - ss))
    return np.vectorize(eval)

def barIntEqui(a, b, fs):
    n = len(fs) - 1
    # Im Skript laufen die Indizes von den lambdas zwischen 0 und n
    # einschließlich, die Formeln gehen also davon aus, dass es n + 1
    # Auswertungen gibt.
    lambdas = np.fromiter(map(lambda j: (-1)**j * comb(n, j), range(n+1)), dtype=float)
    return barInt(np.linspace(a, b, len(fs)), lambdas, fs)

def tschFstSS(a, b, n):
    return (a + b)/2 + (b - a)/2 * np.cos(np.pi / (2*(n-1) + 2) * (2 * np.arange(n) + 1))

def tschSndSS(a, b, n):
    return (a + b)/2 + (b - a)/2 * np.cos(np.pi / (n-1) * np.arange(n))

def barIntTschFst(a, b, fs):
    n = len(fs)
    # np.tile nutzen wir um ein alternierendes Vorzeichen zu garantieren.
    # Unglücklicherweise ist das zweite Argument allerdings die Anzahl an
    # Wiederholungen und nicht die gewünschte resultierende Länge.
    lambdas = np.tile([1, -1], ceil(n/2))[:n] * np.sin(np.pi / (2*(n-1) + 2) * (2 * np.arange(n) + 1))
    return barInt(tschFstSS(a, b, n), lambdas, fs)

def barIntTschSnd(a, b, fs):
    n = len(fs)
    lambdas = np.tile([1, -1], ceil(n/2))[:n] * np.hstack([1/2, np.repeat(1, n-2), 1/2])
    return barInt(tschSndSS(a, b, n), lambdas, fs)

if __name__ == "__main__":
    heaviside = lambda x: (np.sign(x) + 1) / 2
    a, b, n = -2, 2, 60

    ss = np.linspace(a, b, n)
    newton = interpolPoly(ss, heaviside(ss))
    equi = barIntEqui(a, b, heaviside(ss))

    ss = tschFstSS(a, b, n)
    tschFst = barIntTschFst(a, b, heaviside(ss))

    ss = tschFstSS(a, b, n)
    tschSnd = barIntTschSnd(a, b, heaviside(ss))

    xs = np.linspace(a, b, 100)
    fig, axs = plt.subplots(5, figsize=(15, 15))

    axs[0].plot(xs, heaviside(xs), label="Original")
    axs[1].plot(xs, newton(xs), label="Newton")
    axs[2].plot(xs, equi(xs), label="Äquidistant")
    axs[3].plot(xs, tschFst(xs), label="Tschebycheff 1. Art")
    axs[4].plot(xs, tschSnd(xs), label="Tschebycheff 2. Art")

    for i in range(0, 5):
        axs[i].legend()

    plt.savefig("image.png")
