Source code for admix.plot._plot

import warnings
import matplotlib.pyplot as plt
from matplotlib import collections as mc
from matplotlib import patheffects

import numpy as np
import pandas as pd
from scipy import stats

import admix
from admix.data import quantile_normalize
from admix.data import lambda_gc

from typing import Dict


def pca(
    df_pca: pd.DataFrame,
    x: str = "PC1",
    y: str = "PC2",
    label_col: str = None,
    label_order: list = None,
    s=5,
    legend_loc="on data",
    alpha=None,
    ax=None,
):
    """PCA plot

    Parameters
    ----------
    df_pca : pd.DataFrame
        dataframe with PCA components
    x : str, optional
        x-axis, by default "PC1"
    y : str, optional
        y-axis, by default "PC2"
    label_col : str, optional
        column name for labels, by default None
    s : float, optional
    """
    if alpha is None:
        alpha = 1.0
    else:
        assert isinstance(alpha, float) or isinstance(alpha, dict)
    if ax is None:
        ax = plt.gca()

    if label_order is None:
        label_order = df_pca[label_col].unique()

    ax.set_xlabel(x)
    ax.set_ylabel(y)

    if label_col is None:
        return

    # otherwise label_col is present
    for label in label_order:
        group = df_pca.loc[df_pca[label_col] == label, :]
        if isinstance(alpha, dict):
            label_alpha = alpha[label] if label in alpha else 1.0
        else:
            label_alpha = alpha
        ax.scatter(group[x], group[y], s=s, label=label, alpha=label_alpha)

    if legend_loc == "on data":

        all_pos = (
            pd.DataFrame(df_pca[[x, y, label_col]])
            .groupby(label_col, observed=True)
            .median()
            .sort_index()
        )

        for label, x_pos, y_pos in all_pos.itertuples():
            ax.text(
                x_pos,
                y_pos,
                label,
                # weight="bold",
                path_effects=[patheffects.withStroke(linewidth=2.5, foreground="w")],
                verticalalignment="center",
                horizontalalignment="center",
            )

    legend = ax.legend()
    for lh in legend.legendHandles:
        lh.set_alpha(1)
        lh.set_sizes([30])


def joint_pca(
    df_pc,
    x="PC1",
    y="PC2",
    sample_alpha=0.1,
    axes=None,
    figsize=(8.5, 4),
    label_col="SUPERPOP",
    sample_label="SAMPLE",
):
    """Joint PCA plot

    Parameters
    ----------
    df_pc : pd.DataFrame
        dataframe with PCA components
    eigenval : np.ndarray
        eigenvalues
    """
    new_axes = axes is None
    if new_axes:
        fig, axes = plt.subplots(figsize=figsize, dpi=150, ncols=2)

    admix.plot.pca(
        df_pc[df_pc[label_col] != sample_label],
        x=x,
        y=y,
        label_col=label_col,
        ax=axes[0],
    )
    assert set([x, y]).issubset(
        df_pc.columns
    ), f"{x} and {y} must be in the columns of df_pc"

    x_pos, y_pos = df_pc.columns.get_loc(x), df_pc.columns.get_loc(y)
    xlabel, ylabel = x, y
    axes[0].set_xlabel(xlabel)
    axes[0].set_ylabel(ylabel)

    admix.plot.pca(
        df_pc,
        x=x,
        y=y,
        label_col=label_col,
        alpha={sample_label: sample_alpha},
        ax=axes[1],
    )
    axes[1].set_xlabel(xlabel)
    axes[1].set_ylabel(ylabel)

    if new_axes:
        return fig, axes


