import warnings
import matplotlib.pyplot as plt
from matplotlib import collections as mc
from matplotlib import patheffects
import numpy as np
import pandas as pd
from scipy import stats
import admix
from admix.data import quantile_normalize
from admix.data import lambda_gc
from typing import Dict
def pca(
    df_pca: pd.DataFrame,
    x: str = "PC1",
    y: str = "PC2",
    label_col: str = None,
    label_order: list = None,
    s=5,
    legend_loc="on data",
    alpha=None,
    ax=None,
):
    """PCA plot
    Parameters
    ----------
    df_pca : pd.DataFrame
        dataframe with PCA components
    x : str, optional
        x-axis, by default "PC1"
    y : str, optional
        y-axis, by default "PC2"
    label_col : str, optional
        column name for labels, by default None
    s : float, optional
    """
    if alpha is None:
        alpha = 1.0
    else:
        assert isinstance(alpha, float) or isinstance(alpha, dict)
    if ax is None:
        ax = plt.gca()
    if label_order is None:
        label_order = df_pca[label_col].unique()
    ax.set_xlabel(x)
    ax.set_ylabel(y)
    if label_col is None:
        return
    # otherwise label_col is present
    for label in label_order:
        group = df_pca.loc[df_pca[label_col] == label, :]
        if isinstance(alpha, dict):
            label_alpha = alpha[label] if label in alpha else 1.0
        else:
            label_alpha = alpha
        ax.scatter(group[x], group[y], s=s, label=label, alpha=label_alpha)
    if legend_loc == "on data":
        all_pos = (
            pd.DataFrame(df_pca[[x, y, label_col]])
            .groupby(label_col, observed=True)
            .median()
            .sort_index()
        )
        for label, x_pos, y_pos in all_pos.itertuples():
            ax.text(
                x_pos,
                y_pos,
                label,
                # weight="bold",
                path_effects=[patheffects.withStroke(linewidth=2.5, foreground="w")],
                verticalalignment="center",
                horizontalalignment="center",
            )
    legend = ax.legend()
    for lh in legend.legendHandles:
        lh.set_alpha(1)
        lh.set_sizes([30])
def joint_pca(
    df_pc,
    x="PC1",
    y="PC2",
    sample_alpha=0.1,
    axes=None,
    figsize=(8.5, 4),
    label_col="SUPERPOP",
    sample_label="SAMPLE",
):
    """Joint PCA plot
    Parameters
    ----------
    df_pc : pd.DataFrame
        dataframe with PCA components
    eigenval : np.ndarray
        eigenvalues
    """
    new_axes = axes is None
    if new_axes:
        fig, axes = plt.subplots(figsize=figsize, dpi=150, ncols=2)
    admix.plot.pca(
        df_pc[df_pc[label_col] != sample_label],
        x=x,
        y=y,
        label_col=label_col,
        ax=axes[0],
    )
    assert set([x, y]).issubset(
        df_pc.columns
    ), f"{x} and {y} must be in the columns of df_pc"
    x_pos, y_pos = df_pc.columns.get_loc(x), df_pc.columns.get_loc(y)
    xlabel, ylabel = x, y
    axes[0].set_xlabel(xlabel)
    axes[0].set_ylabel(ylabel)
    admix.plot.pca(
        df_pc,
        x=x,
        y=y,
        label_col=label_col,
        alpha={sample_label: sample_alpha},
        ax=axes[1],
    )
    axes[1].set_xlabel(xlabel)
    axes[1].set_ylabel(ylabel)
    if new_axes:
        return fig, axes
[docs]def qq(pval, label=None, ax=None, bootstrap_ci=False):
    """qq plot of p-values
    Parameters
    ----------
    pval : np.ndarray
        p-values, array-like
    ax : matplotlib.axes, optional
        by default None
    return_lambda_gc : bool, optional
        whether to return the lambda GC, by default False
    """
    if ax is None:
        ax = plt.gca()
    pval = np.array(pval)
    pval = pval[~np.isnan(pval)]
    expected_pval = stats.norm.sf(quantile_normalize(-pval))
    ax.scatter(-np.log10(expected_pval), -np.log10(pval), s=2, label=label)
    lim = max(-np.log10(expected_pval))
    ax.plot([0, lim], [0, lim], "r--")
    ax.set_xlabel("Expected -$\log_{10}(p)$")
    ax.set_ylabel("Observed -$\log_{10}(p)$")
    if bootstrap_ci == True:
        lgc, lgc_ci = lambda_gc(pval, bootstrap_ci=True)
    else:
        lgc = lambda_gc(pval, bootstrap_ci=False)
    if bootstrap_ci:
        print(f"lambda GC: {lgc:.3g} [{lgc_ci[0]:.3g}, {lgc_ci[1]:.3g}]")
        return lgc, lgc_ci
    else:
        print(f"lambda GC: {lgc:.3g}")
        return lgc 
