///
Search
Duplicate
🥌

PAFNet: An Efficient Anchor-Free Object Detector Guidance

Created
6/28/2021, 8:41:00 AM
발표일자
2021/05/24
발표자
김연태
Tags
ObjectDetection
anchor-free
✅main
포스팅 종류
논문리뷰
사실 처음보는 Anchor free Detector 논문이라 궁금해서 본건데 생각보다 내용이 없다

Abstract

Object Detection 은 매우 다양한 곳에 사용될 수 있지만 많은 딥러닝 기반 방식들이 모델이 크고 긴 Inference time 을 필요로 하기에 Effectiveness와 Efficiency 의 조율이 필요
이를 위해 pre-defined 된 anchor를 제거한 anchor-free 모델 사용해 두가지를 모두 잡는 모델을 만들어 보고자함
TTFNet 이라 불리는 anchor-free 모델을 사용해 이를 서버와 모바일환경에 최적화 시킨 PAFNet (Paddle Anchor Free Network) 을 제안
PaddlePaddle 이 뭔가 했더니 baidu 에서 만든 딥러닝 프레임워크라고 함

1. introduction

Object Detection은 CNN 을 적용하게 되면서 성능이 향상 되었었음
Region Proposal 과 Classification을 한번에 학습하는 One-stage 방식과 분리하여 진행하는 Two-stage 방식으로 발전해 왔는데 최근 속도에 큰 장점이 있는 One-stage 모델이 성능까지 좋아지며 이 방식이 주류를 이루고 있음
하지만 One-stage 모델의 경우 사전에 정의된 고정된 앵커가 존재하는데 이 매우 많은양의 앵커가 야기하는 일반화 성능의 손실이 크고 연산량도 매우 증가됨
이것을 제거하고 직접 위치를 Regression 하는 모델을 anchor-free 모델이라 부름
여기에 설명이 너무 다 나와서 뒤는 짜름

2. Related Work

Anchor Based model
오랫동안 주류로 사용되어온 방식으로 사전 정의된 앵커를 사용
1-Stage : YOLO, SSD, RetinaNet
2-stage : R-cnn
Anchor-free model
CenterNet
빠른 Inference 를 위해 Detector가 bonding box의 중앙만을 예측하게 하고 그외에 다른 속성(너비, 높이, 포즈 등) 은 해당 이미지에서 직접 연산하게 하는 방식
네트워크에서 히트맵같은 것을 생성하고 히트맵의 피크가 객체의 센터를 의미
이것은 많은 후처리절차 (NMS 등) 을 제거하게 함으로 inference time을 감소시키는데 매우 큰 역할을 함
하지만 Regression에서 중앙에만 집중하게 되어 네트워크 수렴이 느려지는 경향이 있음
일반적으로 MS-COCO를 학습하는데 12epoch 정도가 필요하나 CenterNet은 140 epoch 정도 필요하다고 함
TTFNet
CenterNet을 개선하여 더 좋은 밸런스를 갖춘 모델
학습시간을 감소시키기 위해 Gaussian kernel 을 사용하여 훈련 샘플을 인코딩
중심만이 아닌 주변까지 샘플로 사용할 수 있게 함
더 많은 샘플을 사용해 빠르게 수렴할수 있게 해줌
Training, Inference 시간과 Accuracy 에 균형있는 모습을 보임

3. Method

3.1 Architecture

