# Version vom 15.5.23
# Geodäten und Krümmung des Zylinders im R^3


from sympy import symbols, diag
#https://de.wikipedia.org/wiki/SymPy
#https://www.sympy.org/en/index.html 

from einsteinpy.symbolic import MetricTensor, ChristoffelSymbols
#from einsteinpy.geodesic import Timelike
#from einsteinpy.plotting import StaticGeodesicPlotter
#https://einsteinpy.org/

import numpy as np
from scipy.integrate import odeint
#https://numpy.org/

import matplotlib.pyplot as plt
#https://matplotlib.org/


#Geodätengleichung (2d) als System 2.Ordnung
#ch Christoffelsymbole
def geod(y, t, ch):
    theta1, phi1, tv, pv = y
# ch ist symbolisch, subs - substituiert die Variable theta mit theta1 (was im Aufruf dann immer eine Zahl ist) etc
    ch2 = ch.tensor().subs([(z, theta1),(phi, phi1)])
# theta1'=tv, phi1'=pv, tv'=- \Gamma^{theta1}_{ij} x^ix^j, pv'=- \Gamma^{phi1}_{ij} x^ix^j,  (x^0=theta1, x^1=phi1)
    dydt = [tv, pv, -ch2[0,1,0]*tv**2-2*ch2[0,0,1]*tv*pv-ch2[0,1,1]*pv**2, -ch2[1,0,0]*tv**2-2*ch2[1,0,1]*tv*pv-ch2[1,1,1]*pv**2]
    return dydt


syms = symbols('theta phi')
z, phi  = syms
    

# Metrik in sphärischen Koordinaten - vgl. ART-Skript Bsp II.1.13
g = MetricTensor(diag(1, 1).tolist(), syms)
print('Metrik:', g.tensor())
#print('g[1=theta,2=phi]=', g[0,1], '\n')

ch = ChristoffelSymbols.from_metric(g)
# \Gamma_{ij}^k: äußerste Klammer k, innerste Klammer i
print('Christoffelsymbole:', ch.tensor())
#Da g konstante Matrix, alle Christoffelsymbole und damit auch die Krümmung Null

###############
# Plotten des Zylinders mittels Meridianen und Breitenkreisen
ax = plt.figure().add_subplot(projection='3d')
for thez in np.linspace(-2,20, 10):
    phy = np.linspace(0, 2*np.pi, 30)
    ax.plot(np.sin(phy), np.cos(phy), thez, color='lightgray')

for phy in np.linspace(0, 2*np.pi, 10):
    s = np.linspace(-2.0,2.0, 20)
    ax.plot([np.sin(phy)]*len(s), [np.cos(phy)]*len(s), s, color='lightgray')
    
#############################
# Geodätengleichung (2d) -- Definition der Geodätengleichung oben
y0 = [0, 0, 0.7,0.35]
y1 = [0, 0, 1,0]
y2 = [0, 0, 0, 1]

t = np.linspace(0, 20, 200)


sol = odeint(geod, y0, t, args=(ch,))
sol2 = odeint(geod, y1, t, args=(ch,))

t2 = np.linspace(0, 20, 200)
sol3 = odeint(geod, y2, t2, args=(ch,))


#Plotten der Lösung
z , phi , xv, yv= zip(*sol)
ax.plot(np.sin(phi), np.cos(phi), z, label='geodesic1')

z , phi , xv, yv= zip(*sol2)
ax.plot(np.sin(phi), np.cos(phi), z, label='geodesic2')

z , phi , xv, yv= zip(*sol3)
ax.plot(np.sin(phi), np.cos(phi), z, label='geodesic3')


ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
#ax.legend()

plt.show()