[docs]def manhattan(
    pval,
    chrom=None,
    pos=None,
    axh_y=-np.log10(5e-8),
    s=0.1,
    label=None,
    ax=None,
    color="#3b76af",
):
    """Manhatton plot of p-values
    Parameters
    ----------
    chrom : np.ndarray
        array-like
    pval : np.ndarray
        p-values, array-like
    pos: np.ndarray
        array-like, position for each SNP, if provided, position will be used
    axh_y : np.ndarray, optional
        horizontal line for genome-wide significance, by default -np.log10(5e-8)
    s : float, optional
        dot size, by default 0.1
    ax : matplotlib.axes, optional
        axes, by default None
    """
    if ax is None:
        ax = plt.gca()
    assert (chrom is None) or (pos is None), "chrom and pos cannot be both provided"
    if pos is None:
        pos_provided = False
        pos = np.arange(len(pval))
    else:
        pos_provided = True
        assert len(pos) == len(pval)
    if chrom is None:
        # use snp index
        if pos_provided:
            ax.scatter(
                pos / 1e6,
                -np.log10(pval),
                s=s,
                label=label,
                facecolor=color,
                marker="o",
            )
            ax.set_xlabel("SNP position (Mb)")
        else:
            ax.scatter(pos, -np.log10(pval), s=s, label=label, c=color)
            ax.set_xlabel("SNP index")
    else:
        assert pos_provided is False
        assert len(chrom) == len(pval)
        color_list = ["#1b9e77", "#d95f02"]
        # plot dots for odd and even chromosomes
        for mod in range(2):
            index = np.where(chrom % 2 == mod)[0]
            ax.scatter(
                pos[index],
                -np.log10(pval)[index],
                s=s,
                color=color_list[mod],
                label=label,
            )
        # label unique chromosomes
        xticks = []
        xticklabels = []
        for chrom_i in np.unique(chrom):
            xticks.append(np.where(chrom == chrom_i)[0].mean())
            xticklabels.append(chrom_i)
        ax.set_xticks(xticks)
        ax.set_xticklabels(xticklabels, rotation=90, fontsize=8)
        ax.set_xlabel("Chromosome")
    ax.set_ylabel("-$\log_{10}(P)$")
    if axh_y is not None:
        ax.axhline(y=axh_y, color="r", ls="--") 
def susie(pip, dict_cs, pos=None, ax=None, cmap="Set1"):
    cmap = plt.get_cmap(cmap)
    if ax is None:
        ax = plt.gca()
    if pos is None:
        pos = np.arange(len(pip))
        pos_provided = False
    else:
        pos_provided = True
        assert len(pos) == len(pip)
        pos = pos / 1e6
    ax.scatter(x=pos, y=pip, s=3, color="gray")
    for i, cs in enumerate(dict_cs):
        cs_pos = dict_cs[cs]
        ax.scatter(
            x=pos[cs_pos],
            y=pip[cs_pos],
            s=15,
            edgecolors=cmap.colors[i],
            facecolors=cmap.colors[i],
            alpha=0.6,
        )
    if pos_provided:
        ax.set_xlabel("SNP position (Mb)")
    else:
        ax.set_xlabel("SNP index")
    ax.set_ylabel("PIP")
    ax.set_ylim(-0.05, 1.05)
