import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from ..evaluation import Evaluation
from ..hamiltonian import get_pop_fourier
from ..dynamics import get_reduced_dm_eigs
from ..utils import get_conversion
from ..io import OPTIONS, load_color_palette
# ----------------------------------------------------------------------
def _get_colors():
dna_bases = OPTIONS["dna_bases"]
particles = OPTIONS["particles"]
# A, T, G, C, F
color_palette = load_color_palette("seaborn")["icefire7"]
colors_dna_bases = dict(zip(dna_bases, [color_palette[i] for i in [5, 6, 1, 0, 3]]))
# electron, hole, exciton
# color_palette = load_color_palette("seaborn")["icefire5"]
# colors_particles = dict(zip(particles, [color_palette[i] for i in [0, 4, 2]]))
colors_particles = dict(zip(particles, ["#459DD9", "#B74244", "#4EB572"]))
return colors_dna_bases, colors_particles
COLORS_DNA_BASES, COLORS_PARTICLES = _get_colors()
# ----------------------------------------------------------------------
[docs]
class Visualization(Evaluation):
"""
Visualization class for plotting and analyzing quantum DNA data.
This class extends the Evaluation class and provides various methods for visualizing
quantum DNA data, including heatmaps, population plots, coherence plots, eigenstate
distributions, Fourier analysis, and cumulative average population plots.
Attributes
----------
tb_sites : list
List of tight-binding sites.
kwargs : dict
Additional keyword arguments passed to the Evaluation class.
Methods
-------
plot_heatmap(heatmap_type="seaborn", fig=None, ax=None, dpi=None, **plot_kwargs)
Plot heatmaps for particle populations using seaborn or matplotlib.
plot_pop(tb_site, fig=None, ax=None, dpi=None, add_legend=True, **plot_kwargs)
Plot population dynamics for a specific tight-binding site.
plot_pops(fig=None, ax=None, dpi=None, **plot_kwargs)
Plot population dynamics for all tight-binding sites.
plot_pop_fourier(init_state, end_state, times, t_unit, fig=None, ax=None, dpi=None, add_legend=True, **plot_kwargs)
Plot population dynamics using Fourier analysis.
plot_coh(fig=None, ax=None, dpi=None, **plot_kwargs)
Plot coherence dynamics for particles.
plot_test_fourier(tb_site, fig=None, ax=None, dpi=None, **plot_kwargs)
Test Fourier analysis by comparing population dynamics and Fourier results.
plot_eigv(energy_unit="eV", fig=None, ax=None, dpi=None, color=None)
Plot eigenvalues of the system.
plot_eigs(eigenstate_idx, fig=None, ax=None, dpi=None)
Plot eigenstate distributions for a given eigenstate index.
plot_fourier(init_state, end_state, x_axis, fig=None, ax=None, dpi=None)
Plot Fourier amplitudes and frequencies or periods.
plot_average_pop(J_list, J_unit="100meV", fig=None, ax=None, dpi=None, **plot_kwargs)
Plot cumulative average population for varying Coulomb parameters.
"""
def __init__(self, tb_sites, **kwargs):
self.kwargs = kwargs
self.tb_sites = tb_sites
super().__init__(self.tb_sites, **kwargs)
# ----------------------------------------------------------------------
[docs]
def plot_heatmap(
self,
heatmap_type="seaborn",
fig=None,
ax=None,
dpi=None,
vmax_list=None,
cmaps=None,
**plot_kwargs,
):
if plot_kwargs is None:
plot_kwargs = {}
direction = plot_kwargs.get("direction", "horizontal")
if "direction" in plot_kwargs:
del plot_kwargs["direction"]
num_particles = len(self.particles)
if direction == "vertical":
x_num, y_num = num_particles, 1
else:
x_num, y_num = 1, num_particles
if fig is None:
if x_num == 1 and y_num == 1:
fig, ax = plt.subplots(dpi=dpi)
else:
fig, ax = plt.subplots(
x_num,
y_num,
figsize=(3.4 * y_num, 2.1 * x_num),
sharex=True,
sharey=True,
dpi=dpi,
)
ax = np.array(ax)
ax = ax.reshape((x_num, y_num))
# ---
pop_dict = self.get_pop()
if cmaps is None:
cmaps = {"electron": "Blues", "hole": "Reds", "exciton": "Greys"}
cmaps = {"electron": "Blues", "hole": "Reds", "exciton": "Greens"}
for i in range(x_num):
for j in range(y_num):
particle = self.particles[i + j]
particle_pop = np.array(
[value for key, value in pop_dict.items() if key.startswith(particle)]
)
if vmax_list is not None:
vmax = vmax_list[i + j]
else:
vmax = 1
if particle == "exciton":
vmax = np.max(particle_pop)
# seaborn heatmap (looks prettier in my opinion)
if heatmap_type == "seaborn":
heatmap = sns.heatmap(
particle_pop,
xticklabels=[],
yticklabels=[],
cmap=cmaps[particle],
ax=ax[i, j],
cbar=False,
vmax=vmax,
**plot_kwargs,
)
heatmap.figure.colorbar(
heatmap.collections[0],
ax=ax[i, j],
)
# matplotlib heatmap
if heatmap_type == "matplotlib":
im = ax[i, j].imshow(
particle_pop,
cmap=cmaps[particle],
aspect="auto",
vmax=vmax,
**plot_kwargs,
)
im.figure.colorbar(im, ax=ax[i])
ax[i, j].set_ylabel(particle.capitalize())
# ticks
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
y_len, x_len = particle_pop.shape
xticks = np.linspace(0, x_len, 4)
ax[i, j].set_xticks(xticks, labels=[int(x) for x in np.linspace(0, self.t_end, 4)])
yticks = np.arange(y_len) + 0.5
ax[i, j].set_yticks(yticks, labels=self.tb_sites_flattened)
for j in range(y_num):
ax[-1, j].set_xlabel("Time [" + self.t_unit + "]")
return fig, ax
[docs]
def plot_heatmap2(
self,
heatmap_type="seaborn",
fig=None,
ax=None,
dpi=None,
vmax_list=None,
cmaps=None,
number=None,
**plot_kwargs,
):
if plot_kwargs is None:
plot_kwargs = {}
direction = plot_kwargs.get("direction", "horizontal")
if "direction" in plot_kwargs:
del plot_kwargs["direction"]
num_particles = len(self.particles)
if direction == "vertical":
x_num, y_num = num_particles, 1
else:
x_num, y_num = 1, num_particles
if fig is None:
if x_num == 1 and y_num == 1:
fig, ax = plt.subplots(dpi=dpi)
else:
fig, ax = plt.subplots(
x_num,
y_num,
figsize=(3.4 * y_num, 2.1 * x_num),
sharex=True,
sharey=True,
dpi=dpi,
)
ax = np.array(ax)
ax = ax.reshape((x_num, y_num))
# ---
pop_dict = self.get_pop()
if cmaps is None:
cmaps = {"electron": "Blues", "hole": "Reds", "exciton": "Greys"}
cmaps = {"electron": "Blues", "hole": "Reds", "exciton": "Greens"}
for i in range(x_num):
for j in range(y_num):
particle = self.particles[i + j]
particle_pop = np.array(
[value for key, value in pop_dict.items() if key.startswith(particle)]
)
if vmax_list is not None:
vmax = vmax_list[i + j]
else:
vmax = 1
if particle == "exciton":
vmax = np.max(particle_pop)
# seaborn heatmap (looks prettier in my opinion)
if heatmap_type == "seaborn":
heatmap = sns.heatmap(
particle_pop,
xticklabels=[],
yticklabels=[],
cmap=cmaps[particle],
ax=ax[i, j],
cbar=False,
vmax=vmax,
**plot_kwargs,
)
if number in [1, 3]:
heatmap.figure.colorbar(
heatmap.collections[0],
ax=ax[i, j],
ticks=[0, vmax / 2, vmax],
)
# matplotlib heatmap
if heatmap_type == "matplotlib":
im = ax[i, j].imshow(
particle_pop,
cmap=cmaps[particle],
aspect="auto",
vmax=vmax,
**plot_kwargs,
)
im.figure.colorbar(im, ax=ax[i])
# ax[i, j].set_ylabel(particle.capitalize())
# ticks
ax[i, j].set_xticks([])
ax[i, j].set_yticks([])
y_len, x_len = particle_pop.shape
xticks = np.linspace(0, x_len, 4)
ax[i, j].set_xticks(xticks, labels=[int(x) for x in np.linspace(0, self.t_end, 4)])
yticks = np.arange(y_len) + 0.5
# ax[i, j].set_yticks(yticks, labels=self.tb_sites_flattened)
if number == 0:
ax[i, j].set_yticks(yticks, labels=["01C", "02G", "03G", "06G", "05C", "04C"])
if number == 1:
ax[i, j].set_yticks(yticks, labels=["01M", "02G", "03G", "06G", "05M", "04C"])
if number == 2:
ax[i, j].set_yticks(yticks, labels=["01C", "02G", "03G", "06G", "05C", "04C"])
if number == 3:
ax[i, j].set_yticks(yticks, labels=["01M", "02G", "03G", "06G", "05M", "04C"])
# for j in range(y_num):
# ax[-1, j].set_xlabel("Time [" + self.t_unit + "]")
return fig, ax
[docs]
def plot_pop(self, tb_site, fig=None, ax=None, dpi=None, add_legend=True, **plot_kwargs):
if plot_kwargs in [None, {}]:
plot_kwargs = {}
change_plot_kwargs = True
else:
change_plot_kwargs = False
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
pop = self.get_pop()
# plotting
for particle in self.particles:
if change_plot_kwargs:
plot_kwargs["color"] = COLORS_PARTICLES[particle]
plot_kwargs["label"] = particle
ax.plot(
self.times,
pop[particle + "_" + tb_site],
**plot_kwargs,
)
dna_base = self.tb_basis_sites_dict[tb_site]
x_center = self.t_end / 2
y_center = 0.8
ax.text(
x_center,
y_center,
dna_base,
ha="center",
va="center",
color="grey",
fontsize=15,
# fontweight='bold',
bbox={
"facecolor": "white",
"edgecolor": "white",
"boxstyle": "round,pad=0.2",
},
)
# plot settings
ax.set_ylim(0, 1.02)
if add_legend:
ax.set_ylabel("Population")
ax.set_xlabel("Time [" + self.t_unit + "]")
ax.legend()
return fig, ax
[docs]
def plot_pops(self, fig=None, ax=None, dpi=None, **plot_kwargs):
if plot_kwargs is None:
plot_kwargs = {}
direction = plot_kwargs.get("direction", "horizontal")
if "direction" in plot_kwargs:
del plot_kwargs["direction"]
if direction == "vertical":
x_num, y_num = self.num_sites_per_strand, self.num_channels
else:
x_num, y_num = self.num_channels, self.num_sites_per_strand
if fig is None:
if x_num == 1 and y_num == 1:
fig, ax = plt.subplots(dpi=dpi)
else:
fig, ax = plt.subplots(
x_num,
y_num,
figsize=(3.4 * y_num, 2.1 * x_num),
sharex=True,
sharey=True,
dpi=dpi,
)
ax = ax.reshape((x_num, y_num))
for i in range(x_num):
ax[i, 0].set_ylabel("Population")
for j in range(y_num):
ax[-1, j].set_xlabel("Time [" + self.t_unit + "]")
# ---
for i in range(x_num):
for j in range(y_num):
if direction == "vertical":
tb_site = f"({j}, {i})"
else:
tb_site = f"({i}, {j})"
_, ax[i, j] = self.plot_pop(tb_site, fig, ax[i, j], dpi, add_legend=False)
ax[0, 0].legend(self.particles, loc="upper right")
return fig, ax
[docs]
def plot_pop_fourier(
self,
init_state,
end_state,
times,
t_unit,
fig=None,
ax=None,
dpi=None,
add_legend=True,
**plot_kwargs,
):
if plot_kwargs is None:
plot_kwargs = {}
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
# calculation
self.unit = "rad/" + t_unit
amplitudes_dict, frequencies_dict, average_pop_dict = self.get_fourier(
init_state, end_state, ["amplitude", "frequency", "average_pop"]
)
pop_dict = {}
for particle in self.particles:
amplitudes = amplitudes_dict[particle]
frequencies = frequencies_dict[particle]
average_pop = average_pop_dict[particle]
pop_dict[particle] = [
get_pop_fourier(t, average_pop, amplitudes, frequencies) for t in times
]
# plotting
for particle in self.particles:
ax.plot(
times,
pop_dict[particle],
label=particle,
color=COLORS_PARTICLES[particle],
**plot_kwargs,
)
# plot settings
ax.set_ylim(0, 1.02)
if add_legend:
ax.set_ylabel("Population")
ax.set_xlabel("Time [" + self.t_unit + "]")
ax.legend(self.particles)
return fig, ax
[docs]
def plot_coh(self, fig=None, ax=None, dpi=None, add_legend=True, **plot_kwargs):
if plot_kwargs is None:
plot_kwargs = {}
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
# calculation
coh = self.get_coh()
if plot_kwargs is None:
plot_kwargs = {}
# plotting
for particle in self.particles:
plot_kwargs["color"] = COLORS_PARTICLES[particle]
plot_kwargs["label"] = particle
ax.plot(
self.times,
coh[particle],
**plot_kwargs,
)
# plot settings
if add_legend:
ax.set_ylabel("Coherence")
ax.set_xlabel("Time [" + self.t_unit + "]")
ax.legend()
return fig, ax
[docs]
def plot_test_fourier(self, tb_site, fig=None, ax=None, dpi=None, **plot_kwargs):
if plot_kwargs is None:
plot_kwargs = {}
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
self.plot_pop_fourier(
self.init_states[0],
tb_site,
self.times,
self.t_unit,
fig,
ax,
dpi,
**plot_kwargs,
)
self.plot_pop(tb_site, fig, ax, dpi, **plot_kwargs)
return fig, ax
# ----------------------------------------------------------------------
[docs]
def plot_eigv(self, energy_unit="eV", fig=None, ax=None, dpi=None, color=None):
if fig is None:
fig, ax = plt.subplots(figsize=(3.4, 3.4), dpi=dpi)
# calculation
eigv, _ = self.get_eigensystem()
eigv *= get_conversion(self.unit, energy_unit)
# plotting
x_start, x_end = 0, 1
if color is None:
color = "black"
for e in eigv:
ax.hlines(y=e, xmin=x_start, xmax=x_end, color=color)
# Optional: Layout anpassen
ax.set_xlim(x_start, x_end)
ax.set_xlabel("")
ax.set_xticks([])
ax.grid(True, axis="y", linestyle=":", alpha=0.4)
ax.set_ylabel("Energy in " + energy_unit)
return fig, ax
[docs]
def plot_eigs(self, eigenstate_idx, fig=None, ax=None, dpi=None):
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
# calculation
_, eigs = self.get_eigensystem()
dm = None
for particle in self.particles:
if self.description == "2P":
dm = get_reduced_dm_eigs(eigs, eigenstate_idx, self.tb_basis, particle)
elif self.description == "1P":
dm = np.outer(eigs[:, eigenstate_idx], eigs[:, eigenstate_idx].conj())
eigs_distribution = np.diag(dm).real
if particle != "exciton":
assert np.allclose(
sum(eigs_distribution), 1, atol=1e-2
), "The distribution does not sum to 1."
ax.plot(
range(self.num_sites),
eigs_distribution,
label=particle,
color=COLORS_PARTICLES[particle],
)
ax.set_xticks(range(self.num_sites))
ax.set_xticklabels(self.tb_sites_flattened)
ax.set_ylim(0, 1.02)
ax.set_title(f"Distribution of Eigenstate {eigenstate_idx}")
ax.legend()
return fig, ax
[docs]
def plot_fourier(self, init_state, end_state, x_axis, fig=None, ax=None, dpi=None):
if fig is None:
fig, ax = plt.subplots(dpi=dpi)
# calculation
amplitudes_dict = self.get_amplitudes(init_state, end_state)
frequencies_dict = self.get_frequencies(init_state, end_state)
# transform frequencies to rad/ps
markers = {"electron": "^", "hole": "v", "exciton": "*"}
for particle in self.particles:
conversion = get_conversion(self.unit, "rad/ps") / (2 * np.pi)
frequencies_dict[particle] = np.array(frequencies_dict[particle]) * conversion
amplitudes = amplitudes_dict[particle]
frequencies = frequencies_dict[particle]
# transform frequency to period (in fs)
periods = 1e3 / frequencies
# frequency as x-axis
if x_axis.lower() == "frequency":
ax.plot(
frequencies,
np.abs(amplitudes),
ls="",
marker=markers[particle],
label=particle,
color=COLORS_PARTICLES[particle],
markersize=10,
alpha=0.8,
)
# period as x-axis
elif x_axis.lower() == "period":
ax.plot(
periods,
amplitudes,
ls="",
marker=markers[particle],
label=particle,
color=COLORS_PARTICLES[particle],
markersize=10,
alpha=0.8,
)
# plot settings
if x_axis.lower() == "frequency":
ax.set_xlabel("Frequency in rad/ps")
elif x_axis.lower() == "period":
ax.set_xlabel("Period in fs")
ax.set_ylabel("Amplitude")
# ax.set_ylim(0.02)
ax.legend()
return fig, ax
# ----------------------------------------------------------------------
def _get_cumulative_average_pop(self, J_list, J_unit):
# pop_list contains the average population for each particle, J, and tb_site
pop_list = np.zeros((len(self.particles), len(J_list), self.num_sites))
# calculate the average population for each particle, J, and tb_site using tb_ham.get_average_pop
self.unit = J_unit
for J_idx, J in enumerate(J_list):
self.coulomb_param = J
for tb_site_idx, tb_site in enumerate(self.tb_basis):
average_pop = self.get_average_pop(self.init_states[0], tb_site)
for particle_idx, particle in enumerate(self.particles):
pop_list[particle_idx][J_idx][tb_site_idx] = average_pop[particle]
# calculate the cumulative average population
cumulative_pop_list = [0] * (self.num_sites + 1)
# add zero population
running_pop_list = np.zeros((len(self.particles), len(J_list)))
cumulative_pop_list[-1] = np.array(running_pop_list)
# add cumulative population
for tb_basis_idx in range(self.num_sites):
running_pop_list += pop_list[:, :, tb_basis_idx]
cumulative_pop_list[tb_basis_idx] = np.array(running_pop_list)
return np.array(cumulative_pop_list)
# cumulative_average_pop = get_cumulative_average_pop(tb_ham, J_list)
[docs]
def plot_average_pop(self, J_list, J_unit="100meV", fig=None, ax=None, dpi=None, **plot_kwargs):
if plot_kwargs is None:
plot_kwargs = {}
direction = plot_kwargs.get("direction", "horizontal")
if "direction" in plot_kwargs:
del plot_kwargs["direction"]
num_particles = len(self.particles)
if direction == "vertical":
x_num, y_num = num_particles, 1
else:
x_num, y_num = 1, num_particles
if fig is None:
if x_num == 1 and y_num == 1:
fig, ax = plt.subplots(dpi=dpi)
else:
fig, ax = plt.subplots(
x_num,
y_num,
figsize=(3.4 * y_num, 3.4 * x_num),
sharex=True,
sharey=True,
dpi=dpi,
)
ax = ax.reshape((x_num, y_num))
# ---
dna_seq = self.tb_sites_flattened
cumulative_average_pop = self._get_cumulative_average_pop(J_list, J_unit)
for i in range(x_num):
for j in range(y_num):
for dna_base_idx, dna_base in enumerate(dna_seq):
# black lines
ax[i, j].plot(
J_list,
cumulative_average_pop[dna_base_idx][:][i + j],
color="k",
lw=1.5,
)
# fill between the black lines
ax[i, j].fill_between(
J_list,
cumulative_average_pop[dna_base_idx - 1][:][i + j],
cumulative_average_pop[dna_base_idx][:][i + j],
color=COLORS_DNA_BASES[dna_base],
alpha=0.3,
)
particle = self.particles[i + j]
ax[i, j].set_title(particle.capitalize(), fontsize=15)
# plot settings
for i in range(x_num):
ax[i, 0].set_ylabel("Acc. Population")
for j in range(y_num):
# plot the bottom line
ax[-1, j].plot(J_list, [0] * len(J_list), color="k", lw=1.5)
ax[-1, j].set_xlabel("J [" + self.unit + "]")
# ax[particle_idx].set_ylim(0, 1.05)
# plot the DNA bases as letters
for dna_base_idx, dna_base in enumerate(dna_seq):
ax[0, 0].text(
0,
dna_base_idx / len(dna_seq),
dna_base,
fontsize=15,
color="k",
alpha=0.8,
)
# ----------------------------------------------------------------------