Skip to content

Slider in scatter_3d and scatter makes some data points go missing #4768

Open
@Vilin97

Description

@Vilin97

Only 2 out of 4 categories are plotted when I use a slider. Other data points do not appear at all. When I slide the slider, different categories are plotted. E.g. in the MWE below, only TP and FP show up when the slider is below 0.9. At 0.9 only TN and FN show up.

This behavior also happens for both 2d and 3d scatter plots. See the MWE below for 3d.

image

import numpy as np
import pandas as pd
import plotly.express as px

def plot_scatter_3d_mwe():
    # Create a small DataFrame with fake data
    data = {
        'Dim1': np.random.rand(10),
        'Dim2': np.random.rand(10),
        'Dim3': np.random.rand(10),
        'due': [1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        'serial_number': range(10),
        'predicted_probabilities': [0.9, 0.8, 0.4, 0.2, 0.6, 0.7, 0.1, 0.5, 0.3, 0.95]
    }

    df = pd.DataFrame(data)
    thresholds = np.arange(0, 1.1, 0.1)
    all_frames = []

    for threshold in thresholds:
        # Recalculate predictions based on the threshold
        predicted = (df['predicted_probabilities'] >= threshold).astype(int)
        
        # Create the 4 categories for coloring: TP, TN, FP, FN
        conditions = [
            (df['due'] == 1) & (predicted == 1),  # TP
            (df['due'] == 0) & (predicted == 0),  # TN
            (df['due'] == 0) & (predicted == 1),  # FP
            (df['due'] == 1) & (predicted == 0),  # FN
        ]
        categories = ['TP', 'TN', 'FP', 'FN']
        
        # Assign the categories to a new column
        df['category'] = np.select(conditions, categories, default='Unknown')
        df['threshold'] = threshold  # Add threshold as a column for animation frame
        
        all_frames.append(df.copy())

    # Concatenate all frames for animation
    df_all_frames = pd.concat(all_frames)

    # Plot the scatter 3D with the categories as color and animate over thresholds
    fig = px.scatter_3d(df_all_frames,
                        x='Dim1', y='Dim2', z='Dim3',
                        color='category',
                        animation_frame='threshold',
                        animation_group='serial_number')

    fig.show()

# Call the function
plot_scatter_3d_mwe()

Metadata

Metadata

Assignees

No one assigned

    Labels

    P3backlogbugsomething broken

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions