Source code for mat_discover.utils.plotting

"""Various plotting functions for cluster properties and UMAP visualization."""
from os.path import join
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import MinMaxScaler

# import seaborn as sns


# TODO: change to square plots
[docs]def umap_cluster_scatter(std_emb, labels, figure_dir="figures"): """Plot UMAP embeddings colored by cluster IDs. Parameters ---------- std_emb : 2d array UMAP embedded coordinates. labels : 1d array Cluster IDs associated with the UMAP coordinates. Returns ------- fig : Figure Handle to Matplotlib Figure. """ # TODO: update plotting commands to have optional arguments (e.g. std_emb and labels) cmap = plt.cm.nipy_spectral mx = np.max(labels) # cmap = sns.color_palette("Spectral", mx + 1, as_cmap=True) class_ids = labels != -1 fig = plt.Figure() ax = plt.scatter( std_emb[:, 0], std_emb[:, 1], c=labels, s=0.1, cmap=cmap, label=labels ) unclass_ids = np.invert(class_ids) unclass_frac = np.sum(unclass_ids) / len(labels) plt.axis("off") if unclass_frac != 0.0: ax2 = plt.scatter( std_emb[unclass_ids, 0], std_emb[unclass_ids, 1], c=labels[unclass_ids], s=0.1, cmap=plt.cm.nipy_spectral, label=labels[unclass_ids], ) # How to put the legend out of the plot: https://stackoverflow.com/a/4701285/13697228 plt.legend( [ax2], ["Unclassified: " + "{:.1%}".format(unclass_frac)], loc="upper center", bbox_to_anchor=(0.5, -0.05), ) plt.tight_layout() plt.gca().set_aspect("equal", "box") plt.savefig(join(figure_dir, "umap-cluster-scatter")) plt.show() return fig # TODO: update label ints so they don't overlap so much (skip some based on length of labels) lbl_ints = np.arange(np.amax(labels) + 1) if unclass_frac != 1.0: plt.colorbar(ax, boundaries=lbl_ints - 0.5, label="Cluster ID").set_ticks( lbl_ints ) plt.show()
[docs]def cluster_count_hist(labels, figure_dir="figures"): """Plot histogram of cluster counts, colored by cluster IDs. Parameters ---------- labels : 1d array Cluster IDs. Returns ------- fig : Figure Handle to Matplotlib Figure. """ col_scl = MinMaxScaler() unique_labels = np.unique(labels) col_trans = col_scl.fit(unique_labels.reshape(-1, 1)) scl_vals = col_trans.transform(unique_labels.reshape(-1, 1)) color = plt.cm.nipy_spectral(scl_vals) # mx = np.max(labels) # cmap = sns.color_palette("Spectral", mx + 1, as_cmap=True) # color = cmap(scl_vals) fig = plt.Figure() plt.bar(*np.unique(labels, return_counts=True), color=color) plt.xlabel("cluster ID") plt.ylabel("number of compounds") plt.tight_layout() plt.savefig(join(figure_dir, "cluster-count-hist")) plt.show() return fig
[docs]def target_scatter(std_emb, target, figure_dir="figures", color_unit=None): """Plot UMAP embedding locations colored by target values. Parameters ---------- std_emb : 2d array UMAP embedding coordinates. target : 1d array Target properties corresponding to `std_emb`. Returns ------- fig : Figure Handle to Matplotlib Figure. """ # TODO: change to log colorscale or a higher-contrast fig = plt.Figure() plt.scatter( std_emb[:, 0], std_emb[:, 1], c=target, s=0.1, cmap="Spectral_r", norm=mpl.colors.LogNorm(), ) plt.axis("off") label = "target" if color_unit is None: label = f"{label} ({color_unit})" plt.colorbar(label=label, orientation="horizontal") plt.tight_layout() plt.gca().set_aspect("equal", "box") plt.savefig(join(figure_dir, "target-scatter")) plt.show() return fig
[docs]def dens_scatter(x, y, pdf_sum, figure_dir="figures"): """Plot DensMAP densities at the `x` and `y` embedding coordinates. Parameters ---------- x : 1d array x-coordinates y : 1d array y-coordinates pdf_sum : 1d array probabilities evaluated at each of the `x` and `y` coordinate pairs. Returns ------- fig : Figure Handle to Matplotlib Figure. See Also -------- mat_discover_.mvn_prob_sum : used to obtain `x`, `y`, and `pdf_sum` """ # TODO: add callouts to specific locations (high-scoring compounds) fig = plt.Figure() plt.scatter(x, y, c=pdf_sum) plt.axis("off") plt.tight_layout() plt.colorbar(label="Density", orientation="horizontal") plt.gca().set_aspect("equal", "box") plt.savefig(join(figure_dir, "dens-scatter")) plt.show() return fig
[docs]def dens_targ_scatter(std_emb, target, x, y, pdf_sum, figure_dir="figures"): """Plot overlay of density scatter and target scatter plots. Parameters ---------- std_emb : 2d array UMAP embedding coordinates. target : 1d array Target properties corresponding to `std_emb`. x : 1d array x-coordinates y : 1d array y-coordinates pdf_sum : 1d array probabilities evaluated at each of the `x` and `y` coordinate pairs. Returns ------- fig : Figure Handle to Matplotlib Figure. See Also -------- dens_scatter : density scatter plot targ_scatter : target scatter plot """ fig = plt.Figure() plt.scatter(x, y, c=pdf_sum) plt.scatter( std_emb[:, 0], std_emb[:, 1], c=target, s=2, cmap="Spectral", edgecolors="none", alpha=0.15, ) plt.axis("off") plt.tight_layout() plt.gca().set_aspect("equal", "box") plt.savefig(join(figure_dir, "dens-targ-scatter")) plt.show() return fig
[docs]def group_cv_parity(ytrue, ypred, labels, figure_dir="figures"): """Leave-one-cluster-out cross-validation parity plot colored by `labels`. Parameters ---------- ytrue : 1d array True target values. ypred : 1d array Predicted target values. labels : 1d array Cluster IDs. Returns ------- fig : Figure Handle to Matplotlib Figure. """ labels = np.array(labels) col_scl = MinMaxScaler() col_trans = col_scl.fit(labels.reshape(-1, 1)) scl_vals = col_trans.transform(labels.reshape(-1, 1)) color = plt.cm.nipy_spectral(scl_vals) mx = np.nanmax([ytrue, ypred]) fig = plt.scatter(ytrue, ypred, c=color) plt.plot([0, 0], [mx, mx], "--", linewidth=1) plt.xlabel(r"$E_\mathregular{avg,true}$ (GPa)") plt.ylabel(r"$E_\mathregular{avg,pred}$ (GPa)") plt.tight_layout() plt.savefig(join(figure_dir, "group-cv-parity")) plt.show() return fig
[docs]def matplotlibify(fig, size=24, width_inches=3.5, height_inches=3.5, dpi=142): # make it look more like matplotlib # modified from: https://medium.com/swlh/formatting-a-plotly-figure-with-matplotlib-style-fa56ddd97539) font_dict = dict(family="Arial", size=size, color="black") # app = QApplication(sys.argv) # screen = app.screens()[0] # dpi = screen.physicalDotsPerInch() # app.quit() fig.update_layout( font=font_dict, plot_bgcolor="white", width=width_inches * dpi, height=height_inches * dpi, margin=dict(r=40, t=20, b=10), ) fig.update_yaxes( showline=True, # add line at x=0 linecolor="black", # line color linewidth=2.4, # line size ticks="inside", # ticks outside axis tickfont=font_dict, # tick label font mirror="allticks", # add ticks to top/right axes tickwidth=2.4, # tick width tickcolor="black", # tick color ) fig.update_xaxes( showline=True, showticklabels=True, linecolor="black", linewidth=2.4, ticks="inside", tickfont=font_dict, mirror="allticks", tickwidth=2.4, tickcolor="black", ) # fig.update(layout_coloraxis_showscale=False) width_default_px = fig.layout.width targ_dpi = 300 scale = width_inches / (width_default_px / dpi) * (targ_dpi / dpi) return fig, scale