from __future__ import annotations
import textwrap
from copy import deepcopy
from typing import Iterable, TYPE_CHECKING
import numpy as np
import matplotlib.pyplot as plt
from bmlite import IDAResult
if TYPE_CHECKING: # pragma: no cover
from ._simulation import Simulation
if not hasattr(np, 'concat'): # pragma: no cover
np.concat = np.concatenate
class BaseSolution(IDAResult):
"""Base SPM solution."""
def __init__(self) -> None:
"""
The base solution class is a parent class to both the StepSolution
and CycleSolution classes. Inheriting from this class gives each
solution instance a 'vars' dictionary, access to the 'plot' method,
and ensures that the slicing of the solution vector into 'vars' is
consistent between all solutions.
"""
self.vars = {}
self._postvars = False
def __repr__(self) -> str: # pragma: no cover
"""
Return a readable repr string.
Returns
-------
readable : str
A console-readable instance representation.
"""
classname = self.__class__.__name__
def wrap_string(label: str, value: list, width: int):
if isinstance(value, Iterable):
value = list(value)
else:
value = [value]
indent = ' '*(len(label) + 1)
if classname == 'StepSolution' and len(value) == 1:
text = label + f"{value[0]!r}"
else:
text = label + "[" + ", ".join(f"{v!r}" for v in value) + "]"
return textwrap.fill(text, width=width, subsequent_indent=indent)
data = [
wrap_string(' success=', self.success, 79),
wrap_string(' status=', self.status, 79),
wrap_string(' nfev=', self.nfev, 79),
wrap_string(' njev=', self.njev, 79),
wrap_string(' vars=', self.vars.keys(), 79),
]
summary = f" solvetime={self.solvetime},"
for d in data:
summary += f"\n{d},"
readable = f"{classname}(\n{summary}\n)"
return readable
def post(self) -> None:
from .postutils import post
sim = self._sim
# domain variables
self.vars['an'] = sim.an.to_dict(self)
self.vars['ca'] = sim.ca.to_dict(self)
self.vars['el'] = sim.el.to_dict(self)
postvars = post(self)
self.vars['an']['sdot'] = postvars['sdot_an']
self.vars['ca']['sdot'] = postvars['sdot_ca']
self._postvars = True
def simple_plot(self, x: str, y: str, **kwargs) -> None:
"""
Plot any two basic 1D variables in 'vars' against each other, i.e.,
time, current, voltage, and power.
Parameters
----------
x : str
A variable key in 'vars' to be used for the x-axis.
y : str
A variable key in 'vars' to be used for the y-axis.
**kwargs : dict, optional
Keyword arguments to pass through to `plt.plot()`. For more info
please refer to documentation for `maplotlib.pyplot.plot()`.
"""
from .._utils import ExitHandler
plt.figure()
plt.plot(self.vars[x], self.vars[y], **kwargs)
variable, units = x.split('_')
xlabel = variable.capitalize() + ' [' + units + ']'
variable, units = y.split('_')
ylabel = variable.capitalize() + ' [' + units + ']'
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if not plt.isinteractive():
ExitHandler.register_atexit(plt.show)
def complex_plot(self, *args: str) -> None:
"""
Generates requested plots based on `*args`.
Parameters
----------
*args : str
Use any number of the following arguments to see the described
plots:
================= ===============================================
arg Description
================= ===============================================
'potentials' anode, cathode, and electrolyte potentials [V]
'intercalation' anode/cathode particle Li fractions vs. r and t
'pixels' pixel plots for solid Li concentrations
================= ===============================================
"""
if not self._postvars:
self.post()
if 'potentials' in args:
from .postutils import potentials
potentials(self)
if 'intercalation' in args:
from .postutils import intercalation
intercalation(self)
if 'pixels' in args:
from .postutils import pixels
pixels(self)
def to_dict(self) -> dict:
"""
Creates a dict with all spatial, time, and state variables separated
into 1D and 2D arrays. The keys are given below.
========= =======================================================
Key Value [units] (*type*)
========= =======================================================
r_a r mesh for anode particles [m] (*1D array*)
r_c r mesh for cathode particles [m] (*1D array*)
t saved solution times [s] (*1D array*)
phis_a anode electrode potentials at t [V] (*1D array*)
cs_a electrode Li at t, r_a [kmol/m3] (*2D array*)
phis_c cathode electrode potentials at t [V] (*1D array*)
cs_c electrode Li at t, r_c [kmol/m3] (*2D array*)
phie electrolyte potentials at t [V] (*1D array*)
j_a anode Faradaic current at t [kmol/m2/s] (*1D array*)
j_c cathode Faradaic current at t [kmol/m2/s] (*1D array*)
========= =======================================================
Returns
-------
sol_dict : dict
A dictionary containing the solution
"""
if not self._postvars:
self.post()
vars = {
'r_a': self.vars['an']['r'],
'r_c': self.vars['ca']['r'],
't': self.vars['time_s'],
'phis_a': self.vars['an']['phis'],
'cs_a': self.vars['an']['cs'],
'phis_c': self.vars['ca']['phis'],
'cs_c': self.vars['ca']['cs'],
'phie': self.vars['el']['phie'],
'j_a': self.vars['an']['sdot'],
'j_c': self.vars['ca']['sdot'],
}
return vars
def save_sliced(self, savename: str, overwrite: bool = False) -> None:
"""
Save a `.npz` file with all spatial, time, and state variables
separated into 1D and 2D arrays. The keys are given below. The index
order of the 2D arrays is given with the value descriptions.
========= =======================================================
Key Value [units] (*type*)
========= =======================================================
r_a r mesh for anode particles [m] (*1D array*)
r_c r mesh for cathode particles [m] (*1D array*)
t saved solution times [s] (*1D array*)
phis_a anode electrode potentials at t [V] (*1D array*)
cs_a electrode Li at t, r_a [kmol/m3] (*2D array*)
phis_c cathode electrode potentials at t [V] (*1D array*)
cs_c electrode Li at t, r_c [kmol/m3] (*2D array*)
phie electrolyte potentials at t [V] (*1D array*)
j_a anode Faradaic current at t [kmol/m2/s] (*1D array*)
j_c cathode Faradaic current at t [kmol/m2/s] (*1D array*)
========= =======================================================
Parameters
----------
savename : str
Either a file name or the absolute/relative file path. The `.npz`
extension will be added to the end of the string if it is not
already there. If only the file name is given, the file will be
saved in the user's current working directory.
overwrite : bool, optional
A flag to overwrite an existing `.npz` file with the same name
if one exists. The default is `False`.
"""
import os
if '.npz' not in savename:
savename += '.npz'
if os.path.exists(savename) and not overwrite:
raise FileExistsError(savename + ' already exists. Use overwrite'
' flag or delete the file and try again.')
sol_dict = self.to_dict()
np.savez(savename, **sol_dict)
def _fill_vars(self) -> None:
"""
Fills the 'vars' dictionary by slicing the SolverReturn solution
states. Users should generally only access the solution via 'vars'
since names are more intuitive than interpreting 'y' directly.
"""
sim = self._sim
# domain variables - placeholders
self.vars['an'] = 'Run soln.post() to populate'
self.vars['ca'] = 'Run soln.post() to populate'
self.vars['el'] = 'Run soln.post() to populate'
# stored time
time_s = self.t
self.vars['time_s'] = time_s
self.vars['time_min'] = time_s / 60.
self.vars['time_h'] = time_s / 3600.
# common variables
voltage_V = sim.ca._boundary_voltage(self)
current_A = sim.ca._boundary_current(self)
self.vars['current_A'] = current_A
self.vars['current_C'] = current_A / sim.bat.cap
self.vars['voltage_V'] = voltage_V
self.vars['power_W'] = current_A*voltage_V
def _verify(self, plot: bool = False, atol: float = 1e-1,
rtol: float = 2e-2) -> dict:
"""
Verifies the solution is mathematically consistent. This is primarily
for testing purposes.
Specifically, this compares the boundary current is consistent with the
reactions in each electrode at each time step, and that solid-phase
lithium was conserved. If the verification fails, you can visualize the
checks by using the `plot` flag. Figures shaded grey indicate that its
respective test failed.
Parameters
----------
plot : bool, optional
A flag to show plots of the verifications. The default is False.
atol : float, optional
Absolute tolerance for comparisons. The default is 1e-1.
rtol : float, optional
Relative tolerance for comparisons. The default is 1e-2.
Returns
-------
checks : bool
A dictionary of keys describing each check and boolean values to
specify whether each check passed or not.
"""
from .._utils import ExitHandler
from ..plotutils import format_ticks
from .postutils import _solid_phase_Li
sim = self._sim
c, bat, an, ca = sim.c, sim.bat, sim.an, sim.ca
if not self._postvars:
self.post()
i_mod = self.vars['current_A'] / bat.area
Li_ed_0, Li_ed_t = _solid_phase_Li(self)
j_an_tot = self.vars['an']['sdot']*an.A_s*an.thick*c.F
j_ca_tot = self.vars['ca']['sdot']*ca.A_s*ca.thick*c.F
checks = {
'j_a': np.allclose(i_mod, j_an_tot, rtol=rtol, atol=atol),
'j_c': np.allclose(i_mod, -j_ca_tot, rtol=rtol, atol=atol),
'cs': np.allclose(1., Li_ed_t / Li_ed_0, rtol=rtol, atol=atol),
}
if plot:
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[12, 3],
layout='constrained')
# Faradaic currents
ax[0].set_ylabel(r'$i_{\rm ext} - j_{\rm an}$ [A/m$^2$]')
ax[1].set_ylabel(r'$i_{\rm ext} + j_{\rm ca}$ [A/m$^2$]')
ax[0].plot(self.t, i_mod - j_an_tot, '-C3')
ax[1].plot(self.t, i_mod + j_ca_tot, '-C2')
ymin = min([ax[i].get_ylim()[0] for i in range(1)])
ymax = max([ax[i].get_ylim()[1] for i in range(1)])
for i in range(1):
ax[i].set_ylim([ymin, ymax])
# Lithium conservation
ax[2].set_ylabel(r'$C_{\rm Li,s} \ / \ C_{\rm Li,s}^0$ [$-$]')
ax[2].plot(self.t, Li_ed_t / Li_ed_0, '-k')
# formatting
for i in range(3):
ax[i].set_xlabel(r'$t$ [s]')
format_ticks(ax[i])
# shade bad checks
for i, val in enumerate(checks.values()):
if not val:
ax[i].patch.set_facecolor('grey')
ax[i].patch.set_alpha(0.5)
fig.get_layout_engine().set(wspace=0.1, hspace=0.1)
if not plt.isinteractive():
ExitHandler.register_atexit(plt.show)
return checks
[docs]
class StepSolution(BaseSolution):
"""Single-step solution."""
def __init__(self, sim: Simulation, idasoln: IDAResult,
timer: float) -> None:
"""
A solution instance for a single experimental step.
Parameters
----------
sim : Simulation
The simulation instance that was run to produce the solution.
idasoln : IDAResult
The unformatted solution returned by IDASolver.
timer : float
Amount of time it took for IDASolver to perform the integration.
"""
super().__init__()
self._sim = sim.copy()
self.message = idasoln.message
self.success = idasoln.success
self.status = idasoln.status
self.t = idasoln.t
self.y = idasoln.y
self.yp = idasoln.yp
self.i_events = idasoln.i_events
self.t_events = idasoln.t_events
self.y_events = idasoln.y_events
self.yp_events = idasoln.yp_events
self.nfev = idasoln.nfev
self.njev = idasoln.njev
self._timer = timer
self._fill_vars()
@property
def solvetime(self) -> str:
"""
Print a statement specifying how long IDASolver spent integrating.
Returns
-------
solvetime : str
An f-string with the solver integration time in seconds.
"""
return f"{self._timer:.3f} s"
[docs]
class CycleSolution(BaseSolution):
"""All-step solution."""
def __init__(self, *soln: StepSolution, t_shift: float = 1e-3) -> None:
"""
A solution instance with all experiment steps stitch together into
a single cycle.
Parameters
----------
*soln : StepSolution
All unpacked StepSolution instances to stitch together. The given
steps should be given in the same sequential order that they were
run.
t_shift : float
Time (in seconds) to shift step solutions by when stitching them
together. If zero the end time of each step overlaps the starting
time of its following step. The default is 1e-3.
"""
super().__init__()
self._solns = soln
self._sim = soln[0]._sim.copy()
t_size = np.sum([soln.t.size for soln in self._solns])
sv_size = self._sim._sv0.size
self.message = []
self.success = []
self.status = []
self.t = np.empty([t_size])
self.y = np.empty([t_size, sv_size])
self.yp = np.empty([t_size, sv_size])
self.t_events = None
self.y_events = None
self.yp_events = None
self.nfev = []
self.njev = []
self._timers = []
first = 0
for soln in self._solns:
soln_size = soln.t.size
last = first + soln_size
if first > 0:
shift_t = self.t[first - 1] + soln.t + t_shift
else:
shift_t = soln.t
if soln.t_events and first > 0:
shift_t_events = self.t[first - 1] + soln.t_events + t_shift
elif soln.t_events:
shift_t_events = soln.t_events
self.message.append(soln.message)
self.success.append(soln.success)
self.status.append(soln.status)
self.t[first:last] = shift_t
self.y[first:last, :] = soln.y
self.yp[first:last, :] = soln.yp
first = last
if soln.t_events:
if self.t_events is None:
self.t_events = shift_t_events
self.y_events = soln.y_events
self.yp_events = soln.yp_events
else:
self.t_events = np.concat([self.t_events, shift_t_events])
self.y_events = np.concat([self.y_events, soln.y_events])
self.yp_events = np.concat([self.yp_events, soln.yp_events])
self.nfev.append(soln.nfev)
self.njev.append(soln.njev)
self._timers.append(soln._timer)
self._fill_vars()
@property
def solvetime(self) -> str:
"""
Print a statement specifying how long IDASolver spent integrating.
Returns
-------
solvetime : str
An f-string with the total solver integration time in seconds.
"""
return f"{sum(self._timers):.3f} s"
[docs]
def get_steps(self, idx: int | tuple) -> StepSolution | CycleSolution:
"""
Return a subset of the solution.
Parameters
----------
idx : int | tuple
The step index (int) or first/last indices (tuple) to return.
Returns
-------
:class:`StepSolution` | :class:`CycleSolution`
The returned solution subset. A StepSolution is returned if 'idx'
is an int, and a CycleSolution will be returned for the range of
requested steps when 'idx' is a tuple.
"""
if isinstance(idx, int):
return deepcopy(self._solns[idx])
elif isinstance(idx, (tuple, list)):
solns = self._solns[idx[0]:idx[1] + 1]
return CycleSolution(*solns)