Edited parity plot to allow color to represent 3rd dimension. Still need to improve colorbar axis name.

This commit is contained in:
titusquah
2020-06-16 09:27:19 -06:00
parent 77e4cefd2c
commit 97ca231afc
4 changed files with 168 additions and 114 deletions

View File

@@ -64,7 +64,8 @@ class REEPS:
The ordering of the columns needs to be:
[h_i, h_eq, z_i, z_eq, {RE_1}_aq_i, {RE_1}_aq_eq, {RE_1}_d_eq,
[h_i, h_eq, z_i, z_eq,
{RE_1}_aq_i, {RE_1}_aq_eq, {RE_1}_d_eq,
{RE_2}_aq_i, {RE_2}_aq_eq, {RE_2}_d_eq,...
{RE_N}_aq_i, {RE_N}_aq_eq, {RE_N}_d_eq]
@@ -347,15 +348,17 @@ class REEPS:
self.update_predicted_dict()
@staticmethod
def slsqp_optimizer(objective, x_guess):
def scipy_minimize(objective, x_guess, optimizer_kwargs=None):
""" The default optimizer for REEPS
Uses scipy.minimize with options
Uses scipy.minimize
By default, options are
.. code-block:: python
default_kwargs= {"method": 'SLSQP',
"bounds": [(1e-1, 1e1)*len(x_guess)],
"bounds": [(1e-1, 1e1)]*len(x_guess),
"constraints": (),
"options": {'disp': True,
'maxiter': 1000,
@@ -363,14 +366,16 @@ class REEPS:
:param objective: (func) the objective function
:param x_guess: (np.ndarray) the initial guess (always 1)
:param optimizer_kwargs: (dict) dictionary of options for minimize
:returns: (np.ndarray) Optimized parameters
"""
optimizer_kwargs = {"method": 'SLSQP',
"bounds": [(1e-1, 1e1)] * len(x_guess),
"constraints": (),
"options": {'disp': True,
'maxiter': 1000,
'ftol': 1e-6}}
if optimizer_kwargs is None:
optimizer_kwargs = {"method": 'SLSQP',
"bounds": [(1e-1, 1e1)] * len(x_guess),
"constraints": (),
"options": {'disp': True,
'maxiter': 1000,
'ftol': 1e-6}}
res = minimize(objective, x_guess, **optimizer_kwargs)
est_parameters = res.x
return est_parameters
@@ -450,7 +455,7 @@ class REEPS:
"""Change list of Cantera solutions by inputting
new xml file name and phase names
Also runs set_in_moles to set initial molality to 1 g/L
Also runs set_in_moles to set feed volume to 1 L
:param phases_xml_filename: (str) xml file with parameters
for equilibrium calc
@@ -656,7 +661,7 @@ class REEPS:
This function also calls update_predicted_dict
:param feed_vol: (float) feed volume of mixture (g/L)
:param feed_vol: (float) feed volume of mixture (L)
"""
phases_copy = self._phases.copy()
exp_df = self._exp_df.copy()
@@ -809,7 +814,7 @@ class REEPS:
"with at least 2 arguments: "
"f(objective_func,x_guess, kwargs)")
if optimizer == 'SLSQP':
optimizer = self.slsqp_optimizer
optimizer = self.scipy_minimize
self._optimizer = optimizer
return None
@@ -939,7 +944,9 @@ class REEPS:
i = 0
for species_name in opt_dict.keys():
for thermo_prop in opt_dict[species_name].keys():
opt_dict[species_name][thermo_prop] *= x[i]
if not np.isnan(
x[i]): # if nan, do not update xml with nan
opt_dict[species_name][thermo_prop] *= x[i]
i += 1
self.update_xml(opt_dict, temp_xml_file_path)
@@ -969,9 +976,12 @@ class REEPS:
with optimizer. Returns dictionary with opt_dict structure
:param objective_function: (function) function to compute objective
If 'None', last set objective or default function is used
:param optimizer: (function) function to perform optimization
:param optimizer_kwargs: (dict) arguments for optimizer
:param objective_kwargs: (dict) arguments for objective function
If 'None', last set optimizer or default is used
:param optimizer_kwargs: (dict) optional arguments for optimizer
:param objective_kwargs: (dict) optional arguments
for objective function
:returns opt_dict: (dict) optimized opt_dict. Has identical structure
as opt_dict
"""
@@ -1043,6 +1053,8 @@ class REEPS:
def parity_plot(self,
compared_value=None,
color_axis=None,
plot_title=None,
save_path=None,
print_r_squared=False):
"""
@@ -1052,6 +1064,16 @@ class REEPS:
:param compared_value: (str) Quantity to compare predicted and
experimental data. Can be any column containing "eq" in exp_df i.e.
h_eq, z_eq, {RE}_d_eq, etc.
:param plot_title: (str or boolean)
If None (default): Plot title will be generated from compared_value
Recommend to just explore. If h_eq, plot_title is
"H^+ eq conc".
If str: Plot title will be plot_title string
If "False": No plot title
:param color_axis: (dict)
:param save_path: (str) save path for parity plot
:param print_r_squared: (boolean) To plot or not to plot r-squared
value. Prints 2 places past decimal
@@ -1071,33 +1093,57 @@ class REEPS:
name_breakdown = re.findall('[^_\W]+', compared_value)
compared_species = name_breakdown[0]
if compared_species == 'h':
species_name = '$H^+$'
default_title = '$H^+$ eq. conc. (mol/L)'
elif compared_species == 'z':
species_name = extractant_name
default_title = '{0} eq. conc. (mol/L)'.format(extractant_name)
else:
phase = name_breakdown[1]
if phase == 'aq':
re_charge = re_charges[re_species_list.index(compared_species)]
species_name = '$%s^{%d+}$' % (compared_species, re_charge)
default_title = '$%s^{%d+}$ eq. conc. (mol/L)' \
% (compared_species, re_charge)
elif phase == 'd':
species_name = '{0} distribution ratio'.format(
default_title = '{0} distribution ratio'.format(
compared_species)
else:
species_name = '{0} complex'.format(compared_species)
default_title = '{0} complex eq. conc. (mol/L)'.format(
compared_species)
fig, ax = plt.subplots()
p1 = sns.scatterplot(meas, pred, color="r",
label="{0} eq. conc. (mol/L)".format(
species_name),
legend=False)
if color_axis is None:
sns.scatterplot(meas, pred, color="r",
legend=False)
else:
key = list(color_axis.keys())[0]
value = list(color_axis.values())[0]
if key == 'predicted':
y = self.get_predicted_dict()[value]
elif key == 'measured':
y = self.get_exp_df()[value].values
else:
raise Exception('color_axis must be a dictionary with key'
'"predicted" or "measured"')
y = np.array(y)
meas = np.array(meas)
pred = np.array(pred)
p1 = ax.scatter(meas, pred, c=y, alpha=1, cmap='viridis')
c_bar = fig.colorbar(p1, format='%.2f')
# Fix next line. value is just the dictionary value.
c_bar.set_label(value, rotation=270, labelpad=20)
sns.lineplot(min_max_data, min_max_data, color="b", label="")
if print_r_squared:
p1.text(min_max_data[0],
ax.text(min_max_data[0],
min_max_data[1] * 0.9,
'$R^2$={0:.2f}'.format(self.r_squared(compared_value)))
plt.legend(loc='lower right')
else:
plt.legend()
# plt.legend(loc='lower right')
# else:
# plt.legend()
ax.set(xlabel='Measured', ylabel='Predicted')
if plot_title is None:
ax.set_title(default_title)
elif isinstance(plot_title, str):
ax.set_title(plot_title)
plt.show()
if save_path is not None:
plt.savefig(save_path, bbox_inches='tight')
@@ -1110,7 +1156,7 @@ class REEPS:
:param compared_value: (str) Quantity to compare predicted and
experimental data. Can be any column containing "eq" in exp_df i.e.
h_eq, z_eq, {RE}_d_eq, etc.
h_eq, z_eq, {RE}_d_eq, etc. default is {RE}_aq_eq
"""
exp_df = self.get_exp_df()
predicted_dict = self.get_predicted_dict()