Skip to content

CoMPASS Level 2 - Models

Bayesian Gaussian Mixture Models and GMM-HMM for cognitive state inference.

Overview

This module implements the Level 2 models that identify goal-directed cognitive states:

  • BGMM (Bayesian Gaussian Mixture Model): Clusters behavioral features
  • GMM-HMM: Models temporal dependencies across states
  • Cross-validation for model selection
  • State sequence decoding

The models identify:

  • Oriented states: Goal-directed behavior toward target
  • Non-Oriented states: Exploratory or non-goal-directed behavior

Model Selection

Use run_compass with cross-validation to automatically select the optimal number of states.


Main Modeling Function

run_compass

compass_labyrinth.compass.level_2.model.run_compass

run_compass(
    config: dict,
    df: DataFrame,
    features: list,
    phase_options: list = [5],
    ncomp_options: list[int] = [2, 3],
    k_options: list[int] = [2, 3],
    reg_options: list = [0.0001, 1e-05, 1e-06],
    terminal_values: list = [47],
    bout_col: str = "Bout_ID",
    patience: None | str = None,
    patience_candidates: list = [2, 3, 5, 10],
    verbose: bool = False,
) -> tuple[pd.DataFrame, list]

Run CoMPASS.

Parameters:

  • config (dict) –

    Configuration dictionary.

  • df (DataFrame) –

    Input dataframe.

  • features (list) –

    List of feature column names to use.

  • phase_options (list, default: [5] ) –

    List of phase options to test (default is [5]).

  • ncomp_options (list, default: [2, 3] ) –

    Range of number of components to test (default is [2, 3]).

  • k_options (list, default: [2, 3] ) –

    Range of k values to test (default is [2, 3]).

  • reg_options (list, default: [0.0001, 1e-05, 1e-06] ) –

    List of regularization values to test (default is [1e-4, 1e-5, 1e-6]).

  • terminal_values (list, default: [47] ) –

    List of terminal grid values (default is [47]).

  • bout_col (str, default: 'Bout_ID' ) –

    Name of the bout column (default is "Bout_ID").

  • patience (None or str, default: None ) –

    Patience setting for early stopping (default is None). Set to 'tune' if wanted to apply patience window.

  • patience_candidates (list, default: [2, 3, 5, 10] ) –

    List of patience candidates to test if patience is 'tune' (default is [2, 3, 5, 10]). Only used if patience is set to 'tune'.

  • verbose (bool, default: False ) –

    Whether to print detailed logs during model training (default is False).

Returns:

  • tuple

    A tuple containing: - pd.DataFrame: DataFrame with assigned Level 2 states. - list: List of all CV results for visualization.

