follow 他的 README,clong 下來。
依照 instruction run deep_sort_app.py 的話,需要 download Multi-Object-Tracking 數據集。
其中MOT16 的 sequence data,就在最下面..寫 "Download" 的地方..
Download 下來,解開到 MOT16 目錄下。
README 說明:" download pre-generated detections and the CNN checkpoint file from here."
下載 , 是 detections 目錄,放到 resources 目錄下。
不知道不是 python 版本的問題,tunple 的標示操作Error。
所以參考 這一篇,修改 linear_assignment.py:
diff --git a/deep_sort/linear_assignment.py b/deep_sort/linear_assignment.py
index 178456c..d24ec44 100644
--- a/deep_sort/linear_assignment.py
+++ b/deep_sort/linear_assignment.py
@@ -1,7 +1,7 @@
# vim: expandtab:ts=4:sw=4
from __future__ import absolute_import
import numpy as np
-from sklearn.utils.linear_assignment_ import linear_assignment
+from scipy.optimize import linear_sum_assignment as linear_assignment
from . import kalman_filter
@@ -55,16 +55,16 @@ def min_cost_matching(
cost_matrix = distance_metric(
tracks, detections, track_indices, detection_indices)
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
- indices = linear_assignment(cost_matrix)
+ row_indices, col_indices = linear_assignment(cost_matrix)
matches, unmatched_tracks, unmatched_detections = [], [], []
for col, detection_idx in enumerate(detection_indices):
- if col not in indices[:, 1]:
+ if col not in col_indices:
unmatched_detections.append(detection_idx)
for row, track_idx in enumerate(track_indices):
- if row not in indices[:, 0]:
+ if row not in row_indices:
unmatched_tracks.append(track_idx)
- for row, col in indices:
+ for row, col in zip(row_indices,col_indices):
track_idx = track_indices[row]
detection_idx = detection_indices[col]
if cost_matrix[row, col] > max_distance:
最前面是 linear_assignment package 位置改變。還會有 numpy int 的問題,最後用 python3.5成功(因為 conda 的 python3 最低只有到 3.5)
-- 其實是numpy 版本問題
然後依照說明 run:
python deep_sort_app.py \
--sequence_dir=./MOT16/test/MOT16-06 \
--detection_file=./resources/detections/MOT16_POI_test/MOT16-06.npy \
--min_confidence=0.3 \
--nn_budget=100 \
--display=True
就可以看到 tracking 的結果。之後, python 3.9.16, numpy 1.23.5 也可以順利執行修改過的 code。
README 後段,重新 detection,要使用 tensorflow。
原來應該是使用 tensorflow 1.X
改用 tensorflow 2.4.1 ,要修改:
diff --git a/tools/freeze_model.py b/tools/freeze_model.py
index e89ad29..63b9fc5 100644
--- a/tools/freeze_model.py
+++ b/tools/freeze_model.py
@@ -1,7 +1,7 @@
# vim: expandtab:ts=4:sw=4
import argparse
-import tensorflow as tf
-import tensorflow.contrib.slim as slim
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
def _batch_norm_fn(x, scope=None):
@@ -193,6 +193,7 @@ def parse_args():
def main():
args = parse_args()
+ tf.disable_v2_behavior()
with tf.Session(graph=tf.Graph()) as session:
input_var = tf.placeholder(
之後,要 freeze model,除了要改 compat.v1,還要修改slim:
diff --git a/tools/freeze_model.py b/tools/freeze_model.py
index e89ad29..63b9fc5 100644
--- a/tools/freeze_model.py
+++ b/tools/freeze_model.py
@@ -1,7 +1,7 @@
# vim: expandtab:ts=4:sw=4
import argparse
-import tensorflow as tf
-import tensorflow.contrib.slim as slim
+import tensorflow.compat.v1 as tf
+import tf_slim as slim
def _batch_norm_fn(x, scope=None):
@@ -193,6 +193,7 @@ def parse_args():
def main():
args = parse_args()
+ tf.disable_v2_behavior()
with tf.Session(graph=tf.Graph()) as session:
input_var = tf.placeholder(
沒有留言:
張貼留言