import numpy as np
import matplotlib.pyplot as plt

from blatt05 import *

# Note how there is no argument that specifies how many steps
# to go in Romberg's method, only a desired accuracy. This is not
# a hard condition (on the accuracy) but merely guidance within
# the limits of Romberg's method within "reasonable bounds" for
# the number of steps.
# Using the arguement maxiter is optional, though at times it is
# helpful as a failsafe catching stray or leaky implementations.
# A higher accuracy than 15 decimal places doesn't actually seem to work and
# also: I think Python's garbage collector isn't good enough to handle this ever
# growing set of lists (although only ever two of size ~n need to actually be
# held in memory). Possibly it can't cope with the exponentially growing number
# of points where we evaluate the function. This could also be done in a loop,
# though would be less elegant. Nicer would be if Python/Numpy were clever
# enough to fuse the evaluation with f with the following reduction.
def romberg(a, b, f, eps=10**-15, maxiter=20):
    fs = f(np.linspace(a, b, 3))

    # here we've manually written out the trapezoid rule:
    trap = (b - a)/2 * (fs[0]/2 + fs[1] + fs[2]/2)
    last = [trap, (4 * trap - (b - a) * (fs[0]/2 + fs[2]/2)) / (4 - 1)]

    j = 1
    currError = np.abs((fs[0] + fs[2])/2 - last[-2])
    while currError >= eps and j < maxiter - 1:
        j = j + 1
        repl = [last[0]/2 + (b - a) / 2**j * np.sum(f(np.linspace(a, b, 2**(j-1) + 1)[1:] - (b - a) * 2**-j))]
        for k in range(len(last)):
            repl.append((4**(k+1) * repl[k] - last[k]) / (4**(k+1) - 1))

        currError = np.abs(last[-1] - repl[-2])
        last = repl

    return last[-1]

if __name__ == "__main__":
    # We'll need to reset this call count every time we want to benchmark a
    # method of integration.
    def f(x):
        f.callCount = f.callCount + np.size(x)
        return np.cos(x - 1)**2 + np.sin(0.7 * x) + x
    f.callCount = 0

    a, b = -1, 2 * np.pi
    exactResult = (-10 + 10*np.sqrt(5) + 28*np.pi + 56*np.pi**2 + 7*(np.sin(4) - np.sin(2)) + 40*np.cos(7/10))/28

    for method, name, iterCounts in zip([chainedTrapezoid, chainedSimpson, chainedMilne],
                                    ["Trapez-Regel", "Simpson-Regel", "Milne-Regel"],
                                    [2+np.arange(50), 1+np.arange(40), 1+np.arange(25)]):
        callCounts, err = [], []
        for n in iterCounts:
            f.callCount = 0
            err.append(np.abs(exactResult - method(a, b, n, f)))
            callCounts.append(f.callCount)
        plt.semilogy(callCounts, err, label=name)

    callCounts, err = [], []
    for n in 1+np.arange(8):
        f.callCount = 0
        err.append(np.abs(exactResult - romberg(a, b, f, maxiter=n)))
        callCounts.append(f.callCount)
    plt.semilogy(callCounts, err, label="Romberg-Verfahren")

    plt.ylabel("Approximationsfehler")
    plt.xlabel("Anzahl der Aufrufe des Integranden")
    plt.legend()
    plt.savefig("image.png")

    # Das Romberg-Verfahren, egal ob adaptiv oder mit limitierter Schrittzahl,
    # funktioniert für die angegebene Funktion (x-pi)**4 * np.cos(2**n * x)
    # nicht gut, denn solange die Schrittzahl relativ klein bleibt fallen die
    # Stützstellen der iterierten Trapezregel zusammen mit dem Extrema des
    # Cosinus, sodass prinzipiell bloß die Funktion, die der Vorfaktor ist
    # integriert wird. Schlimmer wird die Lage noch (und das sollte auch
    # eigentlich die Funktion sein, die auf dem Blatt steht), wenn man einen
    # Integranden f(x) * np.sin(2**n * x) hat, da hier die Stützstellen sogar
    # mit den Nullstellen des Sinus zusammenfallen. Wenn man im letzteren Fall
    # nicht mindestens 2**(n+1) + 1 Stützstellen bei der Iterierten Trapezregel
    # hat, so ist das Ergebnis der Integration (fälschlicherweise) Null.