서버와 모바일 각각 제안된 네트워크 구조를 설명
PAFNet for Server Side
Backborn 은 ResNet50-vd 를 사용
ResNet50과 비슷한 inference 속도를 가지나 성능이 좋음
Decoupling operation을 사용해 원본의 1/4 크기로 Up-sample
2, 3, 4 번째 레이어에 Skip connection 추가
Localization Branch
Fl:H/4×W/4×CF_l : H/4 \times W/4 \times C 의 형태를 가짐
GmG_m(Ground Thruth, 중앙지점이 1로 표시된 맵)를 그냥 사용하지 않고 Gaussian Kernel을 적용하여 GmG^\prime_m을 생성
Gaussian Filter
기존 CenterNet에서는 중앙지점 1자리만 Positive Sample로 사용되었으나 Gaussian filter로 인해 주변 Sample이 Negative 샘플로 같이 학습됨
GmG^\prime_mFlF_l에서 해당하는 클래스맵을 Focal Loss를 (heatmap 용으로 변경) 사용하여 Localization Loss를 계산
Regression Branch
Fr:H/4×W/4×4F_r : H/4 \times W/4 \times 4 의 형태를 가짐
Localization 과 마찬가지로 Gaussian Filter 적용
이후 모든 FrF_r에 있는 모든 점과의 GIoU를 계산
GIoU : IoU를 loss로 쓰기 위해 변형을 넣은 Loss
최종 Loss
AGS Module
PAFNet의 Regression Loss 는 목표 위치에 대한 모든 위치에서의 손실의 가중합계이고 가중치는 Gaussinal kernel 에 의해 제공됨
AGS는 localization과 regression의 훈련을 일관성 있게 하기 위해 제공됨
카테고리에 관련 없이 유의한 특성을 얻기 위해 채널 방향으로 max 연산을 사용해 1개 채널로 압축하고 이후 softmax를 사용
regression branch는 가우시안 필터에 해당하는 부분에 의해서만 연산되기에 위 softmax 결과 에 해당 가우시안 필터를 사용하여 마스킹
실험적으로 AGS모듈은 매우 작은 추가 메모리로 네트워크의 정확도에 있어 매우 큰 영향을 주었고 tranining시에만 사용되기에 inference시에 영향이 없음
PAFNet-lite for Mobile Side
모바일에선 모델크기를 최소화하고 inference 속도를 올리는데 중점을 둠
MobileNetV3를 사용해 봤으나 조금 더 좋은 구조로 변형함
Upsampling Network와 Detection Head에 사용됨

3.2 Selection of Trick

기존에 사용되던 다양한 기법들을 효율적으로 적용할 방법 모색
Better Pretrain Model
ResNet50 대신 ResNet50-vd-ssld 모델을 사용
Inference 에는 영향을 끼치지 않음
모바일의 경우 MObileNetV3 을 사용해 knowledge Distillation 사용
Exponential Moving Average
EMA 는 Detection 에서 자주 사용되는 방식
WEMA=λWEMA+(1λ)WW_{EMA} = \lambda W_{EMA} + (1-\lambda) W
CutMix
분류에서 사용되는 MixUp의 개선버전
MixUp : 두개의 데이터를 일정 비율 λ\lambda로 interpolation 하여 새로운 데이터를 생성하는 Augmentation 방식중 하나
CutMix는 위 그림처럼 다른 그림을 잘라 붙이는 형태
일반적으로 Classification 이나 Detection에서 좋은 성능을 보임
GridMask
데이터의 일부를 삭제하는 방식
Dropout이나 정규화 텀을 추가하는것과 비슷한 원리
Deformable Convolution Network

4. Experiment

MS-COCO 데이터셋 사용

4.1 Implementation Detail

Training
Backborn : ResNet50-vd
image resize : 512×512512 \times 512
wloc=1,wreg=5w_{loc} = 1, w_{reg} = 5
Syncronized SGD for 15K iterations with 0.015 lr
11.25K 와 13.75K 에서 1/10로 감소
minibatch = 12, 8 GPU
Mobile side
backborn : MobileNetV3
image resize : 320×320320\times320
Inference
mAP 와 IoU 는 0.5에서 0.95로 설정
1 Tesla V100 GPU with 1 batch size
Mobile은 4x Kirin 990 ARM CPU

4.2 Ablation Study

5. Conclusion

서버의 경우 기존 anchor-free Sota 모델보다 정확도나 속도 면에서 전부 앞서있음
모바일쪽에 프레임워크를 최적화함