Skip to content

Commit 1671763

Browse files
committed
Refactor _ds_to_dlc_style_df
1 parent 1bd0b9f commit 1671763

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

movement/io/save_poses.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,24 +37,25 @@ def _ds_to_dlc_style_df(
3737
pandas.DataFrame
3838
3939
"""
40-
is_3d = "z" in columns.get_level_values("coords")
41-
if is_3d:
42-
tracks_with_scores = ds.position.data
43-
else:
44-
# Concatenate the pose tracks and confidence scores into one array
45-
tracks_with_scores = np.concatenate(
40+
# Keep position data as is, if data is 3D (i.e. contains 'z' coordinate)
41+
# Otherwise, concatenate position and confidence scores into one array
42+
tracks = (
43+
ds.position.data
44+
if "z" in columns.get_level_values("coords")
45+
else np.concatenate(
4646
(
4747
ds.position.data,
4848
ds.confidence.data[:, np.newaxis, ...],
4949
),
5050
axis=1,
5151
)
52+
)
5253
# Reverse the order of the dimensions except for the time dimension
53-
transpose_order = [0] + list(range(tracks_with_scores.ndim - 1, 0, -1))
54-
tracks_with_scores = tracks_with_scores.transpose(transpose_order)
54+
transpose_order = [0] + list(range(tracks.ndim - 1, 0, -1))
55+
tracks = tracks.transpose(transpose_order)
5556
# Create DataFrame with multi-index columns
5657
df = pd.DataFrame(
57-
data=tracks_with_scores.reshape(ds.sizes["time"], -1),
58+
data=tracks.reshape(ds.sizes["time"], -1),
5859
index=np.arange(ds.sizes["time"], dtype=int),
5960
columns=columns,
6061
dtype=float,
@@ -132,21 +133,16 @@ def to_dlc_style_df(
132133
else base_coords + ["likelihood"]
133134
)
134135
individuals = ds.coords["individuals"].data.tolist()
135-
136136
if split_individuals:
137137
df_dict = {}
138-
139138
for individual in individuals:
140139
individual_data = ds.sel(individuals=individual)
141-
142140
index_levels = ["scorer", "bodyparts", "coords"]
143141
columns = pd.MultiIndex.from_product(
144142
[scorer, bodyparts, coords], names=index_levels
145143
)
146-
147144
df = _ds_to_dlc_style_df(individual_data, columns)
148145
df_dict[individual] = df
149-
150146
logger.info(
151147
"Converted poses dataset to DeepLabCut-style DataFrames "
152148
"per individual."
@@ -157,9 +153,7 @@ def to_dlc_style_df(
157153
columns = pd.MultiIndex.from_product(
158154
[scorer, individuals, bodyparts, coords], names=index_levels
159155
)
160-
161156
df_all = _ds_to_dlc_style_df(ds, columns)
162-
163157
logger.info("Converted poses dataset to DeepLabCut-style DataFrame.")
164158
return df_all
165159

0 commit comments

Comments
 (0)