Source code in src/compass_labyrinth/compass/level_2/model.py
def run_compass(
    config: dict,
    df: pd.DataFrame,
    features: list,
    phase_options: list = [5],
    ncomp_options: list[int] = [2, 3],
    k_options: list[int] = [2, 3],
    reg_options: list = [1e-4, 1e-5, 1e-6],
    terminal_values: list = [47],
    bout_col: str = "Bout_ID",
    patience: None | str = None,
    patience_candidates: list = [2, 3, 5, 10],
    verbose: bool = False,
) -> tuple[pd.DataFrame, list]:
    """
    Run CoMPASS.

    Parameters
    -----------
    config : dict
        Configuration dictionary.
    df : pd.DataFrame
        Input dataframe.
    features : list
        List of feature column names to use.
    phase_options : list, optional
        List of phase options to test (default is [5]).
    ncomp_options : list, optional
        Range of number of components to test (default is [2, 3]).
    k_options : list, optional
        Range of k values to test (default is [2, 3]).
    reg_options : list, optional
        List of regularization values to test (default is [1e-4, 1e-5, 1e-6]).
    terminal_values : list, optional
        List of terminal grid values (default is [47]).
    bout_col : str, optional
        Name of the bout column (default is "Bout_ID").
    patience : None or str, optional
        Patience setting for early stopping (default is None).
        Set to 'tune' if wanted to apply patience window.
    patience_candidates : list, optional
        List of patience candidates to test if patience is 'tune' (default is [2, 3, 5, 10]).
        Only used if patience is set to 'tune'.
    verbose : bool, optional
        Whether to print detailed logs during model training (default is False).

    Returns
    --------
    tuple
        A tuple containing:
        - pd.DataFrame: DataFrame with assigned Level 2 states.
        - list: List of all CV results for visualization.
    """
    all_results = []
    final_sess_data = []

    df = assign_bouts_per_session(df, terminal_values=terminal_values, bout_col=bout_col)

    for n_phases in phase_options:
        sessions = df.Session.unique()
        phase_labels = range(n_phases)
        phase_map = build_phase_map(df, n_phases)

        for phase_index in phase_labels:
            for test_sess in sessions:
                print(f"\n=== CV | Test: Session {test_sess} - Phase {phase_index+1}/{n_phases} ===")

                test_bouts = phase_map[(test_sess, phase_index)]
                df_test = df[(df.Session == test_sess) & (df[bout_col].isin(test_bouts))]

                train_sessions = [s for s in sessions if s != test_sess]
                df_train_pool = pd.concat(
                    [df[(df.Session == s) & (df[bout_col].isin(phase_map[(s, phase_index)]))] for s in train_sessions]
                )

                inner_sessions = df_train_pool.Session.unique()
                best_log_lik = -np.inf
                best_aic = np.inf
                best_model = None

                best_patience = patience_candidates[0] if patience == "tune" else patience
                patience_results = {}

                for test_patience in (patience_candidates if patience == "tune" else [best_patience]):
                    log_liks, aics, param_labels = [], [], []
                    no_improve = 0
                    best_inner_loglik = -np.inf

                    early_stopped = False

                    for val_sess in inner_sessions:
                        df_val = df_train_pool[df_train_pool.Session == val_sess]
                        df_train = df_train_pool[df_train_pool.Session != val_sess]

                        for ncomp in ncomp_options:
                            for k in k_options:
                                for reg_val in reg_options:
                                    try:
                                        X_train = df_train[features].values
                                        X_val = df_val[features].values

                                        gmm = initialize_bgmm(ncomp, reg_val)
                                        gmm.fit(X_train)
                                        covariances = regularize_covariances(gmm.covariances_, reg_val)

                                        means = np.tile(gmm.means_[None, :, :], (ncomp, 1, 1))
                                        covars = np.tile(covariances[None, :, :, :], (ncomp, 1, 1, 1))
                                        weights = np.tile(gmm.weights_[None, :], (ncomp, 1))

                                        model = initialize_gmmhmm(ncomp, k, means, covars, weights)
                                        model.fit(X_train)

                                        log_lik = model.score(X_val)
                                        aic = compute_aic(model, X_train)
                                        label = f"n={ncomp},k={k},r={reg_val:.0e}"

                                        log_liks.append(log_lik)
                                        aics.append(aic)
                                        param_labels.append(label)

                                        if log_lik > best_inner_loglik:
                                            best_inner_loglik = log_lik
                                            best_model = model
                                            no_improve = 0
                                        else:
                                            no_improve += 1

                                        if test_patience is not None and no_improve >= test_patience:
                                            early_stopped = True
                                            break
                                    except Exception as e:
                                        if verbose:
                                            print(f"Model failed for n={ncomp}, k={k}, reg={reg_val:.0e}: {e}")
                                        continue
                                if early_stopped:
                                    break
                            if early_stopped:
                                break

                    avg_loglik = np.mean(log_liks)
                    patience_results[test_patience] = (avg_loglik, best_model, log_liks, aics, param_labels)

                # Pick best patience if tuning
                if patience == "tune":
                    best_patience, (best_avg_loglik, best_model, log_liks, aics, param_labels) = max(
                        patience_results.items(), key=lambda x: x[1][0]
                    )
                    print(f"Optimal patience for Session {test_sess}, Phase {phase_index+1}: {best_patience}")

                if best_model is not None:
                    X_test = df_test[features].values
                    df_test = df_test.copy()
                    df_test["Level_2_States"] = best_model.predict(X_test)
                    final_sess_data.append(df_test)

                tag = f"Session:{test_sess}_PhaseIndex:{phase_index+1}_NumPhases:{n_phases}_Patience:{best_patience}"
                all_results.append((tag, log_liks, aics, param_labels))

    df_hier = pd.concat(final_sess_data)

    # Save results
    save_path = Path(config["project_path_full"]) / "csvs" / "combined" / "hhmm_state_file.csv"
    df_hier.to_csv(save_path, index=False)
    print(f"HHMM state file saved at: {save_path}")

    return df_hier, all_results

