Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0c0ad06

Browse files
committedSep 8, 2024
python Decision Tree, add need visualization the process
1 parent da925c9 commit 0c0ad06

File tree

4 files changed

+565
-1
lines changed

4 files changed

+565
-1
lines changed
 

‎DecisionTreeClassifier.png

480 KB
Loading

‎DecisionTreeClassifier.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# --- python Decision Tree, add need visualization the process
2+
3+
# To visualize the decision-making process of the Decision Tree, we can use `plot_tree` from `scikit-learn`. This function generates a plot of the tree structure, showing the conditions at each node and the splits made during training.
4+
5+
# Here’s an extended version of the previous example with a visualization of the decision tree:
6+
7+
# ### Full Example with Visualization:
8+
9+
# ```python
10+
# Import required libraries
11+
import matplotlib.pyplot as plt
12+
from sklearn.datasets import load_iris
13+
from sklearn.model_selection import train_test_split
14+
from sklearn.tree import DecisionTreeClassifier, plot_tree
15+
from sklearn import metrics
16+
17+
# Load the dataset (Iris dataset in this case)
18+
iris = load_iris()
19+
X = iris.data # Features
20+
y = iris.target # Labels
21+
22+
# Split the dataset into training and testing sets
23+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
24+
25+
# Initialize the Decision Tree Classifier
26+
clf = DecisionTreeClassifier()
27+
28+
# Train the model
29+
clf.fit(X_train, y_train)
30+
31+
# Make predictions
32+
y_pred = clf.predict(X_test)
33+
34+
# Evaluate the model
35+
accuracy = metrics.accuracy_score(y_test, y_pred)
36+
print(f"Accuracy: {accuracy * 100:.2f}%")
37+
38+
# Visualize the decision tree
39+
plt.figure(figsize=(12,8))
40+
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
41+
plt.show()
42+
# ```
43+
44+
# ### Key Explanation:
45+
# 1. **`plot_tree` Function**:
46+
# - `clf`: The trained decision tree classifier.
47+
# - `feature_names`: Names of the features (columns) in the dataset.
48+
# - `class_names`: Labels of the target variable (class names).
49+
# - `filled=True`: Colors the nodes to indicate different classes.
50+
#
51+
# 2. **`plt.figure(figsize=(12,8))`**: Resizes the plot for better visualization.
52+
53+
# ### Visualization Output:
54+
# The visualization will show a tree where each node contains:
55+
# - The feature used for splitting.
56+
# - The threshold for the split.
57+
# - The Gini impurity or entropy of the node.
58+
# - The number of samples at each node.
59+
# - The class distribution and the predicted class for leaf nodes.
60+
61+
# This provides a clear visual of the decision process, showing how the tree splits data at each step.
62+
63+
# Let me know if you'd like further customization for the tree's visualization!
64+
65+
## python DecisionTreeClassifier.py
66+
# Matplotlib is building the font cache; this may take a moment.
67+
# Accuracy: 100.00%

‎poetry.lock

Lines changed: 496 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ torch = "^2.4.1"
1111
torchvision = "^0.19.1"
1212
tensorboard = "^2.17.1"
1313
numpy = "^2.1.1"
14+
scikit-learn = "^1.5.1"
15+
matplotlib = "^3.9.2"
1416

1517
[build-system]
1618
requires = ["poetry-core"]

0 commit comments

Comments
 (0)
Please sign in to comment.