"""Helper functions for reading and modifying flash output files. For more
infromation on using this module to make figures with matplotlib and yt, please
refer to the following presentation:
import collections
import numpy as np
from scipy.interpolate import griddata
import yt.mods
except ImportError:
from . import analysis
from . import yt_derived_fields
field_labels = {
'dens': 'Density [g/cm$^3$]',
'edens': 'Electron Density [e/cm$^3$]',
'targ': 'Target Material Fraction',
'tele': 'Electron Temperature ($T_e$) [K]',
'pele': 'Electron Pressure ($P_e$) [J/cm$^3$]',
'trad': 'Radiation Temperarture [K]',
'prad': 'Radiation Pressure [J/cm$^3$]',
'ye': '$Y_e$',
'abar': r'$\bar{A}$',
_linear = lambda x: x
field_scales = {
'dens': np.log10,
'edens': np.log10,
'targ': _linear,
'tele': np.log10,
'pele': np.log10,
'trad': np.log10,
'prad': np.log10,
'ye': _linear,
'abar': _linear,
def _cb_decade_ticks(cb):
ord_min = int(np.trunc(cb.vmin))
ord_max = int(np.trunc(cb.vmax))
decades = np.arange(ord_min, ord_max + 1)
ticklabels = ["$10^{{{0}}}$".format(d) for d in decades]
return cb
field_colorbars = {
'dens': _cb_decade_ticks,
'edens': _cb_decade_ticks,
'targ': _linear,
'tele': _cb_decade_ticks,
'pele': _cb_decade_ticks,
'trad': _cb_decade_ticks,
'prad': _cb_decade_ticks,
'ye': _linear,
'abar': _linear,
def _to_hashable(d):
"""Returns a version of the dictionary whose items are hashable."""
hd = {}
for key in d:
if isinstance(d[key], collections.MutableSequence):
hd[key] = tuple(d[key])
elif isinstance(d[key], collections.MutableSet):
hd[key] = frozenset(d[key])
hd[key] = d[key]
return hd
slice_cache = {}
[docs]def slice(axis, coord, field, pf, bounds=None, resolution=600, method="nearest", **kwargs):
"""Grabs the slice of a certain field (parameter, data) from the flash
output file pf. This slice is performed at the coord along the axis.
This function requires both scipy and yt and returns x, y, and z data
that is suitable for plotting with matplotlib.imshow().
axis : int
The axis along which to slice. Can be 0, 1, or 2 for x, y, z.
coord : float
The coordinate along the axis at which to slice. This is in
"domain" coordinates.
field : str
A field to retrieve, e.g. 'dens'.
pf : str or yt.data_objects.static_output
The checkpoint or plot file, may be either a path to the file
or the yt object returned by pf.mods.load().
bounds : len-4 sequence of floats or None, optional
This defines the area within the domain that should be sliced.
By default, the whole domain is taken. Applies the bounds in
axes order, ie if axis=2 then bounds=[xmin, xmax, ymin, ymax].
resolution : int or len-2 sequence of ints, optional
If a list of intergers, this is the number of points (pixels)
along each dimension. If a single integer, then this is the
resolution the longest axis. The resolution on the other axis
is then calculated to preserve the aspect ratio.
method : str, optional
Interpolation method flag, passed directly down to the function
kwargs : optional
All other keyword arguments are passes to the slice() method
on the pf.h object from yt.
xdat : 2d numpy array
Meshed data for x.
ydat : 2d numpy array
Meshed data for y.
zdat : 2d numpy array
Meshed data for z.
# open the file if we got a path.
opened_here = False
if isinstance(pf, basestring):
pf = yt.mods.load(pf)
opened_here = True
# make kwargs hashable
kwargs = _to_hashable(kwargs)
# get the slice object
slice_key = (pf, axis, coord) + tuple(sorted(kwargs.items()))
if slice_key in slice_cache:
amr_slice = slice_cache[slice_key]
amr_slice = pf.h.slice(axis, coord, field, **kwargs)
slice_cache[slice_key] = amr_slice
# get points
x = amr_slice['x']
y = amr_slice['y']
z = amr_slice['z']
points = [(y, z), (x, z), (x, y)][axis]
# get full domain
if bounds is None:
bounds = (points[0][0], points[0][-1], points[1][0], points[1][-1])
# get proper resolution
if isinstance(resolution, int):
lengths = (bounds[1] - bounds[0], bounds[3] - bounds[2])
min_ind, min_len = min(enumerate(lengths), key=lambda x: x[1])
resolution = [resolution, resolution]
resolution[min_ind] = int(resolution[min_ind] * min_len / lengths[(min_ind + 1) % 2])
elif isinstance(resolution, collections.Sequence) and 2 == len(resolution):
raise ValueError("resolution must be int or length 2 sequence of ints.")
# get meshed data
aa = np.linspace(bounds[0], bounds[1], resolution[0])
bb = np.linspace(bounds[2], bounds[3], resolution[1])
adat, bdat = np.meshgrid(aa, bb)
cdat = griddata(points, amr_slice[field], (adat, bdat), method=method)
# close the file if we opened it
if opened_here:
del pf
# return in the proper order
dat = [(cdat, adat, bdat), (adat, cdat, bdat), (adat, bdat, cdat)]
return dat[axis]
[docs]def slice_gradient(axis, coord, field, pf, bounds=None, resolution=600, method="nearest", **kwargs):
"""Grabs the gradient of a slice of a certain field (parameter, data)
from the flash output file pf. This slice is performed at the coord
along the axis.
This function requires both scipy and yt and returns x, y, and z data
that is suitable for plotting with matplotlib.imshow().
axis : int
The axis along which to slice. Can be 0, 1, or 2 for x, y, z.
coord : float
The coordinate along the axis at which to slice. This is in
"domain" coordinates.
field : str
A field to retrieve, e.g. 'dens'.
pf : str or yt.data_objects.static_output
The checkpoint or plot file, may be either a path to the file
or the yt object returned by pf.mods.load().
bounds : len-4 sequence of floats or None, optional
This defines the area within the domain that should be sliced.
By default, the whole domain is taken. Applies the bounds in
axes order, ie if axis=2 then bounds=[xmin, xmax, ymin, ymax].
resolution : int or len-2 sequence of ints, optional
If a list of intergers, this is the number of points (pixels)
along each dimension. If a single integer, then this is the
resolution the longest axis. The resolution on the other axis
is then calculated to preserve the aspect ratio.
method : str, optional
Interpolation method flag, passed directly down to the function
kwargs : optional
All other keyword arguments are passes to the slice() method
on the pf.h object from yt.
adat : 2d numpy array
Meshed data for the first non-axis axis, eg x when axis=z.
bdat : 2d numpy array
Meshed data for the second non-axis axis, eg y when axis=z.
dfdadat : 2d numpy array
Gradient of field along a.
dfdbdat : 2d numpy array
Gradient of field along b.
magdat : 2d numpy array
Magnitude of the gradient of field.
# open the file if we got a path.
opened_here = False
if isinstance(pf, basestring):
pf = yt.mods.load(pf)
opened_here = True
# make kwargs hashable
kwargs = _to_hashable(kwargs)
# get the slice object
slice_key = (pf, axis, coord) + tuple(sorted(kwargs.items()))
if slice_key in slice_cache:
amr_slice = slice_cache[slice_key]
amr_slice = pf.h.slice(axis, coord, field, **kwargs)
slice_cache[slice_key] = amr_slice
# get a, b, f
x = amr_slice['x']
y = amr_slice['y']
z = amr_slice['z']
a, b = [(y, z), (x, z), (x, y)][axis]
dx = amr_slice['dx']
dy = amr_slice['dy']
dz = amr_slice['dz']
da, db = [(dy, dz), (dx, dz), (dx, dy)][axis]
f = amr_slice[field]
amr_len = len(f)
# calc gradient along a
dfda = np.empty_like(f)
bunique = np.unique(b)
for bu in bunique:
mask = bu == b
dfda[mask] = np.gradient(f[mask]) / da[mask]
# calc gradient along b
dfdb = np.empty_like(f)
aunique = np.unique(a)
for au in aunique:
mask = au == a
dfdb[mask] = np.gradient(f[mask]) / db[mask]
# calc gradient magnitude
mag = np.sqrt(dfda**2 + dfdb**2)
# get full domain
if bounds is None:
bounds = (a[0], a[-1], b[0], b[-1])
# get proper resolution
if isinstance(resolution, int):
lengths = (bounds[1] - bounds[0], bounds[3] - bounds[2])
min_ind, min_len = min(enumerate(lengths), key=lambda x: x[1])
resolution = [resolution, resolution]
resolution[min_ind] = int(resolution[min_ind] * min_len / lengths[(min_ind + 1) % 2])
elif isinstance(resolution, collections.Sequence) and 2 == len(resolution):
raise ValueError("resolution must be int or length 2 sequence of ints.")
# get meshed data
aa = np.linspace(bounds[0], bounds[1], resolution[0])
bb = np.linspace(bounds[2], bounds[3], resolution[1])
adat, bdat = np.meshgrid(aa, bb)
dfdadat = griddata((a, b), dfda, (adat, bdat), method=method)
dfdbdat = griddata((a, b), dfdb, (adat, bdat), method=method)
magdat = griddata((a, b), mag, (adat, bdat), method=method)
# close the file if we opened it
if opened_here:
del pf
# return in the proper order
return adat, bdat, dfdadat, dfdbdat, magdat
ray_cache = {}
[docs]def lineout(p1, p2, field, pf, **kwargs):
"""Grabs a line out (ray) of a certain field (parameter, data) from the
flash output file pf.
This function requires both scipy and yt and returns x, y, z and value
data that is suitable for plotting with matplotlib.plot().
p1 : three-tuple of floats
The first point in the line-out.
p2 : three-tuple of floats
The second point in the line-out.
field : str
A field to retrieve, e.g. 'dens'.
pf : str or yt.data_objects.static_output
The checkpoint or plot file, may be either a path to the file
or the yt object returned by pf.mods.load().
kwargs : optional
All other keyword arguments are passes to the ray() method
on the pf.h object from yt.
x : 1D numpy array
Interpolated x data.
y : 1D numpy array
Interpolated y data.
z : 1D numpy array
Interpolated z data.
v : 1D numpy array
Interpolated field values along ray.
# open the file if we got a path.
opened_here = False
if isinstance(pf, basestring):
pf = yt.mods.load(pf)
opened_here = True
# make hashable
p1 = tuple(p1)
p2 = tuple(p2)
kwargs = _to_hashable(kwargs)
# get the ray object
ray_key = (pf, p1, p2) + tuple(sorted(kwargs.items()))
if ray_key in ray_cache:
amr_ray = ray_cache[ray_key]
amr_ray = pf.h.ray(p1, p2, field, **kwargs)
ray_cache[ray_key] = amr_ray
# get points
x = amr_ray['x']
y = amr_ray['y']
z = amr_ray['z']
v = amr_ray[field]
# close the file if we opened it
if opened_here:
del pf
# return in the proper order
return x, y, z, v
[docs]def shock_on_lineout(p1, p2, field, pf, threshold=1e-6, min_threshold=1e-36, **kwargs):
"""Finds the shock of a certain field (parameter, data) along a line out (ray)
in a flash output file pf. Currently this assumes that p1 is closer to the
center of the shock than p2. Moreover the geometry from the FLASH simulation
must be cartesian or 2D cylindrical (rz).
p1 : three-tuple of floats
The first point in the line-out.
p2 : three-tuple of floats
The second point in the line-out.
field : str
A field to retrieve, e.g. 'dens'.
pf : str or yt.data_objects.static_output
The checkpoint or plot file, may be either a path to the file
or the yt object returned by pf.mods.load().
threshold : float, optional
The value above which gradients must exceed to be considered a
shock. Required for ignoring low-level noise.
min_threshold : float, optional
If a shock is not found at a given threshold level, the threshold
value is reduced by an order of magnitude until a shock is found.
This continues until a minimum threshold is reached.
kwargs : optional
All other keyword arguments are passes to the ray() method
on the pf.h object from yt.
shock_p : three-tuple of floats
Shock position.
shock_v : field type (float)
Shock peak value.
See Also
flash.output.lineout :
Used to get the lineout from two points and a data file.
flash.analysis.shock_detect :
Used to find the shock itself.
x, y, z, v = lineout(p1, p2, field, pf)
r = np.sqrt(x**2 + y**2 + z**2)
sort_index = rad.argsort()
r = r[sort_index]
v = v[sort_index]
shock_r, shock_v, shock_i = analysis.shock_detect(r, v, threshold=threshold,
shock_p = (x[sort_index][shock_i], y[sort_index][shock_i], z[sort_index][shock_i])
return shock_p, shock_v
[docs]def load_laser_dat(filename):
"""Reads in a LaserEnergyProfile.dat file as a numpy structured array.
filename : str or file-like
The string path to the file or an object which implements the file
dat : ndarray
This structured array has the following dtype and units:
========== ======= ========
Field dtype units
========== ======= ========
Step int32 unitless
Time float64 s
dt float64 s
Energy In float64 erg
Energy Out float64 erg
dE In float64 erg
dE Out float64 erg
========== ======= ========
datdt = np.dtype([('step', np.int32),
('t', np.float64),
('dt', np.float64),
('E_in', np.float64),
('E_out', np.float64),
('dE_in', np.float64),
('dE_out', np.float64),
dat = np.loadtxt(filename, dtype=datdt, skiprows=2)
return dat