Model Initialization

initialize_bgmm

compass_labyrinth.compass.level_2.model.initialize_bgmm

initialize_bgmm(n_components, reg_val, random_state=3)
Source code in src/compass_labyrinth/compass/level_2/model.py
def initialize_bgmm(n_components, reg_val, random_state=3):
    return BayesianGaussianMixture(
        n_components=n_components, covariance_type="full", reg_covar=reg_val, random_state=random_state, max_iter=500
    )

initialize_gmmhmm

compass_labyrinth.compass.level_2.model.initialize_gmmhmm

initialize_gmmhmm(
    n_states,
    n_mix,
    means,
    covariances,
    weights,
    random_state=3,
)
Source code in src/compass_labyrinth/compass/level_2/model.py
def initialize_gmmhmm(n_states, n_mix, means, covariances, weights, random_state=3):
    model = GMMHMM(n_components=n_states, n_mix=n_mix, covariance_type="full", random_state=random_state, n_iter=500)
    model.means_ = means[:, :n_mix, :]
    model.covars_ = covariances[:, :n_mix, :, :]
    model.weights_ = weights[:, :n_mix]
    return model

Model Evaluation

compute_aic

compass_labyrinth.compass.level_2.model.compute_aic

compute_aic(model, X)
Source code in src/compass_labyrinth/compass/level_2/model.py
def compute_aic(model, X):
    log_likelihood = model.score(X)
    n_states = model.n_components
    n_features = model.means_.shape[-1]
    n_mixtures = model.n_mix
    num_params = (
        (n_states - 1)  # start prob
        + n_states * (n_states - 1)  # transition matrix
        + n_states * n_mixtures * n_features  # means
        + n_states * n_mixtures * n_features  # covariances
        + n_states * n_mixtures  # weights
    )
    return 2 * num_params - 2 * log_likelihood

visualize_cv_results

compass_labyrinth.compass.level_2.model.visualize_cv_results

visualize_cv_results(
    config: dict,
    all_results: list,
    save_fig: bool = True,
    show_fig: bool = True,
    return_fig: bool = False,
) -> None | list[plt.Figure]

Visualize cross-validation results.

Parameters:

  • config (dict) –

    Configuration dictionary.

  • all_results (list) –

    List of tuples containing CV results.

  • save_fig (bool, default: True ) –

    Whether to save the figures (default is True).

  • show_fig (bool, default: True ) –

    Whether to show the figures (default is True).

  • return_fig (bool, default: False ) –

    Whether to return the figures (default is False).

Returns:

  • None or list of plt.Figure

    List of figures if return_fig is True, otherwise None.

Source code in src/compass_labyrinth/compass/level_2/model.py
def visualize_cv_results(
    config: dict,
    all_results: list,
    save_fig: bool = True,
    show_fig: bool = True,
    return_fig: bool = False,
) -> None | list[plt.Figure]:
    """
    Visualize cross-validation results.

    Parameters
    -----------
    config : dict
        Configuration dictionary.
    all_results : list
        List of tuples containing CV results.
    save_fig : bool, optional
        Whether to save the figures (default is True).
    show_fig : bool, optional
        Whether to show the figures (default is True).
    return_fig : bool, optional
        Whether to return the figures (default is False).

    Returns
    --------
    None or list of plt.Figure
        List of figures if return_fig is True, otherwise None.
    """
    all_figs = list()
    for tag, log_liks, aics, param_labels in all_results:
        fig = plt.figure(figsize=(14, 5))

        plt.subplot(1, 2, 1)
        sns.lineplot(x=np.arange(len(log_liks)), y=log_liks, marker="o")
        plt.xticks(ticks=np.arange(len(param_labels)), labels=param_labels, rotation=90)
        plt.title(f"{tag} - Log-Likelihoods")
        plt.xlabel("Param Config (n,k,reg)")
        plt.ylabel("Log-Likelihood")

        plt.subplot(1, 2, 2)
        sns.lineplot(x=np.arange(len(aics)), y=aics, marker="o")
        plt.xticks(ticks=np.arange(len(param_labels)), labels=param_labels, rotation=90)
        plt.title(f"{tag} - AIC")
        plt.xlabel("Param Config (n,k,reg)")
        plt.ylabel("AIC")

        plt.tight_layout()

        # Save figure
        if save_fig:
            prefix = tag.replace(":", "-").replace(",", "_").replace(" ", "_")
            save_path = Path(config["project_path_full"]) / "figures" / f"level_2_cv_performance_{prefix}.pdf"
            plt.savefig(save_path, bbox_inches="tight", dpi=300)
            print(f"Figure saved at: {save_path}")

        # Show figure
        if show_fig:
            plt.show()

        # Return figure
        if return_fig:
            all_figs.append(fig)

    if return_fig:
        return all_figs

