- Junghwan Park
Loss landscape를 여행하는 연구자를 위한 안내서
알세미를 포함하여 현실에 주어진 문제를 풀고자 할 때 데이터가 부족한 경우가 많습니다. 이런 경우에 대응하기 위한 AI modeling 연구도 활발히 진행되고 있습니다. 앞선 글에서 살펴본 방법은 크게 2가지였습니다. 첫 번째는 여러 task로부터 prior knowledge를 추출하여 학습 시키는 방법입니다 [두 번째 글]. 대표적으로 Transfer learning, Meta learning이 있습니다. 두 번째로는 물리 현상을 잘 설명해주는 방정식을 사용하여 neural network를 학습시키는 방법입니다 [세 번째 글]. 대표적으로 Hamiltonian equation을 사용하여 Hamiltonian mechanics를 neural network에 가르쳐주는 방법이 있습니다([1]). 이런 방법들은 generalization을 위한 알고리즘이기도 합니다([1], [10], [11]). 이번에는 loss landscape 측면에서 generalization을 해석하는 방법을 알아보겠습니다.
Loss landscape와 Generalization
Loss landscape 란? Loss landscape는 파라미터에 따른 손실 값(loss)의 변화를 나타냅니다. Neural network를 훈련한다는 것은 loss landscape를 미끄러져 내려가는 과정이라고 해석할 수 있습니다. 그러므로 loss landscape를 잘 알면 neural network 학습을 이해하는데 도움이 됩니다. 예를 들어 loss landscape 시각화 연구 논문인 'Visualizing the loss landscape of neural nets'에서는 skip connection([4])이 왜 좋은 성능을 내는지 그림 1을 통해 한눈에 보여줍니다([3]). Skip connection은 모델의 성능을 개선해 주거나 학습 속도를 빠르게 해주는 방법입니다. 그림 1의 (a), (b)는 각각 skip connection을 사용하지 않았을 때와 사용했을 때의 loss landscape를 보여줍니다. Skip connection을 안 썼을 때(그림 1.a)는 loss landscape가 굉장히 복잡한 모양으로 보입니다. 그래서 학습 시작점(initialization)이 어디냐에 영향을 크게 받으며 훈련이 힘들 것으로 유추할 수 있습니다. 반면 skip connection을 썼을 때(그림 1.b)는 거의 볼록한 모양으로 보입니다. 때문에 최적점까지 쉽게 미끄러져 내려가 학습이 잘 될 수 있을 것으로 추측할 수 있습니다. 이처럼 loss landscape는 neural network 이해에 직관적인 도움을 줍니다.

Generalization 이란? Generalization은 학습 데이터로 훈련했을 때 보지 못한 데이터(테스트 데이터)도 잘 맞추는 것을 말합니다. 보통 학습 데이터와 테스트 데이터는 분포가 다릅니다. 그래서 학습 데이터로 훈련했을 때 테스트 데이터를 잘 못 맞추는 경우가 발생합니다. 모델의 구조나 학습 방법 등에도 영향을 받습니다. 모델의 generalization 성능을 평가할 때는 generalization gap이라는 용어를 사용합니다. 같은 분포에서 추출된 학습 데이터와 테스트 데이터 사이의 성능 차이를 generalization gap이라고 부릅니다([12]). 특히 학습 데이터가 적으면 generalization gap이 커질 수 있기 때문에 알세미 AI팀에서도 관심 있게 보는 주제입니다.
Flat Minima & Sharp Minima
Generalization gap은 loss landscape와도 깊은 연관이 있습니다. 아래의 그림 2를 보며 이해해 보겠습니다.

x축은 파라미터, y축은 손실 값(loss)을 나타냅니다. 검정 선은 학습 데이터에 대한 loss landscape function (training function)을, 빨간 점선은 테스트 데이터에 대한 loss landscape function (testing function)을 나타냅니다. 훈련 데이터와 테스트 데이터의 분포가 다르기 때문에 조금 이동한 모양을 보여줍니다. 학습을 통해 training function의 최적값(minimum point)까지 파라미터를 업데이트했다고 가정해 보겠습니다. 그럼 당연히 training function은 minimum loss 값을 가지게 됩니다. 반면 testing function은 오른쪽으로 이동하여 존재하므로 학습 데이터로 훈련하여 얻은 파라미터로는 최적의 loss를 갖지 못합니다. 이 차이는 minima의 모양에 따라 커지거나 작아질 수 있습니다. 그림 2 왼쪽과 오른쪽에는 각각 flat minima, sharp minima가 있습니다. 위와 같이 training function의 최적값(minimum point)까지 파라미터를 업데이트했다고 가정해 보겠습니다. Flat minima일 때 testing function의 y값(loss)은 training function의 y값(loss)과 차이가 크지 않음을 알 수 있습니다. 반면, sharp minima일 때 testing function의 y값(loss)은 training function의 y값(loss)과 차이가 많이 나는 것을 알 수 있습니다. 그렇기 때문에 sharp minima를 피하고 flat minima를 찾을 수 있다면 generalization gap을 줄이고 테스트 데이터에 대해서도 좋은 성능을 기대할 수 있습니다([2], [3], [6]).
Practical Method
Sharp minima를 피하고 Flat minima로 훈련하는 것이 좋다는 것은 알 수 있었습니다. 그런데 어떻게 Flat minima를 찾을 수 있을까요? 몇 가지 방법을 소개해 드리겠습니다. 그림 1을 다시 보면 skip connection을 사용하지 않았을 때(그림 1.a) 보다 skip connection을 사용했을 때(그림 1.b)가 더 flat 한 minima를 가지는 것을 볼 수 있습니다. 이처럼 어떤 구조의 neural network를 쓰는지가 generalization gap을 줄이는데 큰 영향을 끼칠 수 있습니다([2], [3]).

