-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathlabel_data.py
158 lines (138 loc) · 5.96 KB
/
label_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Script used to label the dataset files (data/data*.csv), call using the -h
# option for information.
# The program extracts preliminary features from the data, then sorts the
# results by feature importance and plots them in batches. Segments can be
# labeled by clicking on subplots.
import sys
import numpy as np
import pandas as pd
import argparse
from datetime import datetime
import glob
sys.path.append('lib')
import detect_peaks
from sklearn import preprocessing
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.patches import Rectangle
from peakutils.peak import indexes
from classes.Signal import Signal
from classes.DataSource import DataSource
# Parse arugments
data_file_id_choices = list(range(len(glob.glob("data/data*.csv"))))
parser = argparse.ArgumentParser(description='Label the dataset for training. \
A call should supply a data file index (e.g. 0 for ./data/data0.csv) and the label \
type (+/-). Figures of signal segment plots will be displayed, ordered by features that \
correlate with high signal to noise ratio. Labels for each signal segment \
are generated by clicking on the respective plot. The file ID, start, and end \
indices of the segment will be appended as a single new line in \
positive_ranges.csv or negative_ranges.csv, depending on the supplied label type.')
parser.add_argument('--file_id', type=int, default=20, required=True,
choices=data_file_id_choices,
help='data file index (e.g. 0 for ./data/data0.csv)')
parser.add_argument('--label_type', type=str, default="+", required=True,
choices=["+","-","positive","negative"],
help='e.g. +/-/positive/negative')
args = parser.parse_args()
FILE_ID = args.file_id
LABEL_TYPE = args.label_type
LABEL_TYPE = LABEL_TYPE.replace("+", "positive").replace("-", "negative")
# Helper functions
def onclick(event):
fx, fy = fig.transFigure.inverted().transform((event.x,event.y))
for i, subplot in enumerate(subplots):
if subplot["pos"].contains(fx,fy) and subplot["used"] == False:
range_ids = pd.DataFrame([subplot["range"]])
range_ids.to_csv('data/%s_ranges.csv' % LABEL_TYPE,
mode='a', header=False, index=False)
subplots[i]["used"] = True
fig.text(np.mean([subplot["pos"].x1,subplot["pos"].x0])-0.01,
np.mean([subplot["pos"].y1,subplot["pos"].y0]),
'Labeled %s' % LABEL_TYPE,
horizontalalignment='center',
verticalalignment='center',
color="green",
backgroundcolor="white",
fontsize=14)
fig.canvas.draw()
break
ds = DataSource()
dataset = ds.read_data_from_file(FILE_ID)
labeled_ds_pos = pd.read_csv('data/positive_ranges.csv',
header=None,
names=["file_id", "start", "end"])
labeled_ds_neg = pd.read_csv('data/negative_ranges.csv',
header=None,
names=["file_id", "start", "end"])
step = 256
offset = 0
start, end = offset, dataset.shape[0]
features = []
while start+step < end:
signal = Signal(dataset.iloc[start:start+step].ppg.values,
dataset.iloc[start:start+step].timestamp.values)
feature_vector = signal.extract_features(validate_HR_range = (True if LABEL_TYPE=="positive" else False))
if feature_vector != None:
features.append(feature_vector + [signal,start,start+step])
start += step
# Sort by features in ascending order, in order of feature importance
columns = ["mean_HF", "HF/LF", "VLF/LF", "peak_var", "signal", "start", "end"]
sort_column_order = [columns[i] for i in [2,1,3,0]]
features = pd.DataFrame(features, columns=columns).sort_values(sort_column_order, ascending=True)
num_figure_subplots = 30
counter = 0
k = 0
while num_figure_subplots*k < features.shape[0] and k < 100:
fig = plt.figure(k+1, figsize=(15, 10))
subplots = []
for i in range(num_figure_subplots):
feat = features.iloc[num_figure_subplots*k+i]
signal = feat.signal
start = feat.start
end = feat.end
signal = preprocessing.scale(signal.highpass_filter(1))
signal_filtered = preprocessing.scale(signal.bandpass_filter(0.8, 2.5))
start_time = pd.Timestamp(signal.timestamp_in_datetime(0))
end_time = pd.Timestamp(signal.timestamp_in_datetime(-1))
t = np.linspace(start_time.value, end_time.value, step)
t = pd.to_datetime(t)
ax = plt.subplot(num_figure_subplots/3,3,i+1)
alpha = 1
used = False
label = None
if labeled_ds_pos.isin([FILE_ID, start, end]).all(1).any():
label = "+"
if labeled_ds_neg.isin([FILE_ID, start, end]).all(1).any():
label = "-"
if label != None:
alpha = 0.35
ax.text(0.5, 0.5,'Already labeled %s' % label,
horizontalalignment='center',
verticalalignment='center',
transform=ax.transAxes,
fontsize=14)
used = True
subplots.append({"pos":ax.get_position(),
"range":[FILE_ID, start, end],
"used":used,
"figure_id":k+1})
ax.plot(t, preprocessing.scale(signal), alpha=alpha)
ax.plot(t, preprocessing.scale(signal_filtered), color='r', alpha=alpha)
ax.xaxis_date()
ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S'))
ax.yaxis.set_visible(False)
cid = fig.canvas.mpl_connect('button_press_event', onclick)
figManager = plt.get_current_fig_manager()
try:
figManager.window.showMaximized()
except:
try:
figManager.full_screen_toggle()
except:
try:
figManager.window.state('zoomed')
except:
pass
plt.show()
counter += num_figure_subplots
k += 1