regularize_covariances

compass_labyrinth.compass.level_2.model.regularize_covariances

regularize_covariances(covariances, reg_val=1e-06)
Source code in src/compass_labyrinth/compass/level_2/model.py
def regularize_covariances(covariances, reg_val=1e-6):
    return np.array([cov + np.eye(cov.shape[0]) * reg_val for cov in covariances])

State Analysis

get_unique_states

compass_labyrinth.compass.level_2.model.get_unique_states

get_unique_states(df, state_col='Level_2_States')

Return sorted unique states from the specified column.

Source code in src/compass_labyrinth/compass/level_2/model.py
def get_unique_states(df, state_col="Level_2_States"):
    """Return sorted unique states from the specified column."""
    return sorted(df[state_col].dropna().unique())

generate_state_color_map

compass_labyrinth.compass.level_2.model.generate_state_color_map

generate_state_color_map(unique_states, palette='tab10')

Assign each state a color from the selected Seaborn palette.

Source code in src/compass_labyrinth/compass/level_2/model.py
def generate_state_color_map(unique_states, palette="tab10"):
    """Assign each state a color from the selected Seaborn palette."""
    colors = sns.color_palette(palette, len(unique_states))
    return {state: colors[i] for i, state in enumerate(unique_states)}

assign_reward_orientation

compass_labyrinth.compass.level_2.model.assign_reward_orientation

assign_reward_orientation(
    df: DataFrame,
    angle_col: str = "Targeted_Angle_smooth_abs",
    level_2_state_col: str = "Level_2_States",
    session_col: str = "Session",
) -> pd.DataFrame

Assigns reward orientation labels ('Reward Oriented' or 'Non-Reward Oriented') to Level 2 states per session, based on the median Targeted_Angle_smooth within each state and relative to session median.

Parameters:

  • df (DataFrame) –

    Input dataframe with columns for session, level 2 state, and angle.

  • angle_col (str, default: 'Targeted_Angle_smooth_abs' ) –

    Column name representing the smoothed targeted angle.

  • level_2_state_col (str, default: 'Level_2_States' ) –

    Column name for level 2 HMM states.

  • session_col (str, default: 'Session' ) –

    Column name for session identifier.

Returns:

  • DataFrame

    Updated dataframe with 'Reward_Oriented' column.

