"""
General-purpose plotting utilities for scientific visualization.
This module provides flexible plotting functions for creating publication-ready figures
across different analysis workflows. Functions are designed for reusability with
consistent styling and support for both simple and complex data visualizations.
Key Features
------------
- Multi-dimensional array handling for comparative analysis
- Automatic styling with scientific color schemes and markers
- Flexible axis control and customization options
- Publication-ready output in multiple formats (PDF, SVG)
- Memory-efficient plotting for large datasets
Design Philosophy
-----------------
Functions prioritize flexibility over rigid interfaces, using **kwargs for extensive
customization. This approach supports diverse scientific plotting needs while maintaining
consistent visual output across the LAMMPSKit ecosystem.
Styling Standards
-----------------
- Color palette: ['b', 'r', 'g', 'k'] (blue, red, green, black)
- Line styles: ['--', '-.', ':', '-'] (dashed, dash-dot, dotted, solid)
- Markers: ['o', '^', 's', '*'] (circle, triangle, square, star)
- Font sizes: 8pt labels, 7pt legends, 7pt ticks for compact publication layout
Performance Notes
-----------------
Memory usage scales with data size and number of cases. For large datasets (>10^5 points),
consider data downsampling before plotting. Figure generation is optimized for batch
processing workflows.
Examples
--------
Simple comparative plot:
>>> import numpy as np
>>> from lammpskit.plotting import plot_multiple_cases
>>> x = np.array([1, 2, 3])
>>> y = np.array([[1, 4, 9], [1, 8, 27]]) # Two cases as an example
>>> labels = ['Case 1', 'Case 2']
>>> fig = plot_multiple_cases(x, y, labels, 'X values', 'Y values', 'comparison', 8, 6)
Electrochemical analysis plot:
>>> z_bins = np.linspace(-10, 40, 50) # Electrode-to-electrode z positions
>>> atom_counts = np.array([[10, 15, 20], [5, 12, 18]]) # Hf, O, Ta counts
>>> labels = ['SET state', 'RESET state']
>>> fig = plot_multiple_cases(atom_counts, z_bins, labels,
... 'Atom count', 'Z position (Å)', 'distribution', 10, 8)
"""
import os
import numpy as np
import matplotlib.pyplot as plt
from typing import List
[docs]
def plot_multiple_cases(
x_arr: np.ndarray,
y_arr: np.ndarray,
labels: List[str],
xlabel: str,
ylabel: str,
output_filename: str,
xsize: float,
ysize: float,
output_dir: str = os.getcwd(),
**kwargs,
) -> plt.Figure:
"""
Create comparative plots for multiple datasets with publication-ready styling.
Versatile plotting function for scientific data visualization supporting various array
dimensions and comparison scenarios. Handles both single-case and multi-case analysis
with automatic styling, customizable limits, and dual-format output. Optimized for
electrochemical cell analysis and general MD simulation data visualization.
Parameters
----------
x_arr : np.ndarray
X-axis data for plotting. Supports multiple dimensions:
- 1D: Single x-series for all cases
- 2D: Different x-series for each case (shape: n_cases, n_points)
y_arr : np.ndarray
Y-axis data for plotting. Supports multiple dimensions:
- 1D: Single y-series (used with single case or shared across cases)
- 2D: Different y-series for each case (shape: n_cases, n_points)
labels : List[str]
Legend labels for each case. Length should match number of cases in data arrays.
xlabel : str
X-axis label with units. Example: 'Z position (Å)', 'Time (ps)'
ylabel : str
Y-axis label with units. Example: 'Atom count', 'Displacement (Å)'
output_filename : str
Base filename for saved figures (extensions added automatically).
Example: 'atomic_distribution', 'filament_evolution'
xsize : float
Figure width in inches. Note: Function overrides with hardcoded value (1.6).
ysize : float
Figure height in inches. Note: Function overrides with hardcoded value (3.2).
output_dir : str, optional
Output directory for saved figures. Created if doesn't exist (default: cwd).
**kwargs : dict, optional
Advanced customization options:
Axis Limits:
xlimit : tuple (xmin, xmax) - Set both x-axis limits
ylimit : tuple (ymin, ymax) - Set both y-axis limits
xlimitlo : float - Set x-axis lower limit only
xlimithi : float - Set x-axis upper limit only
ylimitlo : float - Set y-axis lower limit only
ylimithi : float - Set y-axis upper limit only
Reference Lines:
xaxis : bool - Add horizontal line at y=0
yaxis : bool - Add vertical line at x=0
Styling:
markerindex : int - Override automatic color/marker cycling
Statistical Analysis:
ncount : np.ndarray - Atom counts per bin for average calculations
Shape: (n_cases, n_bins). Prints weighted averages.
Returns
-------
fig : matplotlib.figure.Figure
Figure object for further customization or display.
Note: Figure is automatically saved and closed for memory efficiency.
Notes
-----
Array Dimension Handling:
- x_arr.ndim=1, y_arr.ndim=1: Single case plot
- x_arr.ndim=1, y_arr.ndim=2: Shared x-axis, multiple y-series
- x_arr.ndim=2, y_arr.ndim=1: Multiple x-series, shared y-axis
- x_arr.ndim=2, y_arr.ndim=2: Full multi-case plot (most common)
Performance Characteristics:
- Memory usage: O(max(x_size, y_size))
- Rendering time: O(n_cases * n_points)
- File I/O: Dual output (PDF + SVG) for versatility
Output Format:
- PDF: Vector format for publications and presentations
- SVG: Web-compatible vector format for interactive displays
- Both saved with tight bounding boxes for clean appearance
Common Usage Patterns in LAMMPSKit:
-----------------------------------
Electrochemical analysis (atom distributions):
>>> plot_multiple_cases(distributions['hafnium'], z_bin_centers, labels,
... 'Hf atoms #', 'z position (A)', 'hf_distribution', 8, 6)
Displacement analysis:
>>> plot_multiple_cases(zdisp, binposition, labels,
... 'z displacement (A)', 'z position (A)', 'z_disp', 8, 6,
... yaxis=True) # Add y=0 reference line
Charge distribution with axis limits:
>>> plot_multiple_cases(charge_data, z_positions, labels,
... 'Net charge', 'z position (A)', 'charge_dist', 8, 6,
... ylimithi=70, xlimithi=15, xlimitlo=-20)
Examples
--------
Basic multi-case comparison:
>>> import numpy as np
>>> z_pos = np.linspace(0, 30, 50) # Electrode positions
>>> hf_counts = np.array([[5, 10, 15], [8, 12, 18]]) # Two voltage states
>>> labels = ['0.5V', '1.0V']
>>> fig = plot_multiple_cases(hf_counts, z_pos, labels,
... 'Hf atom count', 'Z position (Å)',
... 'hafnium_analysis', 10, 8)
Single case with reference lines:
>>> displacement = np.random.normal(0, 1, 100)
>>> positions = np.linspace(-10, 40, 100)
>>> fig = plot_multiple_cases(displacement, positions, ['Displacement'],
... 'Displacement (Å)', 'Z position (Å)',
... 'displacement_profile', 8, 6,
... yaxis=True, xaxis=True)
Multi-dimensional array example:
>>> # 3 cases, 4 elements each
>>> element_counts = np.random.randint(1, 20, (3, 4))
>>> elements = ['Hf', 'Ta', 'O', 'Electrode']
>>> case_labels = ['SET', 'Intermediate', 'RESET']
>>> fig = plot_multiple_cases(element_counts, elements, case_labels,
... 'Element count', 'Element type',
... 'element_comparison', 12, 8)
"""
nrows = 1
ncolumns = 1
xsize = 1.6
ysize = 3.2
print("before subplots")
plt.ioff()
fig, axes = plt.subplots(nrows, ncolumns, squeeze=False, constrained_layout=False, figsize=(xsize, ysize))
print("before axes flatten")
axes = axes.flatten()
print("before tight layout")
fig.tight_layout()
# plt.rcParams['xtick.labelsize'] = 6
# plt.rcParams['ytick.labelsize'] = 6
colorlist = ["b", "r", "g", "k"]
linestylelist = ["--", "-.", ":", "-"]
markerlist = ["o", "^", "s", "*"]
print("reached now plotting point")
# Plot each case depending on array dimensions
if x_arr.ndim > 1 and y_arr.ndim > 1:
for i in range(len(x_arr)):
j = kwargs.get("markerindex", i)
axes[0].plot(
x_arr[i],
y_arr[i],
label=labels[i],
color=colorlist[j],
linestyle=linestylelist[j],
marker=markerlist[j],
markersize=5,
linewidth=1.2,
alpha=0.75,
)
elif x_arr.ndim > 1 and y_arr.ndim == 1:
for i in range(len(x_arr)):
j = kwargs.get("markerindex", i)
axes[0].plot(
x_arr[i],
y_arr,
label=labels[i],
color=colorlist[j],
linestyle=linestylelist[j],
marker=markerlist[j],
markersize=5,
linewidth=1.2,
alpha=0.75,
)
elif x_arr.ndim == 1 and y_arr.ndim > 1:
for i in range(len(y_arr)):
j = kwargs.get("markerindex", i)
axes[0].plot(
x_arr,
y_arr[i],
label=labels[i],
color=colorlist[j],
linestyle=linestylelist[j],
marker=markerlist[j],
markersize=5,
linewidth=1.2,
alpha=0.75,
)
else:
j = kwargs.get("markerindex", 0)
axes[0].plot(
x_arr,
y_arr,
label=labels,
color=colorlist[j],
linestyle=linestylelist[j],
marker=markerlist[j],
markersize=5,
linewidth=1.2,
alpha=0.75,
)
# Optionally plot atom bin counts and print averages
if "ncount" in kwargs:
atoms_per_bin_count = kwargs["ncount"]
for i in range(len(x_arr)):
# Calculate weighted average for statistical reporting
average = np.sum(x_arr[i] * atoms_per_bin_count[i]) / np.sum(atoms_per_bin_count[i])
print(f"\n The average for {labels[i]} in {output_filename} is {average} \n")
# Axis limits and lines
if "xlimit" in kwargs:
print("x axis is limited")
axes[0].set_xlim(kwargs["xlimit"])
if "ylimit" in kwargs:
print("y axis is limited")
axes[0].set_ylim(kwargs["ylimit"])
if "xlimithi" in kwargs:
print("x hi axis is limited")
axes[0].set_xlim(right=kwargs["xlimithi"])
if "ylimithi" in kwargs:
print("y hi axis is limited")
axes[0].set_ylim(top=kwargs["ylimithi"])
if "xlimitlo" in kwargs:
print("x lo axis is limited")
axes[0].set_xlim(left=kwargs["xlimitlo"])
if "ylimitlo" in kwargs:
print("y lo axis is limited")
axes[0].set_ylim(bottom=kwargs["ylimitlo"])
if "xaxis" in kwargs:
axes[0].axhline(y=0, color=colorlist[-1], linestyle=linestylelist[-1], linewidth=1, label="y=0")
if "yaxis" in kwargs:
axes[0].axvline(x=0, color=colorlist[-1], linestyle=linestylelist[-1], linewidth=1, label="x=0")
print("reached axes labelling point")
axes[0].set_ylabel(ylabel, fontsize=8)
axes[0].legend(loc="upper center", fontsize=7)
axes[0].adjustable = "datalim"
axes[0].set_aspect("auto")
axes[0].tick_params(axis="both", which="major", labelsize=7)
axes[0].set_aspect("auto")
axes[0].set_xlabel(xlabel, fontsize=8)
# plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0)
# plt.suptitle(f'{datatype} {dataindexname[dataindex]}', fontsize=8)
# plt.show()
plt.ioff()
print("reached file saving point")
output_filename_pdf = output_filename + ".pdf"
os.makedirs(output_dir, exist_ok=True)
savepath = os.path.join(output_dir, output_filename_pdf)
fig.savefig(savepath, bbox_inches="tight", format="pdf")
output_filename_svg = output_filename + ".svg"
savepath = os.path.join(output_dir, output_filename_svg)
fig.savefig(savepath, bbox_inches="tight", format="svg")
plt.close()
return fig