-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
854 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from datetime import datetime as date | ||
from glob import glob | ||
import os | ||
from loguru import logger | ||
import argparse | ||
import cv2 | ||
|
||
from edgeyolo.detect import Detector, TRTDetector, draw | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser("EdgeYOLO Detect parser") | ||
parser.add_argument("-w", "--weights", type=str, default="edgeyolo_coco.pth", help="weight file") | ||
parser.add_argument("-c", "--conf-thres", type=float, default=0.25, help="confidence threshold") | ||
parser.add_argument("-n", "--nms-thres", type=float, default=0.55, help="nms threshold") | ||
parser.add_argument("--fp16", action="store_true", help="fp16") | ||
parser.add_argument("--no-fuse", action="store_true", help="do not fuse model") | ||
parser.add_argument("--input-size", type=int, nargs="+", default=[640, 640], help="input size: [height, width]") | ||
parser.add_argument("-s", "--source", type=str, default="E:/videos/test.avi", help="video source or image dir") | ||
parser.add_argument("--trt", action="store_true", help="is trt model") | ||
parser.add_argument("--legacy", action="store_true", help="if img /= 255 while training, add this command.") | ||
parser.add_argument("--use-decoder", action="store_true", help="support original yolox model v0.2.0") | ||
parser.add_argument("--batch-size", type=int, default=1, help="batch size") | ||
parser.add_argument("--no-label", action="store_true", help="do not draw label") | ||
parser.add_argument("--save-dir", type=str, default="./imgs/coco", help="image result save dir") | ||
|
||
args = parser.parse_args() | ||
exist_save_dir = os.path.isdir(args.save_dir) | ||
|
||
# detector setup | ||
detector = TRTDetector if args.trt else Detector | ||
detect = detector( | ||
weight_file=args.weights, | ||
conf_thres=args.conf_thres, | ||
nms_thres=args.nms_thres, | ||
input_size=args.input_size, | ||
fuse=not args.no_fuse, | ||
fp16=args.fp16, | ||
use_decoder=args.use_decoder | ||
) | ||
if args.trt: | ||
args.batch_size = detect.batch_size | ||
|
||
# source loader setup | ||
if os.path.isdir(args.source): | ||
|
||
class DirCapture: | ||
|
||
def __init__(self, dir_name): | ||
self.imgs = [] | ||
for img_type in ["jpg", "png", "jpeg", "bmp", "webp"]: | ||
self.imgs += sorted(glob(os.path.join(dir_name, f"*.{img_type}"))) | ||
|
||
def isOpened(self): | ||
return bool(len(self.imgs)) | ||
|
||
def read(self): | ||
print(self.imgs[0]) | ||
now_img = cv2.imread(self.imgs[0]) | ||
self.imgs = self.imgs[1:] | ||
return now_img is not None, now_img | ||
|
||
source = DirCapture(args.source) | ||
delay = 0 | ||
else: | ||
source = cv2.VideoCapture(int(args.source) if args.source.isdigit() else args.source) | ||
delay = 1 | ||
|
||
all_dt = [] | ||
dts_len = 300 // args.batch_size | ||
success = True | ||
|
||
# start inference | ||
while source.isOpened() and success: | ||
|
||
frames = [] | ||
for _ in range(args.batch_size): | ||
success, frame = source.read() | ||
if not success: | ||
if not len(frames): | ||
cv2.destroyAllWindows() | ||
break | ||
else: | ||
while len(frames) < args.batch_size: | ||
frames.append(frames[-1]) | ||
else: | ||
frames.append(frame) | ||
|
||
if not len(frames): | ||
break | ||
|
||
results = detect(frames, args.legacy) | ||
dt = detect.dt | ||
all_dt.append(dt) | ||
if len(all_dt) > dts_len: | ||
all_dt = all_dt[-dts_len:] | ||
print(f"\r{dt * 1000 / args.batch_size:.1f}ms " | ||
f"average:{sum(all_dt) / len(all_dt) / args.batch_size * 1000:.1f}ms", end=" ") | ||
|
||
key = -1 | ||
imgs = draw(frames, results, detect.class_names, 2, draw_label=not args.no_label) | ||
# print([im.shape for im in frames]) | ||
for img in imgs: | ||
# print(img.shape) | ||
cv2.imshow("EdgeYOLO result", img) | ||
key = cv2.waitKey(delay) | ||
if key in [ord("q"), 27]: | ||
break | ||
elif key == ord(" "): | ||
delay = 1 - delay | ||
elif key == ord("s"): | ||
if not exist_save_dir: | ||
os.makedirs(args.save_dir, exist_ok=True) | ||
file_name = f"{str(date.now()).split('.')[0].replace(':', '').replace('-', '').replace(' ', '')}.jpg" | ||
cv2.imwrite(os.path.join(args.save_dir, file_name), img) | ||
logger.info(f"image saved to {file_name}.") | ||
if key in [ord("q"), 27]: | ||
cv2.destroyAllWindows() | ||
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import yaml | ||
import argparse | ||
import os.path as osp | ||
import os | ||
from loguru import logger | ||
import torch | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser("EdgeYOLO onnx2tensorrt parser") | ||
parser.add_argument("-o", "--onnx", type=str, default="yolov7.onnx", help="ONNX file") | ||
parser.add_argument("-y", "--yaml", type=str, default="yolov7.yaml", help="export params file") | ||
parser.add_argument("-w", "--workspace", type=int, default=8, help="export memory workspace(GB)") | ||
parser.add_argument("--fp16", action="store_true", help="fp16") | ||
parser.add_argument("--int8", action="store_true", help="int8") | ||
parser.add_argument("--best", action="store_true", help="best") | ||
parser.add_argument("-d", "--dist-path", type=str, default="export_output/tensorrt") | ||
parser.add_argument("--batch-size", type=int, default=0, help="batch-size") | ||
return parser.parse_args() | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
|
||
assert osp.isfile(args.onnx), f"No such file named {args.onnx}." | ||
assert osp.isfile(args.yaml), f"No such file named {args.yaml}." | ||
|
||
os.makedirs(args.dist_path, exist_ok=True) | ||
|
||
name = args.onnx.replace("\\", "/").split("/")[-1][:-len(args.onnx.split(".")[-1])] | ||
|
||
engine_file = osp.join(args.dist_path, name + "engine").replace("\\", "/") | ||
pt_file = osp.join(args.dist_path, name + "pt").replace("\\", "/") | ||
cls_file = osp.join(args.dist_path, name + "txt").replace("\\", "/") | ||
params = yaml.load(open(args.yaml).read(), yaml.Loader) | ||
command = f"trtexec --onnx={args.onnx}" \ | ||
f"{' --fp16' if args.fp16 else ' --int8' if args.int8 else ' --best' if args.best else ''} " \ | ||
f"--saveEngine={engine_file} --workspace={args.workspace*1024} " \ | ||
f"--batch={args.batch_size if not args.batch_size > 0 else params['batch_size'] if 'batch_size' in params else 1}" | ||
|
||
logger.info("start converting onnx to tensorRT engine file.") | ||
os.system(command) | ||
|
||
if not osp.isfile(engine_file): | ||
logger.error("tensorRT engine file convertion failed.") | ||
return | ||
|
||
logger.info(f"tensorRT engine saved to {engine_file}") | ||
|
||
try: | ||
data = { | ||
"model": { | ||
"engine": bytearray(open(engine_file, "rb").read()), | ||
"input_names": params["input_name"], | ||
"output_names": params["output_name"] | ||
}, | ||
"names": params["names"], | ||
"img_size": params["img_size"], | ||
"batch_size": params["batch_size"] | ||
} | ||
class_str = "" | ||
for name in params["names"]: | ||
class_str += name + "\n" | ||
with open(cls_file, "w") as cls_f: | ||
cls_f.write(class_str[:-1]) | ||
logger.info(f"class names txt pt saved to {cls_file}") | ||
torch.save(data, pt_file) | ||
logger.info(f"tensorRT pt saved to {pt_file}") | ||
except Exception as e: | ||
logger.error(f"convert2pt error: {e}") | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# parameters | ||
nc: 10 # number of classes | ||
depth_multiple: 1.0 # models depth multiple | ||
width_multiple: 1.0 # layer channel multiple | ||
|
||
# anchor-box-free | ||
anchors: | ||
- [8, 8] # P3/8 | ||
- [16, 16] # P4/16 | ||
- [32, 32] # P5/32 | ||
|
||
# edgeyolo backbone | ||
backbone: | ||
# [from, number, module, args] | ||
[[-1, 1, Conv, [32, 3, 1]], # 0 | ||
|
||
[-1, 1, Conv, [64, 3, 2]], # 1-P1/2 | ||
[-1, 1, Conv, [64, 3, 1]], | ||
|
||
[-1, 1, Conv, [128, 3, 2]], # 3-P2/4 | ||
[-1, 1, Conv, [64, 1, 1]], | ||
[-2, 1, Conv, [64, 1, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[[-1, -3, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [256, 1, 1]], # 11 | ||
|
||
[-1, 1, MP, []], | ||
[-1, 1, Conv, [128, 1, 1]], | ||
[-3, 1, Conv, [128, 1, 1]], | ||
[-1, 1, Conv, [128, 3, 2]], | ||
[[-1, -3], 1, Concat, [1]], # 16-P3/8 | ||
[-1, 1, Conv, [128, 1, 1]], | ||
[-2, 1, Conv, [128, 1, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[[-1, -3, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [512, 1, 1]], # 24 | ||
|
||
[-1, 1, MP, []], | ||
[-1, 1, Conv, [256, 1, 1]], | ||
[-3, 1, Conv, [256, 1, 1]], | ||
[-1, 1, Conv, [256, 3, 2]], | ||
[[-1, -3], 1, Concat, [1]], # 29-P4/16 | ||
[-1, 1, Conv, [256, 1, 1]], | ||
[-2, 1, Conv, [256, 1, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[[-1, -3, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [1024, 1, 1]], # 37 | ||
|
||
[-1, 1, MP, []], | ||
[-1, 1, Conv, [512, 1, 1]], | ||
[-3, 1, Conv, [512, 1, 1]], | ||
[-1, 1, Conv, [512, 3, 2]], | ||
[[-1, -3], 1, Concat, [1]], # 42-P5/32 | ||
[-1, 1, Conv, [256, 1, 1]], | ||
[-2, 1, Conv, [256, 1, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[[-1, -3, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [1024, 1, 1]], # 50 | ||
] | ||
|
||
# edgeyolo head | ||
head: | ||
[[-1, 1, SPPCSPC, [512]], # 51 | ||
|
||
[-1, 1, Conv, [256, 1, 1]], | ||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||
[37, 1, Conv, [256, 1, 1]], # route backbone P4 | ||
[[-1, -2], 1, Concat, [1]], | ||
|
||
[-1, 1, Conv, [256, 1, 1]], | ||
[-2, 1, Conv, [256, 1, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [256, 1, 1]], # 63 | ||
|
||
[-1, 1, Conv, [128, 1, 1]], | ||
[-1, 1, nn.Upsample, [None, 2, 'nearest']], | ||
[24, 1, Conv, [128, 1, 1]], # route backbone P3 | ||
[[-1, -2], 1, Concat, [1]], | ||
|
||
[-1, 1, Conv, [128, 1, 1]], | ||
[-2, 1, Conv, [128, 1, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[-1, 1, Conv, [64, 3, 1]], | ||
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [128, 1, 1]], # 75 | ||
|
||
[-1, 1, MP, []], | ||
[-1, 1, Conv, [128, 1, 1]], | ||
[-3, 1, Conv, [128, 1, 1]], | ||
[-1, 1, Conv, [128, 3, 2]], | ||
[[-1, -3, 63], 1, Concat, [1]], | ||
|
||
[-1, 1, Conv, [256, 1, 1]], | ||
[-2, 1, Conv, [256, 1, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[-1, 1, Conv, [128, 3, 1]], | ||
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [256, 1, 1]], # 88 | ||
|
||
[-1, 1, MP, []], | ||
[-1, 1, Conv, [256, 1, 1]], | ||
[-3, 1, Conv, [256, 1, 1]], | ||
[-1, 1, Conv, [256, 3, 2]], | ||
[[-1, -3, 51], 1, Concat, [1]], | ||
|
||
[-1, 1, Conv, [512, 1, 1]], | ||
[-2, 1, Conv, [512, 1, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[-1, 1, Conv, [256, 3, 1]], | ||
[[-1, -2, -3, -4, -5, -6], 1, Concat, [1]], | ||
[-1, 1, Conv, [512, 1, 1]], # 101 | ||
|
||
[75, 1, RepConv, [256, 3, 1]], # 102 | ||
[88, 1, RepConv, [512, 3, 1]], # 103 | ||
[101, 1, RepConv, [1024, 3, 1]], # 104 | ||
|
||
[[102,103,104], 1, YOLOXDetect, [nc, anchors, Conv]], | ||
] |
Oops, something went wrong.