Source code in src/compass_labyrinth/compass/level_2/model.py
def assign_reward_orientation(
    df: pd.DataFrame,
    angle_col: str = "Targeted_Angle_smooth_abs",
    level_2_state_col: str = "Level_2_States",
    session_col: str = "Session",
) -> pd.DataFrame:
    """
    Assigns reward orientation labels ('Reward Oriented' or 'Non-Reward Oriented') to Level 2 states per session,
    based on the median Targeted_Angle_smooth within each state and relative to session median.

    Parameters
    -----------
    df : pd.DataFrame
        Input dataframe with columns for session, level 2 state, and angle.
    angle_col : str
        Column name representing the smoothed targeted angle.
    level_2_state_col : str
        Column name for level 2 HMM states.
    session_col : str
        Column name for session identifier.

    Returns
    --------
    pd.DataFrame
        Updated dataframe with 'Reward_Oriented' column.
    """
    df = df.copy()
    df["Reward_Oriented"] = np.nan

    for sess in df[session_col].unique():
        df_sess = df[df[session_col] == sess]
        state_medians = df_sess.groupby(level_2_state_col)[angle_col].median()
        session_median = df_sess[angle_col].median()

        if len(state_medians) == 2:
            reward_state = state_medians.idxmin()
            non_reward_state = state_medians.idxmax()
            df.loc[(df[session_col] == sess) & (df[level_2_state_col] == reward_state), "Reward_Oriented"] = (
                "Reward Oriented"
            )
            df.loc[(df[session_col] == sess) & (df[level_2_state_col] == non_reward_state), "Reward_Oriented"] = (
                "Non-Reward Oriented"
            )

        elif len(state_medians) == 3:
            sorted_states = state_medians.sort_values()
            reward_state = sorted_states.index[0]
            non_reward_state = sorted_states.index[2]
            middle_state = sorted_states.index[1]

            if state_medians[middle_state] <= session_median:
                middle_label = "Reward Oriented"
            else:
                middle_label = "Non-Reward Oriented"

            df.loc[(df[session_col] == sess) & (df[level_2_state_col] == reward_state), "Reward_Oriented"] = (
                "Reward Oriented"
            )
            df.loc[(df[session_col] == sess) & (df[level_2_state_col] == non_reward_state), "Reward_Oriented"] = (
                "Non-Reward Oriented"
            )
            df.loc[(df[session_col] == sess) & (df[level_2_state_col] == middle_state), "Reward_Oriented"] = (
                middle_label
            )

    return df

assign_hhmm_state

compass_labyrinth.compass.level_2.model.assign_hhmm_state

assign_hhmm_state(
    df: DataFrame,
    level_1_state_col: str,
    level_2_state_col: str,
) -> pd.DataFrame

Assigns a final HHMM (Hierarchical Hidden Markov Model) state to the dataframe. The final HHMM state is based on the combination of level 1 and level 2 states.

Parameters:

  • df (DataFrame) –

    Dataframe containing the level_1_state_col and level_2_state_col columns.

  • level_1_state_col (str) –

    The name of the column representing the first-level HMM state.

  • level_2_state_col (str) –

    The name of the column representing the second-level state (reward-oriented or not).

Returns:

  • df ( DataFrame ) –

    DataFrame with an additional 'HHMM State' column indicating the final HHMM state.

Source code in src/compass_labyrinth/compass/level_2/model.py
def assign_hhmm_state(
    df: pd.DataFrame,
    level_1_state_col: str,
    level_2_state_col: str,
) -> pd.DataFrame:
    """
    Assigns a final HHMM (Hierarchical Hidden Markov Model) state to the dataframe.
    The final HHMM state is based on the combination of level 1 and level 2 states.

    Parameters
    -----------
    df : pd.DataFrame
        Dataframe containing the level_1_state_col and level_2_state_col columns.
    level_1_state_col : str
        The name of the column representing the first-level HMM state.
    level_2_state_col : str
        The name of the column representing the second-level state (reward-oriented or not).

    Returns
    --------
    df : pd.DataFrame
        DataFrame with an additional 'HHMM State' column indicating the final HHMM state.
    """
    # Define the conditions for assigning the HHMM states
    conds = [
        (df[level_1_state_col] == 1) & (df[level_2_state_col] == "Non-Reward Oriented"),
        (df[level_1_state_col] == 1) & (df[level_2_state_col] == "Reward Oriented"),
        (df[level_1_state_col] == 2) & (df[level_2_state_col] == "Non-Reward Oriented"),
        (df[level_1_state_col] == 2) & (df[level_2_state_col] == "Reward Oriented"),
    ]

    # Define the labels for the corresponding HHMM states
    labels = [
        "Surveillance, Non-Reward Oriented",
        "Surveillance, Reward Oriented",
        "Ambulatory, Non-Reward Oriented",
        "Ambulatory, Reward Oriented",
    ]

    # Set the default value as 'NaN' (string) to match the data type of the labels
    df["HHMM State"] = np.select(conds, labels, default="NaN")

    return df