[docs]def qq(pval, label=None, ax=None, bootstrap_ci=False): """qq plot of p-values Parameters ---------- pval : np.ndarray p-values, array-like ax : matplotlib.axes, optional by default None return_lambda_gc : bool, optional whether to return the lambda GC, by default False """ if ax is None: ax = plt.gca() pval = np.array(pval) pval = pval[~np.isnan(pval)] expected_pval = stats.norm.sf(quantile_normalize(-pval)) ax.scatter(-np.log10(expected_pval), -np.log10(pval), s=2, label=label) lim = max(-np.log10(expected_pval)) ax.plot([0, lim], [0, lim], "r--") ax.set_xlabel("Expected -$\log_{10}(p)$") ax.set_ylabel("Observed -$\log_{10}(p)$") if bootstrap_ci == True: lgc, lgc_ci = lambda_gc(pval, bootstrap_ci=True) else: lgc = lambda_gc(pval, bootstrap_ci=False) if bootstrap_ci: print(f"lambda GC: {lgc:.3g} [{lgc_ci[0]:.3g}, {lgc_ci[1]:.3g}]") return lgc, lgc_ci else: print(f"lambda GC: {lgc:.3g}") return lgc
[docs]def manhattan( pval, chrom=None, pos=None, axh_y=-np.log10(5e-8), s=0.1, label=None, ax=None, color="#3b76af", ): """Manhatton plot of p-values Parameters ---------- chrom : np.ndarray array-like pval : np.ndarray p-values, array-like pos: np.ndarray array-like, position for each SNP, if provided, position will be used axh_y : np.ndarray, optional horizontal line for genome-wide significance, by default -np.log10(5e-8) s : float, optional dot size, by default 0.1 ax : matplotlib.axes, optional axes, by default None """ if ax is None: ax = plt.gca() assert (chrom is None) or (pos is None), "chrom and pos cannot be both provided" if pos is None: pos_provided = False pos = np.arange(len(pval)) else: pos_provided = True assert len(pos) == len(pval) if chrom is None: # use snp index if pos_provided: ax.scatter( pos / 1e6, -np.log10(pval), s=s, label=label, facecolor=color, marker="o", ) ax.set_xlabel("SNP position (Mb)") else: ax.scatter(pos, -np.log10(pval), s=s, label=label, c=color) ax.set_xlabel("SNP index") else: assert pos_provided is False assert len(chrom) == len(pval) color_list = ["#1b9e77", "#d95f02"] # plot dots for odd and even chromosomes for mod in range(2): index = np.where(chrom % 2 == mod)[0] ax.scatter( pos[index], -np.log10(pval)[index], s=s, color=color_list[mod], label=label, ) # label unique chromosomes xticks = [] xticklabels = [] for chrom_i in np.unique(chrom): xticks.append(np.where(chrom == chrom_i)[0].mean()) xticklabels.append(chrom_i) ax.set_xticks(xticks) ax.set_xticklabels(xticklabels, rotation=90, fontsize=8) ax.set_xlabel("Chromosome") ax.set_ylabel("-$\log_{10}(P)$") if axh_y is not None: ax.axhline(y=axh_y, color="r", ls="--")
def susie(pip, dict_cs, pos=None, ax=None, cmap="Set1"): cmap = plt.get_cmap(cmap) if ax is None: ax = plt.gca() if pos is None: pos = np.arange(len(pip)) pos_provided = False else: pos_provided = True assert len(pos) == len(pip) pos = pos / 1e6 ax.scatter(x=pos, y=pip, s=3, color="gray") for i, cs in enumerate(dict_cs): cs_pos = dict_cs[cs] ax.scatter( x=pos[cs_pos], y=pip[cs_pos], s=15, edgecolors=cmap.colors[i], facecolors=cmap.colors[i], alpha=0.6, ) if pos_provided: ax.set_xlabel("SNP position (Mb)") else: ax.set_xlabel("SNP index") ax.set_ylabel("PIP") ax.set_ylim(-0.05, 1.05)
[docs]def lanc( dset: admix.Dataset = None, lanc: np.ndarray = None, ax=None, max_indiv: int = None, ) -> None: """ Plot local ancestry. Parameters ---------- dset: xarray.Dataset A dataset containing the local ancestry matrix. lanc: np.ndarray A numpy array of shape (n_snp, n_indiv, 2) ax: matplotlib.Axes A matplotlib axes object to plot on. If None, will create a new one. max_indiv: int The maximum number of individuals to plot. If None, will plot the first 10 individuals Returns ------- ax: matplotlib.Axes """ # if dataset is provided, use it to extract lanc if dset is not None: lanc = dset.lanc.compute() pos = dset.snp.POS.values BP_POS = True else: assert lanc is not None, "either dataset or lanc must be provided" pos = np.arange(lanc.shape[0]) BP_POS = False # append dummy snp at the end to make plotting easier pos = np.concatenate([pos, [pos[-1] + 1]]) assert lanc.shape[2] == 2, "lanc must be of shape (n_snp, n_indiv, 2)" n_snp, n_indiv = lanc.shape[0:2] if max_indiv is not None: n_plot_indiv = min(max_indiv, n_indiv) else: n_plot_indiv = min(n_indiv, 10) if n_plot_indiv < n_indiv: warnings.warn( f"Only the first {n_plot_indiv} are plotted. To plot more individuals, increase `max_indiv`" ) if ax is None: ax = plt.gca() start = [] stop = [] label = [] row = [] for i_indiv in range(n_plot_indiv): for i_ploidy in range(2): a = lanc[:, i_indiv, i_ploidy] switch = np.where(a[1:] != a[0:-1])[0] switch = np.concatenate([[0], switch, [len(a)]]) for i_switch in range(len(switch) - 1): start_idx, stop_idx = switch[i_switch], switch[i_switch + 1] if BP_POS: start.append(pos[start_idx] / 1e6) stop.append(pos[stop_idx] / 1e6) else: start.append(start_idx) stop.append(stop_idx) label.append(a[start_idx + 1]) row.append(i_indiv - 0.1 + i_ploidy * 0.2) df_plot = pd.DataFrame({"start": start, "stop": stop, "label": label, "row": row}) lines = [[(r.start, r.row), (r.stop, r.row)] for _, r in df_plot.iterrows()] cmap = plt.get_cmap("tab10") for i, (label, group) in enumerate(df_plot.groupby("label")): lc = mc.LineCollection( [lines[l_i] for l_i in group.index], linewidths=2, label=label, color=cmap(i), ) ax.add_collection(lc) ax.legend() ax.autoscale() if BP_POS: ax.set_xlabel("SNP position (Mb)") else: ax.set_xlabel("SNP index") ax.set_ylabel("Individuals") ax.set_yticks([]) ax.set_yticklabels([])
[docs]def admixture( a: np.ndarray, labels=None, label_orders=None, ax=None, ) -> None: """ Plot admixture. Parameters ---------- a: np.ndarray A numpy array of shape (n_indiv, n_snp, 2) labels: list A list of labels for each individual. label_orders: list A list of orderings for the individuals. ax: matplotlib.Axes A matplotlib axes object to plot on. If None, will create a new one. Returns ------- None """ n_indiv, n_pop = a.shape # reorder based on labels if labels is not None: dict_label_range = dict() reordered_a = [] unique_labels = np.unique(labels) if label_orders is not None: assert set(label_orders) == set( unique_labels ), "label_orders must cover all unique labels" unique_labels = label_orders cumsum = 0 for label in unique_labels: reordered_a.append(a[labels == label, :]) dict_label_range[label] = [cumsum, cumsum + sum(labels == label)] cumsum += sum(labels == label) a = np.vstack(reordered_a) if ax is None: ax = plt.gca() cmap = plt.get_cmap("tab10") bottom = np.zeros(n_indiv) for i_pop in range(n_pop): ax.bar( np.arange(n_indiv), height=a[:, i_pop], width=1, bottom=bottom, facecolor=cmap(i_pop), edgecolor=cmap(i_pop), ) bottom += a[:, i_pop] ax.tick_params(axis="both", left=False, labelleft=False) if labels is not None: seps = sorted(np.unique(np.concatenate([r for r in dict_label_range.values()]))) for x in seps[1:-1]: ax.vlines(x - 0.5, ymin=0, ymax=1, color="black") ax.set_xticks([np.mean(dict_label_range[label]) for label in dict_label_range]) ax.set_xticklabels([label for label in dict_label_range]) else: ax.get_xaxis().set_ticks([]) for pos in ["top", "right", "bottom", "left"]: ax.spines[pos].set_visible(False) return ax
def compare_pval( x_pval: np.ndarray, y_pval: np.ndarray, xlabel: str = None, ylabel: str = None, ax=None, s: int = 5, ): """Compare two p-values. Parameters ---------- x_pval: np.ndarray The p-value for the first variable. y_pval: np.ndarray The p-value for the second variable. xlabel: str The label for the first variable. ylabel: str The label for the second variable. ax: matplotlib.Axes A matplotlib axes object to plot on. If None, will create a new one. """ if ax is None: ax = plt.gca() if not isinstance(x_pval, np.ndarray): x_pval = np.array(x_pval) if not isinstance(y_pval, np.ndarray): y_pval = np.array(y_pval) nonnan_idx = ~np.isnan(x_pval) & ~np.isnan(y_pval) x_pval, y_pval = -np.log10(x_pval[nonnan_idx]), -np.log10(y_pval[nonnan_idx]) ax.scatter(x_pval, y_pval, s=s) lim = max(np.nanmax(x_pval), np.nanmax(y_pval)) * 1.1 ax.plot([0, lim], [0, lim], "k--", alpha=0.5, lw=1, label="y=x") # add a regression line slope = np.linalg.lstsq(x_pval[:, None], y_pval[:, None], rcond=None)[0].item() ax.axline( (0, 0), slope=slope, color="black", ls="--", lw=1, label=f"y={slope:.2f} x", ) ax.legend() if xlabel is not None: ax.set_xlabel(xlabel) if ylabel is not None: ax.set_ylabel(ylabel) def rg_posterior( xs: np.ndarray, dict_loglik: Dict[str, np.ndarray], ci=[0.5, 0.95], s=11, colors="black", markers="o", ax=None, ): """ Plot the posterior distribution Parameters ---------- xs: np.ndarray list of x coordinates dict_loglik: Dict[np.ndarray] trait -> list of log-likelihoods ci: Union[float, List[float]] ci to plot, can be 1 float or two float colors: ["darkblue"] * (len(est) - 1) + ["darkred"] markers: ["o"] * (len(est) - 1) + ["^"] """ if ax is None: ax = plt.gca() assert len(ci) == 2, "Currently must plot 2 CIs" assert ci[0] < ci[1], "Smaller CI should come first" assert np.all([len(xs) == len(dict_loglik[t]) for t in dict_loglik]) trait_list = list(dict_loglik.keys())[::-1] if isinstance(colors, list): colors = colors[::-1] elif isinstance(colors, str): colors = [colors] * len(trait_list) if isinstance(markers, list): markers = markers[::-1] elif isinstance(markers, str): markers = [markers] * len(trait_list) dict_mode = {trait: xs[dict_loglik[trait].argmax()] for trait in trait_list} dict_ci_err: Dict[int, Dict] = {ci[0]: dict(), ci[1]: dict()} for trait in trait_list: mode = dict_mode[trait] for each_ci in ci: hdi = admix.data.hdi(xs, dict_loglik[trait], ci=each_ci) assert not isinstance( hdi, list ), f"HPDI for {trait} contains multiple intervals {hdi}, indicating lack of data. Please rerun this function after remove this trait." dict_ci_err[each_ci][trait] = [mode - hdi[0], hdi[1] - mode] mode = np.array([dict_mode[trait] for trait in trait_list]) lw_list = [2.5, 1.0] for i, each_ci in enumerate(ci): ci_low = [dict_ci_err[each_ci][trait][0] for trait in trait_list] ci_high = [dict_ci_err[each_ci][trait][1] for trait in trait_list] ax.errorbar( y=np.arange(len(mode)), x=mode, xerr=(ci_low, ci_high), fmt=" ", lw=lw_list[i], ecolor=colors, ) if i == 0: for j, trait in enumerate(trait_list): ax.scatter( x=mode[j], y=j, marker=markers[j], color=colors[j], s=s, ) for y in np.arange(len(mode)): ax.axhline(y=y, color="gray", ls="dotted", lw=0.5, alpha=0.8) ax.set_xlim(0, 1.1) ax.set_xlabel("Highest probability density of $r_{admix}$") ax.set_yticks(np.arange(len(mode))) ax.set_ylim(-0.5, len(mode) - 0.5) ax.set_yticklabels( trait_list, fontsize=8, ) # annotation ax.tick_params(left=False, pad=-1) ax.axvline(x=1.0, color="red", ls="--", lw=0.8, alpha=0.4) ax.set_title("Estimated $r_{admix}$", fontsize=10, x=0.5)