from typing import Any, List, Tuple, Union, Mapping, Iterable, Optional
import warnings
from scipy.stats import zscore
from numpy.linalg import LinAlgError
from matplotlib.axes import Axes as MplAxes
from statsmodels.tools.sm_exceptions import ConvergenceWarning
import numpy as np
import scipy
import pandas as pd
import scanpy as sc
import anndata as ad
import matplotlib
import statsmodels.api as sm
import matplotlib.pyplot as plt
import statsmodels.formula.api as smf
def _adjust_plotheight(scplot):
"""
Fix large gap between title and plot for scanpy plots.
(rather hacky, might not work in all cases)
"""
# modified code from sc.pl.MatrixPlot.make_figure
category_height = scplot.DEFAULT_CATEGORY_HEIGHT
category_width = scplot.DEFAULT_CATEGORY_WIDTH
mainplot_height = len(scplot.categories) * category_height
mainplot_width = len(scplot.var_names) * category_width + scplot.group_extra_size
if scplot.are_axes_swapped:
mainplot_height, mainplot_width = mainplot_width, mainplot_height
height = mainplot_height # + 1 # +1 for labels
# if the number of categories is small use
# a larger height, otherwise the legends do not fit
scplot.height = max([scplot.min_figure_height, height])
scplot.width = mainplot_width + scplot.legends_width
def _ensure_categorical(adata: ad.AnnData, col: str) -> None:
if isinstance(adata.obs[col].dtype, pd.CategoricalDtype):
# nothing todo
return
adata.obs[col] = adata.obs[col].astype(str).astype("category")
return
# TODO add group size similar to dotplot here!
[docs]def plot_mean_intensity(
adata: ad.AnnData,
groupby: str = "cluster",
marker_dict: Optional[Union[Mapping[str, Iterable[str]], Iterable[str]]] = None,
save: Optional[str] = None,
dendrogram: bool = False,
limit_to_groups: Optional[Mapping[str, Union[str, List[str]]]] = None,
type: str = "matrixplot", # noqa: A002
cmap: str = "viridis",
adjust_height: bool = True,
figsize: Tuple[int, int] = (10, 5),
ax: MplAxes = None,
**kwargs: Any,
) -> None:
"""
Show per cluster intensity of each channel.
Intensity is either shown as mean or z-scored intensity,
depending on the ``standard_scale`` keyword argument.
Parameters
----------
adata
Adata containing aggregated information by clusters.
E.g. result of :meth:`FeatureExtractor.get_intensity_adata`.
groupby
column in ``adata.obs`` containing the groups to compare.
marker_dict
Limit/group vars that are shown, either by passing list or dict (adds annotations to plot).
save
Path to save figure to.
dendrogram
Show dendrogram over columns.
limit_to_groups
Dict with obs as keys and groups from obs as values, to subset adata before plotting.
type
Type of plot, either `matrixplot` or `violinplot`.
cmap
Matplotlib colormap to use.
adjust_height
Option to make plots a bit more streamlined.
figsize
Size of figure.
ax
Axis to plot in.
kwargs
Keyword arguments for :func:`sc.pl.stacked_violin`/:func:`sc.pl.matrixplot`.
"""
if limit_to_groups is None:
limit_to_groups = {}
_ensure_categorical(adata, groupby)
# subset data
for key, groups in limit_to_groups.items():
if not isinstance(groups, list):
groups = [groups]
adata = adata[adata.obs[key].isin(groups)]
# group vars together?
if marker_dict is None:
marker_dict = np.array(adata.var.index)
if isinstance(marker_dict, dict):
marker_list = np.concatenate(list(marker_dict.values()))
else:
marker_list = marker_dict
# calculate values to show
color_values = pd.DataFrame(index=adata.var.index)
for g in adata.obs[groupby].cat.categories:
color = "mean intensity"
g_expr = adata[adata.obs[groupby] == g].X
g_size = np.array(adata[adata.obs[groupby] == g].obs["size"])
color_values[g] = np.array((g_expr * g_size[:, np.newaxis]).sum(axis=0) / g_size.sum())
color_values = color_values.loc[marker_list]
standard_scale = kwargs.pop("standard_scale", None)
if standard_scale == "var":
color_values = color_values.apply(zscore, axis=1)
elif standard_scale == "obs":
color_values = color_values.apply(zscore, axis=0)
# plot
if dendrogram:
sc.tl.dendrogram(adata, groupby=groupby)
title = "mean intensity in " + ", ".join([f"{key}: {val}" for key, val in limit_to_groups.items()])
if limit_to_groups == {}:
title = "mean intensity"
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
if type == "violinplot":
scplot = sc.pl.stacked_violin(
adata,
var_names=marker_dict,
groupby=groupby, # standard_scale='var',
ax=ax,
dendrogram=dendrogram,
return_fig=True,
title=title,
**kwargs,
)
elif type == "matrixplot":
scplot = sc.pl.matrixplot(
adata,
var_names=marker_dict,
groupby=groupby,
cmap=cmap,
colorbar_title=color,
ax=ax,
return_fig=True,
dendrogram=dendrogram,
values_df=color_values.T,
title=title,
**kwargs,
)
else:
raise NotImplementedError(type)
if adjust_height:
_adjust_plotheight(scplot)
scplot.make_figure()
# add axis labels
scplot.ax_dict["mainplot_ax"].set_xlabel("channels")
scplot.ax_dict["mainplot_ax"].set_ylabel(groupby)
if save is not None:
plt.savefig(save, dpi=100)
[docs]def plot_mean_size(
adata: ad.AnnData,
groupby_row: str = "cluster",
groupby_col: str = "well_name",
normby_row: Optional[str] = None,
normby_col: Optional[str] = None,
ax: MplAxes = None,
figsize: Tuple[int, int] = None,
adjust_height: bool = False,
save: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Plot mean cluster sizes per cell, grouped by different columns in obs.
Parameters
----------
adata
Adata containing aggregated information by clusters.
E.g. result of :meth:`FeatureExtractor.get_intensity_adata`.
groupby_row
Column in ``adata.obs`` containing the row-wise grouping.
groupby_col
Column in ``adata.obs`` containing the column-wise grouping.
normby_row
Value in ``groupby_row`` to normalise rows by.
normby_col
Value in ``groupby_col`` to normalise columns by.
ax
Axis to plot in.
figsize
Size of figure.
adjust_height
Option to make plots a bit more streamlined.
save
Path to save figure to.
kwargs
Keyword arguments for :func:`sc.pl.matrixplot`.
"""
_ensure_categorical(adata, groupby_row)
_ensure_categorical(adata, groupby_col)
# groupy_col needs to be var
sizes = {
c: adata[adata.obs[groupby_col] == c].obs.groupby(groupby_row).mean()["size"]
for c in adata.obs[groupby_col].cat.categories
}
sizes_adata = ad.AnnData(pd.DataFrame(sizes), dtype=np.float32)
sizes_adata.obs["group"] = sizes_adata.obs.index.astype("category")
# get values to show
values_df = pd.DataFrame(sizes)
title_suffix = ""
if normby_row is not None:
values_df = values_df.divide(values_df.loc[normby_row], axis="columns")
title_suffix += f"\nnormalised by {groupby_row} {normby_row}"
if normby_col is not None:
values_df = values_df.divide(values_df.loc[:, normby_col], axis="rows")
title_suffix += f"\nnormalised by {groupby_col} {normby_col}"
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
scplot = sc.pl.matrixplot(
sizes_adata,
var_names=sizes_adata.var_names,
groupby="group",
values_df=values_df,
colorbar_title="mean size\nin group",
title="mean object size" + title_suffix,
ax=ax,
show=False,
return_fig=True,
**kwargs,
)
if adjust_height:
_adjust_plotheight(scplot)
scplot.make_figure()
if save is not None:
plt.savefig(save, dpi=100)
def mixed_model(ref_expr, g_expr, ref_well_name, g_well_name):
"""
Calcuate mixed model for dotplots.
"""
# res_data = {'resid'}
res_data: Mapping[str, List[Any]] = {"df": [], "resid": []}
# if True:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning, message=".*Random effects covariance is singular.*")
# iterate over all channels in the data
pvals = []
for i in range(ref_expr.shape[-1]):
# create dataframe for mixed model
df = pd.DataFrame(index=range(len(ref_expr) + len(g_expr)))
df["mean_expr"] = np.log2(np.concatenate([ref_expr[:, i], g_expr[:, i]]))
df["group"] = [0] * len(ref_expr) + [1] * len(g_expr)
df["well"] = np.concatenate([ref_well_name, g_well_name])
df.replace([np.inf, -np.inf], np.nan, inplace=True)
df = df.dropna()
# display(df)
# sns.distplot(df[df['group']==0]['mean_expr'])
# sns.distplot(df[df['group']==1]['mean_expr'])
model = sm.MixedLM.from_formula("mean_expr ~ group", re_formula="~1", groups="well", data=df)
try:
result = model.fit()
except LinAlgError:
print(f"Singular fit with mixed model for column {i}, replacing with OLS.")
model = smf.ols(formula="mean_expr ~ group", data=df)
result = model.fit()
pvals.append(result.pvalues["group"])
res_data["df"].append(df)
return np.array(pvals).astype("float"), res_data
[docs]def get_intensity_change(
adata: ad.AnnData,
groupby: str,
marker_dict: Optional[Union[Mapping[str, Iterable[str]], Iterable[str]]] = None,
limit_to_groups: Optional[Mapping[str, Union[str, List[str]]]] = None,
reference: Optional[Union[List[str], str]] = None,
reference_group: Optional[str] = None,
color: str = "logfoldchange",
size: str = "mean_reference",
group_sizes_barplot: Optional[str] = None,
pval: str = "ttest",
alpha: float = 0.05,
norm_by_group: Optional[str] = None,
) -> Mapping[str, Any]:
"""
Get data for plotting intensity comparison with :func:`plot_intensity_change`.
Calculate mean intensity differences between perturbations or clusters.
If no reference is given, use all other groups (except the current one) as reference.
Colours show log2-foldchange / mean intensity changes / z-score changes, depending on the ``color`` argument.
Dot size shows mean intensity of reference group that is compared to, or indicates the p-value,
depending on the ``size`` argument.
Parameters
----------
adata
Adata containing aggregated information by clusters.
E.g. result of :meth:`FeatureExtractor.get_intensity_adata`.
groupby
column in ``adata.obs`` containing the groups to compare.
marker_dict
Limit/group vars that are shown, either by passing list or dict (adds annotations to plot).
limit_to_groups
Dict with obs as keys and groups from obs as values, to subset adata before plotting.
reference
Reference cluster/perturbation to compare to.
If not defined, will compare each value in ``groupby`` against the rest.
reference_group
Obs entry that contains reference grouping (by default, ``groupby`` is used).
color
Colour of dots, either `logfoldchange` or `meanchange`.
size
sizes of dots, either `mean_reference` or `pval` (distinguish significant and non-significant dots).
group_sizes_barplot
Mean size of groups shown as a bar plot to the right.
Either None (do not show), `mean` (mean size of groups),
`meanchange` (mean difference of group size from reference), `foldchange`.
pval
Type of test done to determine p-values. Either `ttest` or `mixed_model`.
`mixed_model` calculates a mixed model using wells as random effects and should be preferred.
Note that when using `norm_by_group`, the mixed model will be calculated on the normalised values,
which differs from the model used in the original publication.
alpha
``pval`` threshold above which dots are not shown
norm_by_group
Divide all mean values by the mean values of this group.
This is done separately for the reference and the values to compare to.
Returns
-------
Mapping[str, Any]:
data to input to :func:`plot_intensity_change`.
"""
_ensure_categorical(adata, groupby)
if limit_to_groups is None:
limit_to_groups = {}
# subset data
for key, groups in limit_to_groups.items():
if not isinstance(groups, list):
groups = [groups]
adata = adata[adata.obs[key].isin(groups)]
# which vars to show?
if marker_dict is None:
marker_dict = np.array(adata.var.index)
if isinstance(marker_dict, dict):
marker_list = np.concatenate(list(marker_dict.values()))
else:
marker_list = marker_dict
# subset data to markers that we'd like to show
adata = adata[:, marker_list].copy()
# calculate values to show
color_values = pd.DataFrame(index=adata.var.index) # intensity values shown as colors
p_values = pd.DataFrame(index=adata.var.index) # pvalues (impacting dot sizes)
size_values = pd.DataFrame(index=adata.var.index) # dot sizes
group_size = {} # mean group sizes (for barplot)
p_values_data = {} # additional data returned from mixed model
# define reference
adata_ref = None
if reference_group is None:
reference_group = groupby
if reference is not None:
if not isinstance(reference, list):
reference = [reference]
adata_ref = adata[adata.obs[reference_group].isin(reference)]
# subset adata to not reference
adata = adata[~adata.obs[reference_group].isin(reference)]
assert len(adata) > 0, f"no obs in adata that are not one of {reference} in {reference_group}"
for g in adata.obs[groupby].cat.categories:
# reference expression
if reference is not None:
assert adata_ref is not None
if reference_group != groupby:
# reference expression is the current group in the reference group
# (which is a distinct grouping from the groupby categories)
# print('reference expression is the current group in the reference group')
adata_cur_ref = adata_ref[adata_ref.obs[groupby] == g]
else:
# reference expression is the reference group
# print('reference expression is the reference group')
adata_cur_ref = adata_ref
else:
# reference expression is everything except the current group (classic comparison of groupings)
# print('reference expression is everything except the current group')
adata_cur_ref = adata[adata.obs[groupby] != g]
cur_ref_expr = adata_cur_ref.X
if norm_by_group is not None:
assert reference is not None, "Need a reference for norm by group"
assert reference_group != groupby, "Can only norm by group if reference_group is different to groupby"
assert adata_ref is not None
cur_ref_expr = cur_ref_expr / adata_ref[adata_ref.obs[groupby] == norm_by_group].X
cur_ref_size = np.array(adata_cur_ref.obs["size"])
# group expression
g_expr = adata[adata.obs[groupby] == g].X
if norm_by_group is not None:
g_expr = g_expr / adata[adata.obs[groupby] == norm_by_group].X
g_size = np.array(adata[adata.obs[groupby] == g].obs["size"])
# mean group expression
mean_g = (g_expr * g_size[:, np.newaxis]).sum(axis=0) / g_size.sum()
# mean reference expression
mean_ref = (cur_ref_expr * cur_ref_size[:, np.newaxis]).sum(axis=0) / cur_ref_size.sum()
# mean group size
if group_sizes_barplot == "mean":
group_size[g] = g_size.mean()
elif group_sizes_barplot == "meanchange":
group_size[g] = g_size.mean() - cur_ref_size.mean()
elif group_sizes_barplot == "foldchange":
group_size[g] = g_size.mean() / cur_ref_size.mean()
else:
group_size[g] = 0 # type: ignore[assignment]
# set p values by testing if distribution of intensities is the same (without adjusting for size!)
if pval == "mixed_model":
if norm_by_group == g:
# not not calc p values, all mean intensities will be 0
pvals = np.array([1] * len(adata.var.index))
pvals_data = {}
else:
g_well_name = adata[adata.obs[groupby] == g].obs["well_name"]
pvals, pvals_data = mixed_model(
cur_ref_expr,
g_expr,
ref_well_name=adata_cur_ref.obs["well_name"],
g_well_name=g_well_name,
)
elif pval == "ttest":
pvals_data = {}
_, pvals = scipy.stats.ttest_ind(cur_ref_expr, g_expr, axis=0)
else:
raise NotImplementedError(pval)
p_values[g] = pvals
p_values_data[g] = pvals_data
# set size values to mean intensity of reference group
if size == "mean_reference":
size_values[g] = mean_ref
size_title = "mean intensity\nof reference"
elif size == "pval":
# size values are 1 for significant pvals, and 0.5 for non-significant pvals
sizes = (pvals <= alpha).astype("float")
sizes[sizes == 0] = 0.5
size_values[g] = sizes
size_title = "pvalue"
else:
raise NotImplementedError(size)
# set color values
if color == "logfoldchange":
color_values[g] = np.array(np.log2(mean_g / mean_ref))
elif color == "meanchange":
color_values[g] = np.array(mean_g - mean_ref)
else:
raise NotImplementedError(color)
# get title for plot
lmt_str = ", ".join([f'{key}: {",".join(val)}' for key, val in limit_to_groups.items()])
if limit_to_groups == {}:
lmt_str = "all"
if reference is not None:
title = f'{color} of {lmt_str} wrt {reference_group}: {",".join([str(r) for r in reference])}'
else:
title = f"{color} of {lmt_str} wrt rest"
cbar_title = f"{color} in group"
if norm_by_group is not None:
cbar_title = f"relative {color} in group\nwrt {norm_by_group}"
return_dict = {
"adata": adata,
"color_values": color_values,
"size_values": size_values,
"p_values": p_values,
"p_values_data": p_values_data,
"group_size": group_size,
"marker_dict": marker_dict,
"groupby": groupby,
"alpha": alpha,
"plot_data": {
"title": title,
"colorbar_title": cbar_title,
"size_title": size_title,
"show_unsignificant_dots": size == "pval",
"group_sizes_barplot": group_sizes_barplot,
},
}
return return_dict
[docs]def plot_intensity_change(
adata: ad.AnnData,
color_values: pd.DataFrame,
size_values: pd.DataFrame,
p_values: pd.DataFrame,
p_values_data: Mapping[str, Any],
group_size: Mapping[str, Any],
marker_dict: Optional[Union[Mapping[str, Iterable[str]], Iterable[str]]],
groupby: str,
plot_data: Mapping[str, Any],
alpha: float,
adjust_height: bool = True,
ax: Optional[matplotlib.axes.Axes] = None,
figsize: Iterable[int] = (10, 3),
save: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Plot mean intensity differences between perturbations or clusters.
Takes returns of :func:`get_intensity_change` as input:
``plot_intensity_change(**get_intensity_change(...))``
Parameters
----------
adjust_height
Option to make plots a bit more streamlined.
ax
Axis to plot in.
figsize
Size of figure.
save
Path to save figure to.
kwargs
Keyword arguments for :func:`sc.pl.dotplot`.
"""
kwargs["vmin"] = kwargs.get("vmin", -1)
kwargs["vmax"] = kwargs.get("vmax", 1)
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
scplot = sc.pl.dotplot(
adata,
var_names=marker_dict,
groupby=groupby,
dot_color_df=color_values.T,
cmap="bwr",
colorbar_title=plot_data["colorbar_title"],
size_title=plot_data["size_title"],
title=plot_data["title"],
show=False,
return_fig=True,
ax=ax,
**kwargs,
)
# set dot size
scplot.dot_size_df = size_values.T
# do not show unsignificant dots if size does not indicate this
if not plot_data["show_unsignificant_dots"]:
scplot.dot_size_df[p_values.T > alpha] = 0
# add group sizes
if plot_data["group_sizes_barplot"] is not None:
group_size: pd.Series = pd.Series(data=group_size) # type: ignore[no-redef]
scplot.group_extra_size = 0.8
scplot.plot_group_extra = {
"kind": "group_totals",
"width": 0.8,
"sort": None,
"counts_df": group_size,
"color": None,
}
if adjust_height:
_adjust_plotheight(scplot)
scplot.make_figure()
# add axis labels
scplot.ax_dict["mainplot_ax"].set_xlabel("channel")
scplot.ax_dict["mainplot_ax"].set_ylabel(groupby)
# allow negative values in barplot
if plot_data["group_sizes_barplot"] == "meanchange":
assert isinstance(group_size, pd.Series)
scplot.ax_dict["group_extra_ax"].set_xlim(
(
group_size.min() - np.abs(group_size.min()) * 0.4,
group_size.max() + np.abs(group_size.max() * 0.4),
)
)
if save is not None:
plt.savefig(save, dpi=100)
[docs]def plot_size_change(
adata: ad.AnnData,
groupby_row: str = "cluster",
groupby_col: str = "well_name",
reference_row: Optional[str] = None,
reference_col: Optional[str] = None,
figsize: Optional[Tuple[int, int]] = None,
adjust_height: bool = True,
ax: MplAxes = None,
pval: float = 0.05,
save: Optional[str] = None,
size: str = "mean_size",
limit_to_groups: Optional[Mapping[str, Union[List[str], str]]] = None,
**kwargs: Any,
) -> None:
"""
Plot mean intensity differences between perturbations and clusters.
Parameters
----------
adata
Adata containing aggregated information by clusters.
E.g. result of :meth:`FeatureExtractor.get_intensity_adata`.
groupby_row
Column in ``adata.obs`` containing the row-wise grouping.
groupby_col
Column in ``adata.obs`` containing the column-wise grouping.
reference_row
Reference cluster/perturbation to compare to row-wise.
If not defined, will compare each value in groupby_row against the rest.
reference_col
Reference cluster/perturbation to compare to col-wise.
If not defined, will compare each value in groupby_col against the rest.
figsize
Size of figure.
adjust_height
Option to make plots a bit more streamlined.
ax
Axis to plot in.
pval
``pval`` threshold above which dots are not shown.
save
Path to save figure to.
size
Sizes of dots, either `mean_size` or `pval` (distinguish significant and non-significant dots).
limit_to_groups
Dict with obs as keys and groups from obs as values, to subset adata before plotting.
kwargs
Keyword arguments for :func:`sc.pl.dotplot`.
"""
if limit_to_groups is None:
limit_to_groups = {}
assert (reference_col is None) ^ (reference_row is None), "either reference_row or reference_col must be defined"
_ensure_categorical(adata, groupby_row)
_ensure_categorical(adata, groupby_col)
kwargs["vmin"] = kwargs.get("vmin", -1)
kwargs["vmax"] = kwargs.get("vmax", 1)
# subset data
for key, groups in limit_to_groups.items():
if not isinstance(groups, list):
groups = [groups]
adata = adata[adata.obs[key].isin(groups)]
col_grps = adata.obs[groupby_col].cat.categories
row_grps = adata.obs[groupby_row].cat.categories
# calculate mean sizes to plot later
grp_df = adata.obs.groupby([groupby_row, groupby_col])
sizes_adata = ad.AnnData(grp_df.mean()["size"].unstack())
sizes_adata.obs["group"] = sizes_adata.obs.index.astype("category")
# calculate values to show
data = grp_df.mean()["size"].unstack()
color_values = pd.DataFrame(index=row_grps, columns=col_grps)
p_values = pd.DataFrame(index=row_grps, columns=col_grps)
size_values = pd.DataFrame(index=row_grps, columns=col_grps)
# assign color_values and size_values
if reference_row is not None:
color_values = np.log2(data.divide(data.loc[reference_row], axis="columns"))
size_values.loc[:, :] = np.array(data.loc[reference_row])[np.newaxis, :]
if reference_col is not None:
color_values = np.log2(data.divide(data.loc[:, reference_col], axis="rows"))
size_values.loc[:, :] = np.array(data.loc[:, reference_col])[:, np.newaxis]
# assign p_values
for r in row_grps:
for c in col_grps:
if reference_row is not None:
ref_dist = grp_df.get_group((reference_row, c))["size"]
if reference_col is not None:
ref_dist = grp_df.get_group((r, reference_col))["size"]
cur_dist = grp_df.get_group((r, c))["size"]
_, p = scipy.stats.ttest_ind(ref_dist, cur_dist)
p_values.loc[r, c] = p
# set size values
if size == "mean_size":
# size values is already correct
size_title = "mean size\nof reference (%)"
elif size == "pval":
# size values are 1 for significant pvals, and 0.5 for non-significant pvals
for r in row_grps:
for c in col_grps:
size_values.loc[r, c] = 1 if p_values.loc[r, c] <= pval else 0.5
size_title = "pvalue"
else:
raise NotImplementedError(size)
size_values = size_values.astype("float")
size_values = size_values / size_values.max()
# plot
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
lmt_str = ", ".join([f'{key}: {",".join(val)}' for key, val in limit_to_groups.items()])
if limit_to_groups == {}:
lmt_str = "all"
title = f"size logfoldchange of {lmt_str} wrt "
if reference_col is not None:
title += f"{groupby_col}: {reference_col}"
else:
title += f"{groupby_row}: {reference_row}"
scplot = sc.pl.dotplot(
sizes_adata,
var_names=sizes_adata.var_names,
groupby="group",
dot_color_df=color_values,
dot_size_df=size_values,
cmap="bwr",
colorbar_title="logfolchange in group",
size_title=size_title,
title=title,
show=False,
return_fig=True,
ax=ax,
**kwargs,
)
# do not show unsignificant dots if size does not indicate this
if size != "pval":
scplot.dot_size_df[p_values > pval] = 0
if adjust_height:
_adjust_plotheight(scplot)
scplot.make_figure()
if save is not None:
plt.savefig(save, dpi=100)