[docs]def lanc(
    dset: admix.Dataset = None,
    lanc: np.ndarray = None,
    ax=None,
    max_indiv: int = None,
) -> None:
    """
    Plot local ancestry.
    Parameters
    ----------
    dset: xarray.Dataset
        A dataset containing the local ancestry matrix.
    lanc: np.ndarray
        A numpy array of shape (n_snp, n_indiv, 2)
    ax: matplotlib.Axes
        A matplotlib axes object to plot on. If None, will create a new one.
    max_indiv: int
        The maximum number of individuals to plot.
        If None, will plot the first 10 individuals
    Returns
    -------
    ax: matplotlib.Axes
    """
    # if dataset is provided, use it to extract lanc
    if dset is not None:
        lanc = dset.lanc.compute()
        pos = dset.snp.POS.values
        BP_POS = True
    else:
        assert lanc is not None, "either dataset or lanc must be provided"
        pos = np.arange(lanc.shape[0])
        BP_POS = False
    # append dummy snp at the end to make plotting easier
    pos = np.concatenate([pos, [pos[-1] + 1]])
    assert lanc.shape[2] == 2, "lanc must be of shape (n_snp, n_indiv, 2)"
    n_snp, n_indiv = lanc.shape[0:2]
    if max_indiv is not None:
        n_plot_indiv = min(max_indiv, n_indiv)
    else:
        n_plot_indiv = min(n_indiv, 10)
        if n_plot_indiv < n_indiv:
            warnings.warn(
                f"Only the first {n_plot_indiv} are plotted. To plot more individuals, increase `max_indiv`"
            )
    if ax is None:
        ax = plt.gca()
    start = []
    stop = []
    label = []
    row = []
    for i_indiv in range(n_plot_indiv):
        for i_ploidy in range(2):
            a = lanc[:, i_indiv, i_ploidy]
            switch = np.where(a[1:] != a[0:-1])[0]
            switch = np.concatenate([[0], switch, [len(a)]])
            for i_switch in range(len(switch) - 1):
                start_idx, stop_idx = switch[i_switch], switch[i_switch + 1]
                if BP_POS:
                    start.append(pos[start_idx] / 1e6)
                    stop.append(pos[stop_idx] / 1e6)
                else:
                    start.append(start_idx)
                    stop.append(stop_idx)
                label.append(a[start_idx + 1])
                row.append(i_indiv - 0.1 + i_ploidy * 0.2)
    df_plot = pd.DataFrame({"start": start, "stop": stop, "label": label, "row": row})
    lines = [[(r.start, r.row), (r.stop, r.row)] for _, r in df_plot.iterrows()]
    cmap = plt.get_cmap("tab10")
    for i, (label, group) in enumerate(df_plot.groupby("label")):
        lc = mc.LineCollection(
            [lines[l_i] for l_i in group.index],
            linewidths=2,
            label=label,
            color=cmap(i),
        )
        ax.add_collection(lc)
    ax.legend()
    ax.autoscale()
    if BP_POS:
        ax.set_xlabel("SNP position (Mb)")
    else:
        ax.set_xlabel("SNP index")
    ax.set_ylabel("Individuals")
    ax.set_yticks([])
    ax.set_yticklabels([]) 
[docs]def admixture(
    a: np.ndarray,
    labels=None,
    label_orders=None,
    ax=None,
) -> None:
    """
    Plot admixture.
    Parameters
    ----------
    a: np.ndarray
        A numpy array of shape (n_indiv, n_snp, 2)
    labels: list
        A list of labels for each individual.
    label_orders: list
        A list of orderings for the individuals.
    ax: matplotlib.Axes
        A matplotlib axes object to plot on. If None, will create a new one.
    Returns
    -------
    None
    """
    n_indiv, n_pop = a.shape
    # reorder based on labels
    if labels is not None:
        dict_label_range = dict()
        reordered_a = []
        unique_labels = np.unique(labels)
        if label_orders is not None:
            assert set(label_orders) == set(
                unique_labels
            ), "label_orders must cover all unique labels"
            unique_labels = label_orders
        cumsum = 0
        for label in unique_labels:
            reordered_a.append(a[labels == label, :])
            dict_label_range[label] = [cumsum, cumsum + sum(labels == label)]
            cumsum += sum(labels == label)
        a = np.vstack(reordered_a)
    if ax is None:
        ax = plt.gca()
    cmap = plt.get_cmap("tab10")
    bottom = np.zeros(n_indiv)
    for i_pop in range(n_pop):
        ax.bar(
            np.arange(n_indiv),
            height=a[:, i_pop],
            width=1,
            bottom=bottom,
            facecolor=cmap(i_pop),
            edgecolor=cmap(i_pop),
        )
        bottom += a[:, i_pop]
    ax.tick_params(axis="both", left=False, labelleft=False)
    if labels is not None:
        seps = sorted(np.unique(np.concatenate([r for r in dict_label_range.values()])))
        for x in seps[1:-1]:
            ax.vlines(x - 0.5, ymin=0, ymax=1, color="black")
        ax.set_xticks([np.mean(dict_label_range[label]) for label in dict_label_range])
        ax.set_xticklabels([label for label in dict_label_range])
    else:
        ax.get_xaxis().set_ticks([])
    for pos in ["top", "right", "bottom", "left"]:
        ax.spines[pos].set_visible(False)
    return ax 