State Sequence Visualization

plot_state_sequence_for_session

compass_labyrinth.compass.level_2.model.plot_state_sequence_for_session

plot_state_sequence_for_session(
    df_session: DataFrame,
    state_col: str = "Level_2_States",
    color_map: None | dict = None,
    title_prefix: str = "State Sequence",
) -> plt.Figure

Plot the state sequence using color bars for one session.

Source code in src/compass_labyrinth/compass/level_2/model.py
def plot_state_sequence_for_session(
    df_session: pd.DataFrame,
    state_col: str = "Level_2_States",
    color_map: None | dict = None,
    title_prefix: str = "State Sequence",
) -> plt.Figure:
    """Plot the state sequence using color bars for one session."""
    df_session = df_session.reset_index(drop=True).copy()
    df_session["color"] = df_session[state_col].map(color_map)

    fig, ax = plt.subplots(figsize=(10, 3))
    for idx, row in df_session.iterrows():
        rect = patches.Rectangle((idx, 0), 1, 1, color=row["color"])
        ax.add_patch(rect)

    ax.set_xlim(df_session.index.min(), df_session.index.max() + 1)
    ax.set_yticks([])
    ax.set_title(f"{title_prefix} - Session {df_session['Session'].iloc[0]}")

    legend_handles = [patches.Patch(color=color_map[state], label=f"State {state}") for state in color_map]
    ax.legend(
        handles=legend_handles,
        title="States",
        bbox_to_anchor=(0.5, -0.15),
        loc="upper center",
        borderaxespad=0.0,
    )

    plt.tight_layout()
    return fig

plot_state_sequences

compass_labyrinth.compass.level_2.model.plot_state_sequences

plot_state_sequences(
    config: dict,
    df: DataFrame,
    genotype: str = "WT-WT",
    state_col: str = "Level_2_States",
    sessions_to_plot: str | list | int = "all",
    title_prefix: str = "State Sequence",
    save_fig: bool = True,
    show_fig: bool = True,
    return_fig: bool = False,
) -> None | list[plt.Figure]

Plot state sequences for specified sessions and genotype.

Source code in src/compass_labyrinth/compass/level_2/model.py
def plot_state_sequences(
    config: dict,
    df: pd.DataFrame,
    genotype: str = "WT-WT",
    state_col: str = "Level_2_States",
    sessions_to_plot: str | list | int = "all",  # Can be 'all', a list of session IDs, or an int (top n)
    title_prefix: str = "State Sequence",
    save_fig: bool = True,
    show_fig: bool = True,
    return_fig: bool = False,
) -> None | list[plt.Figure]:
    """Plot state sequences for specified sessions and genotype."""
    df_geno = df[df["Genotype"] == genotype]
    unique_states = get_unique_states(df_geno, state_col=state_col)
    color_map = generate_state_color_map(unique_states)

    # Determine which sessions to plot
    all_sessions = df_geno["Session"].unique()
    if isinstance(sessions_to_plot, int):
        selected_sessions = all_sessions[:sessions_to_plot]
    elif isinstance(sessions_to_plot, list):
        selected_sessions = [s for s in sessions_to_plot if s in all_sessions]
    else:
        selected_sessions = all_sessions

    # Plot each selected session
    all_figs = []
    for sess_id in selected_sessions:
        df_sess = df_geno[df_geno["Session"] == sess_id][[state_col, "Session"]]
        fig = plot_state_sequence_for_session(
            df_sess,
            state_col=state_col,
            color_map=color_map,
            title_prefix=title_prefix,
        )

        # Save figure
        if save_fig:
            save_path = Path(config["project_path_full"]) / "figures" / f"state_sequence_session_{sess_id}.pdf"
            fig.savefig(save_path, bbox_inches="tight", dpi=300)
            print(f"Figure saved at: {save_path}")

        # Show figure
        if show_fig:
            plt.show()

        # Return figure
        if return_fig:
            all_figs.append(fig)

    if return_fig:
        return all_figs

plot_hhmm_state_sequence

compass_labyrinth.compass.level_2.model.plot_hhmm_state_sequence