Stochastic Gradient Descent (SGD)를 사용하는 것만으로도 flat minima를 찾는데 도움을 줍니다. 딥러닝에서는 보통 데이터를 많이 쓰기 때문에 전체 데이터를 한번에 넣어서 학습하지 않고 여러 개로 쪼개어(batch) 사용하게 됩니다. 파라미터를 업데이트 할 때 전체 데이터가 아닌 일부 데이터를 사용하므로 gradient에 noise가 추가되어 학습됩니다. 이런 noise가 sharp minima를 탈출할 수 있도록 도와줍니다([8]). Sharp minima의 경우 noise에 민감하여 탈출이 쉬운 반면 flat minima는 탈출이 어렵기 때문에 flat minima로 수렴할 확률이 높기 때문입니다. 그래서 배치(batch) 크기를 줄이거나 학습률 조절(learning rate scheduling)을 사용하여 gradient noise를 적절히 이용하는 것이 flat minima를 찾아가는데 도움을 줄 수 있습니다([2], [7]). 비슷한 방법으로 학습 중에 perturbation을 줘서 sharp minima를 탈출시키는 방법도 있습니다. Loss가 증가하도록 weight 혹은 input에 perturbation을 주는 방법입니다([9]). Sharp minima라면 탈출할 것이고 flat minima라면 그렇지 않을 것입니다.
모델 훈련 시 flat minima의 특징을 갖는 곳을 찾도록 알고리즘을 구성한 논문도 있습니다([6]). Flat minima에서는 loss가 주위 영역과 별로 차이가 나지 않는 특징이 있습니다. 즉, Perturbation에 강건하여 loss가 잘 변하지 않는다면 flat minima로 볼 수 있습니다. 이러한 영역을 찾도록 수식으로 만들어 학습하는 방법입니다. 이 알고리즘을 사용했을 때 테스트 데이터에 대해서 좋은 성능을 보여 당시 ImageNet([13]) state of the art를 달성하기도 했습니다.
본 포스팅에서는 minima와 flatness 기초 개념을 다루었습니다. 딥러닝을 사용해 학습을 해보면 안 될 것 같은데 잘 되는 경우가 생깁니다😮. 반대로 진짜 안되는 경우도 있습니다😢. 이런 현상이 왜 생기는지 알아내기 위해 generalization과 loss landscape 연구가 활발히 진행되고 있습니다. 실제 학습에 사용하기 위해 Hessian의 eigenvalue들을 구해서 principal curvature를 분석하는 등 기하학적인 연구들도 포함됩니다. 많은 데이터를 추출하기 어려운 scientific machine learning 분야에서 모델의 기본 성능을 높이기 위해 꼭 염두에 두어야 할 연구들입니다. 알세미 AI팀에서도 반도체 모델링에서 generalization이 잘 되기 위한 좋은 minima의 기하학적 특성이 무엇일지 토론하며 재미있는 연구를 이어가고 있습니다.
Reference
[1] Hamiltonian neural networks, S. Greydanus et al., NeurIPS 2019 Hamiltonian Neural Networks
[2] On large-batch training for deep learning: Generalization gap and sharp minima, N. S. Keskar et al., ICLR 2017 On Large-Batch Training for Deep Learning: Generalization Gap and...
[3] Visualizing the loss landscape of neural nets, H. Li et al., NeurIPS 2018 Visualizing the Loss Landscape of Neural Nets
[4] Deep Residual Learning for Image Recognition, K. He et al., CVPR 2016 Deep Residual Learning for Image Recognition
[5] Train longer, generalize better: closing the generalization gap in large batch training of neural networks, E. Hoffer et al., NeurIPS 2017 Train longer, generalize better: closing the generalization gap in...
[6] Sharpness-Aware Minimization for Efficiently Improving Generalization, P. Foret et al., ICLR 2021 Sharpness-Aware Minimization for Efficiently Improving Generalization
[7] Towards Flatter Loss Surface via Non Monotonic Learning Rate Scheduling, S. Seong et al., AUAI 2018 Towards Flatter Loss Surface via Nonmonotonic Learning Rate Scheduling
[8] A Bayesian Perspective on Generalization and Stochastic Gradient Descent, SL. Smith et al., ICLR 2018 A Bayesian Perspective on Generalization and Stochastic Gradient Descent
[9] Adversarial Weight Perturbation Helps Robust Generalization, D. Wu et al., NeurIPS 2020 Adversarial Weight Perturbation Helps Robust Generalization
[10] Meta Learning, Yee Whye Teh, The Machine Learning Summer School 2020, Max Planck Institute for Intelligent Systems Tübingen Machine Learning Summer School 2020 Meta Learning, part 1 - Yee Whye Teh - MLSS 2020, Tübingen
[11] How transferable are features in deep neural networks?, J. Yosinski et al., NeurIPS 2014 How transferable are features in deep neural networks?
[12] Predicting the Generalization Gap in Deep Neural Networks, Y. Jiang, Google AI Blog Predicting the Generalization Gap in Deep Neural Networks
[13] ImageNet: A Large-Scale Hierarchical Image Database, J. Deng et al., CVPR 2009 https://ieeexplore.ieee.org/document/5206848https://ieeexplore.ieee.org/document/5206848