def compare_pval(
    x_pval: np.ndarray,
    y_pval: np.ndarray,
    xlabel: str = None,
    ylabel: str = None,
    ax=None,
    s: int = 5,
):
    """Compare two p-values.
    Parameters
    ----------
    x_pval: np.ndarray
        The p-value for the first variable.
    y_pval: np.ndarray
        The p-value for the second variable.
    xlabel: str
        The label for the first variable.
    ylabel: str
        The label for the second variable.
    ax: matplotlib.Axes
        A matplotlib axes object to plot on. If None, will create a new one.
    """
    if ax is None:
        ax = plt.gca()
    if not isinstance(x_pval, np.ndarray):
        x_pval = np.array(x_pval)
    if not isinstance(y_pval, np.ndarray):
        y_pval = np.array(y_pval)
    nonnan_idx = ~np.isnan(x_pval) & ~np.isnan(y_pval)
    x_pval, y_pval = -np.log10(x_pval[nonnan_idx]), -np.log10(y_pval[nonnan_idx])
    ax.scatter(x_pval, y_pval, s=s)
    lim = max(np.nanmax(x_pval), np.nanmax(y_pval)) * 1.1
    ax.plot([0, lim], [0, lim], "k--", alpha=0.5, lw=1, label="y=x")
    # add a regression line
    slope = np.linalg.lstsq(x_pval[:, None], y_pval[:, None], rcond=None)[0].item()
    ax.axline(
        (0, 0),
        slope=slope,
        color="black",
        ls="--",
        lw=1,
        label=f"y={slope:.2f} x",
    )
    ax.legend()
    if xlabel is not None:
        ax.set_xlabel(xlabel)
    if ylabel is not None:
        ax.set_ylabel(ylabel)
def rg_posterior(
    xs: np.ndarray,
    dict_loglik: Dict[str, np.ndarray],
    ci=[0.5, 0.95],
    s=11,
    colors="black",
    markers="o",
    ax=None,
):
    """
    Plot the posterior distribution
    Parameters
    ----------
    xs: np.ndarray
        list of x coordinates
    dict_loglik: Dict[np.ndarray]
        trait -> list of log-likelihoods
    ci: Union[float, List[float]]
        ci to plot, can be 1 float or two float
    colors:
        ["darkblue"] * (len(est) - 1) + ["darkred"]
    markers:
        ["o"] * (len(est) - 1) + ["^"]
    """
    if ax is None:
        ax = plt.gca()
    assert len(ci) == 2, "Currently must plot 2 CIs"
    assert ci[0] < ci[1], "Smaller CI should come first"
    assert np.all([len(xs) == len(dict_loglik[t]) for t in dict_loglik])
    trait_list = list(dict_loglik.keys())[::-1]
    if isinstance(colors, list):
        colors = colors[::-1]
    elif isinstance(colors, str):
        colors = [colors] * len(trait_list)
    if isinstance(markers, list):
        markers = markers[::-1]
    elif isinstance(markers, str):
        markers = [markers] * len(trait_list)
    dict_mode = {trait: xs[dict_loglik[trait].argmax()] for trait in trait_list}
    dict_ci_err: Dict[int, Dict] = {ci[0]: dict(), ci[1]: dict()}
    for trait in trait_list:
        mode = dict_mode[trait]
        for each_ci in ci:
            hdi = admix.data.hdi(xs, dict_loglik[trait], ci=each_ci)
            assert not isinstance(
                hdi, list
            ), f"HPDI for {trait} contains multiple intervals {hdi}, indicating lack of data. Please rerun this function after remove this trait."
            dict_ci_err[each_ci][trait] = [mode - hdi[0], hdi[1] - mode]
    mode = np.array([dict_mode[trait] for trait in trait_list])
    lw_list = [2.5, 1.0]
    for i, each_ci in enumerate(ci):
        ci_low = [dict_ci_err[each_ci][trait][0] for trait in trait_list]
        ci_high = [dict_ci_err[each_ci][trait][1] for trait in trait_list]
        ax.errorbar(
            y=np.arange(len(mode)),
            x=mode,
            xerr=(ci_low, ci_high),
            fmt=" ",
            lw=lw_list[i],
            ecolor=colors,
        )
        if i == 0:
            for j, trait in enumerate(trait_list):
                ax.scatter(
                    x=mode[j],
                    y=j,
                    marker=markers[j],
                    color=colors[j],
                    s=s,
                )
    for y in np.arange(len(mode)):
        ax.axhline(y=y, color="gray", ls="dotted", lw=0.5, alpha=0.8)
    ax.set_xlim(0, 1.1)
    ax.set_xlabel("Highest probability density of $r_{admix}$")
    ax.set_yticks(np.arange(len(mode)))
    ax.set_ylim(-0.5, len(mode) - 0.5)
    ax.set_yticklabels(
        trait_list,
        fontsize=8,
    )
    # annotation
    ax.tick_params(left=False, pad=-1)
    ax.axvline(x=1.0, color="red", ls="--", lw=0.8, alpha=0.4)
    ax.set_title("Estimated $r_{admix}$", fontsize=10, x=0.5)