plot_hhmm_state_sequence(
    config: dict,
    df: DataFrame,
    session_col: str = "Session",
    state_col: str = "HHMM State",
    session_id: None | int = None,
    title_prefix: str = "State Sequence",
    colors: None | dict = None,
    save_fig: bool = True,
    show_fig: bool = True,
    return_fig: bool = False,
) -> None | plt.Figure

Plots a rectangular sequence of HHMM states for a given session.

Parameters:

  • config (dict) –

    Configuration dictionary for the project.

  • df (DataFrame) –

    DataFrame containing session and HHMM state columns.

  • session_col (str, default: 'Session' ) –

    Name of the column indicating session.

  • state_col (str, default: 'HHMM State' ) –

    Name of the column containing HHMM state labels.

  • session_id (None | int, default: None ) –

    Specific session to plot. If None, plots all sessions.

  • title_prefix (str, default: 'State Sequence' ) –

    Custom title prefix for plots.

  • colors (None | dict, default: None ) –

    Dictionary mapping HHMM states to colors. If None, default colors are used.

  • save_fig (bool, default: True ) –

    Whether to save the figure.

  • show_fig (bool, default: True ) –

    Whether to display the figure.

  • return_fig (bool, default: False ) –

    Whether to return the figure object.

Returns:

  • None or Figure

    The figure object if return_fig is True, otherwise None.

Source code in src/compass_labyrinth/compass/level_2/model.py
def plot_hhmm_state_sequence(
    config: dict,
    df: pd.DataFrame,
    session_col: str = "Session",
    state_col: str = "HHMM State",
    session_id: None | int = None,
    title_prefix: str = "State Sequence",
    colors: None | dict = None,
    save_fig: bool = True,
    show_fig: bool = True,
    return_fig: bool = False,
) -> None | plt.Figure:
    """
    Plots a rectangular sequence of HHMM states for a given session.

    Parameters
    -----------
    config : dict
        Configuration dictionary for the project.
    df : pd.DataFrame
        DataFrame containing session and HHMM state columns.
    session_col : str
        Name of the column indicating session.
    state_col : str
        Name of the column containing HHMM state labels.
    session_id : None | int
        Specific session to plot. If None, plots all sessions.
    title_prefix : str
        Custom title prefix for plots.
    colors : None | dict
        Dictionary mapping HHMM states to colors. If None, default colors are used.
    save_fig : bool
        Whether to save the figure.
    show_fig : bool
        Whether to display the figure.
    return_fig : bool
        Whether to return the figure object.

    Returns
    --------
    None or plt.Figure
        The figure object if return_fig is True, otherwise None.
    """
    sessions_to_plot = [session_id] if session_id is not None else df[session_col].unique()

    if colors is None:
        # Get unique states from all sessions to be plotted and generate color map
        unique_states = sorted(df[df[session_col].isin(sessions_to_plot)][state_col].dropna().unique())
        colors = generate_state_color_map(unique_states)
    all_figs = []
    for sess in sessions_to_plot:
        test = df.loc[df[session_col] == sess, [state_col]].reset_index(drop=True)
        test["color"] = test[state_col].map(colors)

        fig, ax = plt.subplots(figsize=(10, 3))
        for idx, row in test.iterrows():
            rect = patches.Rectangle((idx, 0), 1, 1, color=row["color"])
            ax.add_patch(rect)

        plt.yticks([])
        plt.xlim(test.index.min(), test.index.max() + 1)
        plt.title(f"{title_prefix} - Session {sess}")
        handles = [patches.Patch(color=color, label=label) for label, color in colors.items()]
        ax.legend(handles=handles, title="States", bbox_to_anchor=(0.5, -0.15), loc="upper center", borderaxespad=0.0)
        plt.tight_layout()

        # Save figure
        if save_fig:
            save_path = Path(config["project_path_full"]) / "figures" / f"hhmm_state_sequence_session_{sess}.pdf"
            plt.savefig(save_path, bbox_inches="tight", dpi=300)
            print(f"Figure saved at: {save_path}")

        # Show figure
        if show_fig:
            plt.show()

        # Return figure
        if return_fig:
            all_figs.append(fig)

    if return_fig:
        return all_figs