# Version vom 15.5.23
# Geodäten der Sphäre im R^3

from sympy import symbols, diag, cos
#https://de.wikipedia.org/wiki/SymPy
#https://www.sympy.org/en/index.html 

from einsteinpy.symbolic import MetricTensor, ChristoffelSymbols
#https://einsteinpy.org/

import numpy as np
#https://numpy.org/

from scipy.integrate import odeint
#https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.odeint.html

import matplotlib.pyplot as plt
#https://matplotlib.org/

#Geodätengleichung (2d) als System 2.Ordnung
#ch Christoffelsymbole
def geod(y, t, ch):
    theta1, phi1, tv, pv = y
# ch ist symbolisch, subs - substituiert die Variable theta mit theta1 (was im Aufruf dann immer eine Zahl ist) etc
    ch2 = ch.tensor().subs([(theta, theta1),(phi, phi1)])
# theta1'=tv, phi1'=pv, tv'=- \Gamma^{theta1}_{ij} x^ix^j, pv'=- \Gamma^{phi1}_{ij} x^ix^j,  (x^0=theta1, x^1=phi1)
    dydt = [tv, pv, -ch2[0,1,0]*tv**2-2*ch2[0,0,1]*tv*pv-ch2[0,1,1]*pv**2, -ch2[1,0,0]*tv**2-2*ch2[1,0,1]*tv*pv-ch2[1,1,1]*pv**2]
    return dydt



syms = symbols('theta phi')
theta, phi = syms
    

# Metrik in sphärischen Koordinaten - vgl. ART-Skript Bsp II.1.13
g = MetricTensor(diag(1, cos(theta)**2).tolist(), syms)
print('Metrik:', g.tensor())

ch = ChristoffelSymbols.from_metric(g)
# \Gamma_{ij}^k: äußerste Klammer k, innerste Klammer i
# print('Christoffelsymbole:', ch.tensor())


#Plotten
ax = plt.figure().add_subplot(projection='3d')

#Plotten der Sphäre mittels Breiten- und Längenkreisen
for thex in np.linspace(-np.pi/2,np.pi/2, 10):
    phy = np.linspace(0, 2*np.pi, 50)
    ax.plot(np.cos(thex)*np.sin(phy), np.cos(thex)*np.cos(phy), np.sin(thex), color='lightgray')
    ax.plot(np.cos(phy)*np.sin(thex), np.cos(phy)*np.cos(thex), np.sin(phy), color='lightgray')


#############################
# Geodätengleichung (2d) -- Definition der Geodätengleichung oben
# Anfangswerte (hier starten jeweils in theta=0, phi=0 mit unterschiedlichen Anfangsgeschwindigkeiten)
y0 = [0, 0, 2,1]
y1 = [0, 0, 1,1]

t = np.linspace(0, 10, 101)

sol = odeint(geod, y0, t, args=(ch,))
sol2 = odeint(geod, y1, t, args=(ch,))

#sol ist ein Array von 4er Tupeln (für jedes t der Wert (4er Tupel) der Lösung der ODE bei den jeweiligen Anfangswerten )
#zip macht daraus 4 einzelne Arrays
the, ph, thep, php = zip(*sol)
ax.plot(np.cos(the)*np.sin(ph), np.cos(the)*np.cos(ph), np.sin(the), label='geodesic1')


the, ph, thep, php = zip(*sol2)
ax.plot(np.cos(the)*np.sin(ph), np.cos(the)*np.cos(ph), np.sin(the), label='geodesic2')


ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.legend()


plt.show()


