On Calibration of Modern Neural Networks
๐ ๋ ผ๋ฌธ ์ ๋ณด
โOn Calibration of Modern Neural Networksโ (ICML 2017)
๐ ์ฐ๊ตฌ ๊ฐ์
Confidence Calibration์ ๋ชจ๋ธ์ด ์์ธกํ ํ๋ฅ ์ด ์ค์ ์ ๋ต์ผ ๊ฐ๋ฅ์ฑ๊ณผ ์ผ๋ง๋ ์ผ์นํ๋์ง๋ฅผ ๋ํ๋ด๋ ๊ฐ๋ ์ด๋ค. ์๋ฅผ ๋ค์ด, ์ด๋ค ์ด๋ฏธ์ง์ ๋ํด ๋ชจ๋ธ์ด ๊ณ ์์ด์ผ ํ๋ฅ ์ 0.9๋ผ๊ณ ์์ธกํ์ ๋, ์ด๋ฌํ ์์ธก์ด ์ ๋ณด์ (calibrated)๋์ด ์๋ค๋ฉด ์ค์ ๋ก ๊ทธ ์ด๋ฏธ์ง๊ฐ ๊ณ ์์ด์ผ ํ๋ฅ ๋ ์ฝ 90%๊ฐ ๋์ด์ผ ํ๋ค๋ ๊ฒ์ด๋ค.
On Calibration of Modern Neural Networks ๋ ผ๋ฌธ์ ๋น ๋ฅด๊ฒ ๋ฐ์ ํ๊ณ ์ฐ๊ตฌ๋์ด์ ธ ์ค๊ณ ์๋ ResNet, DenseNet ๋ฑ๊ณผ ๊ฐ์ ํ๋์ ์ธ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๋ค์ด ๋์ ๋ถ๋ฅ ์ ํ๋๋ฅผ ๋ฌ์ฑํจ์๋ ๋ถ๊ตฌํ๊ณ , ์คํ๋ ค ํ๋ฅ ๋ณด์ (calibration) ์ฑ๋ฅ์ ๋ ๋๋น ์ก๋ค๋ ์ฌ์ค์ ์คํ์ ์ผ๋ก ๋ณด์ด๊ณ ์๋ค. ๊ณผ๊ฑฐ์ ์์ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๋ค์ ์์ธก ํ๋ฅ ์ด ์ค์ ์ ๋ต ํ๋ฅ ๊ณผ ๋น๊ต์ ์ ์ผ์นํ์ง๋ง, ๊น๊ณ ๋ณต์กํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง ์ต์ ๋ชจ๋ธ๋ค์ ์์ ์๊ฒ ์์ธก์ ํ๋ฉด์๋ ๊ทธ ํ๋ฅ ์ด ์ค์ ์ ๋ต๋ฅ ๊ณผ ๋ถ์ผ์นํ๋ ๊ฒฝํฅ์ด ์๋ค๋ ์ ์ ์ง๊ณ ์๋ค. ํ ๋ง๋๋ก ๋งํด ์ต์ ๋ชจ๋ธ๋ค์ Overconfident ๋์ด์ ธ ์๋ค๋ ์ ์ ๋ฐ๊ฒฌํ์๋ค.
Figure 1: Confidence histograms (top) and reliability diagrams (bottom) for a 5-layer LeNet (left) and a 110-layer ResNet (right) on CIFAR-100. Refer to the text below for detailed illustration.
์ด๋ฌํ ๋ฌธ์ ๋ ์์จ์ฃผํ, ์๋ฃ ์ง๋จ, ๋ฒ๋ฅ ํ๋จ ๋ฑ๊ณผ ๊ฐ์ด ์์ธก ๊ฒฐ๊ณผ์ ๋ํ ์ ๋ขฐ๋๊ฐ ๋งค์ฐ ์ค์ํ ์์ฉ ๋ถ์ผ์์ ํนํ ๋๋๋ผ์ง ์ ์๊ธฐ์ ๋ชจ๋ธ์ด ๋จ์ํ ์ ํํ ๋ฟ๋ง ์๋๋ผ, ์์ ์ ์์ธก์ด ์ผ๋ง๋ ํ์คํ์ง์ ๋ํ ํํ ๋ํ ์ ๋ขฐํ ์ ์์ด์ผ ํ๋ค.
๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋ค์ํ ์ฌํ ๋ณด์ (post-hoc calibration) ๋ฐฉ๋ฒ๋ค์ ์คํ์ ์ผ๋ก ๋น๊ตํ๊ณ , ๊ทธ ์ค์์๋ Temperature Scaling์ด๋ผ๋ ๋จ ํ๋์ ์ค์นผ๋ผ ํ๋ผ๋ฏธํฐ๋ง์ ์ฌ์ฉํ๋ ๊ฐ๋จํ ๋ฐฉ๋ฒ์ด ๋งค์ฐ ํจ๊ณผ์ ์ด๋ผ๋ ์ฌ์ค์ ๋ฐํ๋๋ค. ๋ณธ ๊ธ์์๋ ์ด์ ์คํ ์ฝ๋๋ ๊ตฌํํ์๋ค.
๐ ์ ๊ฒฝ๋ง์ Overconfidence ์์ธ ๋ถ์
์ต๊ทผ์ ์ ๊ฒฝ๋ง ๋ชจ๋ธ๋ค์ ๋์ ์ ํ๋๋ฅผ ์๋ํ์ง๋ง, ๊ทธ confidence (์์ธก ํ๋ฅ ) ๋ ์ค์ ์ ๋ต๋ฅ ๊ณผ ์ ๋ง์ง ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค. ์ด ํ์์ miscalibration (๋ถ์์ ํ ๋ณด์ ) ์ด๋ผ๊ณ ํ๋ฉฐ, ๊ทธ ์์ธ๊ณผ ๊ด๋ จ ์์๋ค์ ์ฐ์ ๋ถ์ํ์๋ค.
Figure 2: The effect of network depth (far left), width (middle left), Batch Normalization (middle right), and weight decay (far right) on miscalibration, as measured by ECE (lower is better)
1. ๋ชจ๋ธ ์ฉ๋์ ์ฆ๊ฐ (Model Capacity)
- ์ต๊ทผ ๋ฅ๋ฌ๋ ๋ชจ๋ธ๋ค์ ๋ ์ด์ด ์์ ํํฐ ์๊ฐ ๊ธ๊ฒฉํ ์ฆ๊ฐํ์ฌ, ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋ ์ ๋ง์ถ ์ ์๋ ๋ชจ๋ธ ์ฉ๋(capacity) ์ ๊ฐ์ถ๊ฒ ๋์๋ค.
- ํ์ง๋ง ๋ชจ๋ธ ์ฉ๋์ด ์ปค์ง์๋ก ์คํ๋ ค confidence๊ฐ ์ค์ ์ ํ๋๋ณด๋ค ๊ณผ๋ํ๊ฒ ๋์์ง๋ ๊ณผ์ , ์ฆ overconfidence ํ๋ ๊ฒฝํฅ์ด ๋ํ๋๋ค.
์คํ ๊ฒฐ๊ณผ (ResNet on CIFAR-100):
- ๊น์ด(depth)๋ฅผ ์ฆ๊ฐ์ํค๋ฉด Error์ ์ค์ด๋๋ ECE๊ฐ ์ฆ๊ฐ
- ํํฐ ์(width)๋ฅผ ์ฆ๊ฐ์ํค๋ฉด Error์ ํ์ฐํ ์ค์ด๋๋, ECE๊ฐ ์ฆ๊ฐ
๋์ capacity๋ overfitting์ ์ผ๊ธฐํ ์ ์์ผ๋ฉฐ, ์ด๋ฌํ ๊ฒฝ์ฐ ์ ํ๋๋ ์ข์์ ธ๋ ํ๋ฅ ์ ํ์ง์ ๋๋น ์ง๋ค.
2. Batch Normalization์ ์ํฅ
- Batch Normalization์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ต์ ์์ ํ์ํค๊ณ ๋น ๋ฅด๊ฒ ๋ง๋๋ ๊ธฐ๋ฒ์ผ๋ก, ํ๋ ์ํคํ ์ฒ์์ ํ์์ ์ผ๋ก ์ฌ์ฉ๋๋ค.
- ํ์ง๋ง, BN์ ์ฌ์ฉํ ๋ชจ๋ธ๋ค์ ์ ํ๋๋ ์ฌ๋ผ๊ฐ์ง๋ง calibration์ ์คํ๋ ค ๋๋น ์ง๋ ํ์์ด ๋ํ๋๋ค.
์คํ ๊ฒฐ๊ณผ (6-layer ConvNet on CIFAR-100):
- BN์ ์ ์ฉํ ConvNet์ ์ ํ๋๊ฐ ์ฝ๊ฐ ์ฌ๋ผ๊ฐ์ง๋ง(Error ๊ฐ์), ECE๋ ๋๋ ทํ๊ฒ ์ฆ๊ฐ
BN์ ๋ด๋ถ ํ์ฑ ๋ถํฌ๋ฅผ ์ ๊ทํํ์ฌ ํ์ต์ ๋ ์ ๋๊ฒ ํ์ง๋ง, ๊ฒฐ๊ณผ์ ์ผ๋ก ์ถ๋ ฅ ํ๋ฅ ์ด ๋ ๊ณผ์ ๋(overconfident) ์ํ๋ก ๋ํ๋ calibration์๋ ๋ถ์ ์ ์ธ ์ํฅ์ ๋ฏธ์น๊ฒ ๋๋ค.
3. Weight Decay ๊ฐ์์ ์ํฅ
- ์ ํต์ ์ผ๋ก weight decay ๋ ๊ณผ์ ํฉ์ ๋ง๊ธฐ ์ํ ์ ๊ทํ ๋ฐฉ๋ฒ์ผ๋ก ๋๋ฆฌ ์ฌ์ฉ๋์ด ์์ผ๋ฉฐ, overfitting์ ๋ฐฉ์งํ๊ธฐ ์ํด ๊ฐ์ค์น์ ํจ๋ํฐ๋ฅผ ์ฃผ๋ ์ ๊ทํ ๊ธฐ๋ฒ์ด๋ค.
- ์ต๊ทผ์๋ BN์ ์ ๊ทํ ํจ๊ณผ ๋๋ฌธ์ weight decay ์ฌ์ฉ๋์ด ์ค์ด๋๋ ์ถ์ธ์ด๋ค.
- ํ์ง๋ง ์คํ์์๋ weight decay๋ฅผ ์ฆ๊ฐ์ํฌ์๋ก calibration์ ๊ฐ์ ๋๋ ๊ฒฝํฅ์ ๋ณด์ธ๋ค.
์คํ ๊ฒฐ๊ณผ (ResNet-110 on CIFAR-100):
- Weight decay๋ฅผ ์ฆ๊ฐ์ํค๋ฉด ๋ถ๋ฅ ์ ํ๋(Error)๋ ํน์ ๊ตฌ๊ฐ์์ ์ต์ ์ ์ ์ฐ๊ณ ์ดํ ๋ค์ ์ฆ๊ฐ
- ECE๋ weight decay๊ฐ ์ฆ๊ฐํ ์๋ก ๊ฐ์ํ๋ ๊ฒฝํฅ์ ๋ณด์
์ฆ, ์ ํ๋๋ฅผ ์ต๋ํํ๋ weight decay ์ค์ ๊ณผ calibration์ ์ต์ ํํ๋ ์ค์ ์ ์๋ก ๋ค๋ฅผ ์ ์์ผ๋ฉฐ, ์ ํ๋๋ ์ ์ง๋๋๋ผ๋ confidence๋ ์๊ณก๋ ์ ์๋ค.
4. NLL ๊ณผ์ ํฉ ํ์ (Overfitting to NLL)
Figure 3: Test error and NLL of a 110-layer ResNet with stochastic depth on CIFAR-100 during training
- ์คํ์์๋ learning rate๊ฐ ๋ฎ์์ง๋ ๊ตฌ๊ฐ์์ test error๋ ๊ณ์ ์ค์ด๋๋ ๋ฐ๋ฉด, NLL์ ๋ค์ ์ฆ๊ฐํ๋ ํ์์ ํ์ธํ์๋ค.
- ์ด๋ ๋ชจ๋ธ์ด ์ ํ๋๋ ๋์ด์ง๋ง confidence๊ฐ ์ค์ ๋ณด๋ค ๊ณผ๋ํ ์ํ๋ก ํ์ต์ด ์งํ๋๊ณ ์์์ ์๋ฏธํ๋ค.
์คํ ๊ฒฐ๊ณผ (ResNet-110 + stochastic depth on CIFAR-100):
- Epoch 250 ์ดํ learning rate ๊ฐ์
- ์ดํ test error๋ ๊ฐ์ (29% โ 27%)ํ์ง๋ง, NLL์ ์ฆ๊ฐ
์ต์ ์ ๊ฒฝ๋ง์ ํ์ต ํ๋ฐ๋ถ์์ NLL์ ๊ณ์ ์ต์ํํ๋ ค๋ ๊ณผ์ ์์ confidence๋ฅผ ๊ณผ๋ํ๊ฒ ๋์ด๋ ๊ฒฝํฅ์ด ์์ผ๋ฉฐ, ์ด๋ก ์ธํด ์ค์ ์ ๋ต๋ฅ ๋ณด๋ค ๋์ ํ๋ฅ ์ ์ถ๋ ฅํ๋ overconfidentํ ์ํ๋ก calibration ์ค๋ฅ๊ฐ ๋ฐ์ํ๋ค.
๐ Calibration์ ์ ์ ๋ฐ ์ธก์ ๋ฐฉ๋ฒ
๋ณธ ๋ ผ๋ฌธ์์๋ ๋ค์ค ํด๋์ค ๋ถ๋ฅ ๋ฌธ์ ๋ฅผ ๋ค๋ฃจ๊ณ ์์ผ๋ฉฐ, ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ์ฃผ์ด์ง ์ ๋ ฅ \(X \in \mathcal{X}, \quad Y \in \{1, \dots, K\}\) ๋ฅผ ์์ธกํ๋๋ค๊ณ ๊ฐ์ ํ๋ค. ๋ชจ๋ธ์ ์์ธก ํ๋ฅ ์ ๋ค์๊ณผ ๊ฐ์ด ์ ์๋๋ค.
\[h(X) = (\hat{Y}, \hat{P})\]์ฌ๊ธฐ์ \(\hat{Y}\)๋ ์์ธก๋ ํด๋์ค, \(\hat{P}\)๋ ๊ฐ ํด๋์ค์ ๋ํ ํ๋ฅ ๋ถํฌ์ด๋ฉฐ, softmax ์ถ๋ ฅ์ ์ต๋๊ฐ์ผ๋ก ์ ์๋๋ค.
๊ทธ๋ผ ์๋ฒฝํ๊ฒ ๋ณด์ ๋ ๋ชจ๋ธ์ ์ ์๋ ์ด๋ป๊ฒ ๋ ๊น? ๋ณธ ๋ ผ๋ฌธ์์๋ ๋ค์๊ณผ ๊ฐ์ด ์ ์ํ๊ณ ์๋ค.
\[P(\hat{Y} = Y \mid \hat{P} = p) = p, \quad \forall p \in [0, 1]\]์ ์์์ ์ ์ ์๋ฏ์ด ๋ชจ๋ธ์ด ์๋ฒฝํ ๋ณด์ (calibrated)๋์ด ์๋ค๋ ๊ฒ์, ์์ธกํ ํ๋ฅ ๊ฐ์ด ์ค์ ์ ๋ต๋ฅ ๊ณผ ์ผ์นํ๋ ๊ฒ์ ์๋ฏธํ๋ค. ์๋ฅผ ๋ค์ด ๋ชจ๋ธ์ด 100๊ฐ์ ์ํ์ ๋ํด ๋ชจ๋ 0.8์ confidence๋ฅผ ์ถ๋ ฅํ๋ค๋ฉด, ์ค์ ๋ก ๊ทธ ์ค ์ฝ 80๊ฐ๊ฐ ๋ง์์ผ ์๋ฒฝํ ๋ณด์ ๋ ๊ฒ์ด๋ค.
๐ Reliability Diagram (์ ๋ขฐ๋ ๋ค์ด์ด๊ทธ๋จ)
์์ธก ํ๋ฅ \(\hat{P}\)๋ฅผ ๊ตฌ๊ฐ์ผ๋ก ์๊ฒ ๋๋๊ณ , ๊ฐ ๊ตฌ๊ฐ์์์ ์ค์ ์ ๋ต๋ฅ (accuracy)๊ณผ ํ๊ท confidence๋ฅผ ๋น๊ตํ๋ค. ๋ง์ฝ ๋ชจ๋ธ์ด ์ ๋ณด์ ๋์ด์ ธ ์๋ค๋ฉด, ๊ฐ ๊ตฌ๊ฐ์์๋ ์๋์ ๊ด๊ณ์์ด ์ฑ๋ฆฝํด์ผํ๋ค๋ ๊ฒ์ด๋ค.
\[\text{acc}(B_m) = \frac{1}{|B_m|} \sum_{i \in B_m} \mathbf{1}(\hat{y}_i = y_i)\] \[\text{conf}(B_m) = \frac{1}{|B_m|} \sum_{i \in B_m} \hat{p}_i\] \[\text{Accuracy}(B_m) \approx \text{Confidence}(B_m)\]์ฌ๊ธฐ์ \(\text{acc}(B_m)\)๋ ๊ตฌ๊ฐ \(B_m\)์ ์ํ๋ ์ํ๋ค์ ์ค์ ์ ๋ต๋ฅ , \(\text{conf}(B_m)\)๋ ๊ตฌ๊ฐ \(B_m\)์ ์ํ๋ ์ํ๋ค์ ํ๊ท confidence๋ฅผ ์๋ฏธํ๋ค. ๋ง์ฝ ๋ชจ๋ธ์ด ์ ๋ณด์ ๋์ด ์๋ค๋ฉด, ๋ ๊ฐ์ ์๋ก ๋น์ทํด์ผ ํ๋ค.
๐ Expected Calibration Error (ECE)
ECE๋ ๋ชจ๋ธ์ ์ ์ฒด calibration ์ฑ๋ฅ์ ์์น์ ์ผ๋ก ์ธก์ ํ๋ ๋ํ์ ์ธ ์งํ๋ก, ๊ฐ bin์ ๋ํด ์์ธก ํ๋ฅ ๊ณผ ์ค์ ์ ๋ต๋ฅ ๊ฐ์ ์ฐจ์ด๋ฅผ ํ๊ท ํ์ฌ ๊ณ์ฐํ๋ค. \(M\)๊ฐ์ bin์ผ๋ก ๋๋๊ณ , ๊ฐ bin \(B_m\)์ ๋ํด ๋ค์๊ณผ ๊ฐ์ด ์ ์๋๋ค.
\[\text{ECE} = \sum_{m=1}^{M} \frac{|B_m|}{n} \left| \text{acc}(B_m) - \text{conf}(B_m) \right|\]๐ Maximum Calibration Error (MCE)
MCE๋ ๊ฐ์ฅ ํฐ ์ค์ฐจ๋ฅผ ๋ณด์ธ bin์ calibration gap์ ์ธก์ ํ๊ฒ ๋๋ฉฐ, ์ฝ๊ฒ ๋งํด โ์ต์ ์ ๋ณด์ ์คํจโ ์ ๋๋ฅผ ๋ํ๋ธ๋ค๊ณ ์๊ฐํ ์ ์๋ค. ์์ ์ด ์ค์ํ ์์คํ ์์ ๋งค์ฐ ์ค์ํ ์งํ๋ก ์ฌ์ฉ๋ ์ ์๋ค.
\[\text{MCE} = \max_{m \in \{1, \dots, M\}} \left| \text{acc}(B_m) - \text{conf}(B_m) \right|\]๐ ๏ธ Calibration Methods
๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ฏธ ํ์ต๋ ๋ชจ๋ธ์ ๋ํด ํ๋ฅ ๋ณด์ ์ ์ํ ์ฌํ ์ฒ๋ฆฌ(Post-hoc) ๋ฐฉ๋ฒ๋ค์ ์๊ฐํ๊ณ ์์ผ๋ฉฐ, ์ด๋ค์ ๋ชจ๋ธ์ ์์ธก ๊ตฌ์กฐ๋ ์ ํ๋๋ ์ ์งํ๋ฉด์, ์์ธก ํ๋ฅ (confidence)์ด ์ค์ ์ ๋ต๋ฅ ๊ณผ ๋ ์ ์ผ์นํ๋๋ก ๋ง๋ค์ด์ค๋ค.
๐ 1. Calibrating Binary Models
์ฐ์ ๋ค๋ฃฐ ๋ด์ฉ์ ์ด์ง ๋ถ๋ฅ(binary classification) ์ํฉ์์์ confidence calibration ๊ธฐ๋ฒ๋ค์ ์์๋ณด๊ณ ์ ํ๋ค. ์ด ๊ฒฝ์ฐ ํด๋์ค ๋ผ๋ฒจ์ \(Y \in \{0, 1\}\)์ด๋ฉฐ, ๋ชจ๋ธ์ positive ํด๋์ค(1)์ ๋ํ ํ๋ฅ \(\hat{p}_i\)์ logit \(z_i \in \mathbb{R}\)๋ฅผ ์ถ๋ ฅํ๋ค. ํ๋ฅ ์ ๋ณดํต sigmoid ํจ์๋ฅผ ํตํด ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐ๋๋ค:
\[\hat{p}_i = \sigma(z_i) = \frac{1}{1 + \exp(-z_i)}\]Calibration์ \(\hat{p}_i\) ๋๋ \(z_i\)๋ฅผ ์กฐ์ ํ์ฌ ์๋ก์ด calibrated ํ๋ฅ \(\hat{q}_i\)๋ฅผ ์ป๋ ๊ฒ์ด๋ค.
A. Histogram Binning
Histogram Binning์ ๋ชจ๋ธ์ด ์ถ๋ ฅํ๋ ํ๋ฅ ์ ์ผ์ ๊ตฌ๊ฐ(bin)์ผ๋ก ๋๋๊ณ , ๊ฐ ๊ตฌ๊ฐ ๋ด์์์ ์ค์ ์ ๋ต ๋น์จ์ ์๋ก์ด ๋ณด์ ํ๋ฅ ๋ก ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ด๋ค. ๋ชจ๋ \(\hat{p}_i\)๋ฅผ \(M\)๊ฐ์ ๊ตฌ๊ฐ \(B_1, B_2, ..., B_M\)์ผ๋ก ๋๋๋ค. ๊ฐ bin๋ง๋ค ์ ๋ต๋ฅ ์ ๊ณ์ฐํ์ฌ ๊ทธ ๊ฐ์ ๋ณด์ ๋ ํ๋ฅ ๋ก ์ฌ์ฉํ๋ค. ์ด๋ฅผ ์์์ผ๋ก ์ดํด๋ณด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
- ๊ตฌ๊ฐ ์ ์: \(B_m = (a_m, a_{m+1}]\) where \(0 = a_1 \leq a_2 \leq \dots \leq a_{M+1} = 1\)
- ์ต์ ํ ๋ฌธ์ : \(\min_{\theta_1,\dots,\theta_M} \sum_{m=1}^{M} \sum_{i=1}^{n} \mathbf{1}(a_m \leq \hat{p}_i < a_{m+1}) (\theta_m - y_i)^2\)
์์ ์ต์ ํ ๋ฌธ์ ๋ฅผ ํ๊ฒ๋๋ฉด ๊ฐ bin์ ๋ณด์ ํ๋ฅ ์ bin ๋ด ๋ ์ด๋ธ ํ๊ท ์ผ๋ก ๊ณ์ฐ๋๊ฒ ๋๋ค.
\[\theta_m = \frac{1}{|B_m|} \sum_{i \in B_m} y_i\]์์๋ฅผ ํตํด ์ดํด๋ณด๋๋ก Histogram Binning์ ์ดํด๋ณด๋๋ก ํ๊ฒ ๋ค.
- ์์ธก ํ๋ฅ ๊ณผ ์ค์ ์ ๋ต
์ํ ๋ฒํธ | ์์ธก ํ๋ฅ \(\hat{p}\_i\) | ์ค์ ์ ๋ต \(y\_i\) |
---|---|---|
1 | 0.15 | 0 |
2 | 0.22 | 1 |
3 | 0.35 | 1 |
4 | 0.48 | 0 |
5 | 0.63 | 1 |
6 | 0.72 | 1 |
7 | 0.85 | 1 |
8 | 0.89 | 0 |
9 | 0.95 | 1 |
- ์์ธก ํ๋ฅ ์ bin์ผ๋ก ๋๋ ๊ฒฐ๊ณผ
Bin ๋ฒํธ | ํ๋ฅ ๊ตฌ๊ฐ | ํด๋น ์ํ ๋ฒํธ | ์ ๋ต ๋ชฉ๋ก | ๋ณด์ ํ๋ฅ \(\theta\_m\) |
---|---|---|---|---|
1 | (0.0, 0.3] | 1, 2 | [0, 1] | 0.50 |
2 | (0.3, 0.7] | 3, 4, 5 | [1, 0, 1] | 0.67 |
3 | (0.7, 1.0] | 6, 7, 8, 9 | [1, 1, 0, 1] | 0.75 |
- ๊ฐ bin์์์ ์ค์ ์ ๋ต ํ๊ท (๋ณด์ ํ๋ฅ )
Bin ๋ฒํธ | ์ ๋ต ๋ชฉ๋ก | ๋ณด์ ํ๋ฅ \(\theta\_m\) |
---|---|---|
1 | [0, 1] | 0.50 |
2 | [1, 0, 1] | 0.67 |
3 | [1, 1, 0, 1] | 0.75 |
๊ตฌํ์ด ๊ฐ๋จํ๊ณ ์ด๋ค ๋ชจ๋ธ์ด๋ ์ฌํ ๋ณด์ ์ผ๋ก ์ ์ฉํ ์ ์์ผ๋, ์์ธก์ด ๊ณ๋จ ํจ์์ฒ๋ผ ๋ณ๊ฒฝ๋์ด ๋ถ๋๋ฝ์ง ์์ผ๋ฉฐ bin์ ๋ช ๊ฐ ๊ฐ์ง๋๋์ ๋ฐ๋ผ ์ฑ๋ฅ์ด ํฌ๊ฒ ๋ณํํ ์ ์๋ค.
B. Isotonic Regression
$\hat{p}_i$์ ๋ํด ๋จ์กฐ ์ฆ๊ฐ(monotonic) ํจ์ $f$๋ฅผ ํ์ตํ์ฌ:
\[\hat{q}_i = f(\hat{p}_i)\]- ๋ชฉ์ ํจ์: \(\min_f \sum_{i=1}^n (f(\hat{p}_i) - y_i)^2 \quad \text{subject to } f \text{ is monotonic}\)
์ด ๋ฐฉ๋ฒ์ histogram๋ณด๋ค ์ ์ฐํ์ง๋ง, ๊ณผ์ ํฉ์ ๊ฐ๋ฅ์ฑ์ด ์์ผ๋ฉฐ ์ ๊ทํ๊ฐ ํ์ํ ์ ์๋ค
C. Platt Scaling
๋ชจ๋ธ์ด ์ถ๋ ฅํ logit $z_i$๋ฅผ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํ์ฌ, sigmoid๋ฅผ ์ ์ฉํ ๋ก์ง์คํฑ ํ๊ท๋ฅผ ์ํํจ:
\[\hat{q}_i = \sigma(az_i + b) = \frac{1}{1 + \exp(-(az_i + b))}\]- $a$, $b$๋ validation set์์ NLL์ ์ต์ํํ๋๋ก ํ์ต
- ์ ๊ฒฝ๋ง ํ๋ผ๋ฏธํฐ๋ ๊ณ ์ ๋จ
์ ํ๋๋ ๊ทธ๋๋ก ์ ์งํ๋ฉด์ ํ๋ฅ ์ ์กฐ์ ํ ์ ์๋ ๊ฐ๋จํ ๋ฐฉ๋ฒ.
D. BBQ (Bayesian Binning into Quantiles)
Histogram Binning์ ํ์ฅ์ผ๋ก, ์ฌ๋ฌ binning scheme์ ๊ณ ๋ คํ์ฌ Bayesian model averaging์ ์ํํ๋ค.
calibrated ํ๋ฅ : \(P(\hat{q}_i | \hat{p}_i, D) = \sum_{s \in \mathcal{S}} P(\hat{q}_i | s, \hat{p}_i, D) \cdot P(s | D)\)
์ฌ๊ธฐ์ $s$๋ binning scheme์ด๋ฉฐ, prior๋ Beta ๋ถํฌ, likelihood๋ binomial๋ก ๊ณ์ฐ๋จ
์ ํํ์ง๋ง ๊ตฌํ์ด ๋ณต์กํ๊ณ ๊ณ์ฐ๋์ด ๋ง์.
๐ 2. Extension to Multiclass Models
๋ฌผ๋ก ์
๋๋ค. ์๋๋ ๐ 2. Extension to Multiclass Models
๋ด์ฉ์ ์์ ํฌํจ ๋งํฌ๋ค์ด ํ์์ผ๋ก ํ ๋ฒ์ ๋ณต์ฌํ ์ ์๋๋ก ์ ๋ฆฌํ ๊ฒ์
๋๋ค:
๐ 2. Extension to Multiclass Models
๋ค์ค ํด๋์ค ๋ถ๋ฅ ๋ฌธ์ ์์๋ ๊ฐ ์ ๋ ฅ ์ํ \(x_i\)์ ๋ํด logit ๋ฒกํฐ \(z\)๋ฅผ softmax์ ์ ๋ ฅํ์ฌ ๋ค์๊ณผ ๊ฐ์ด ํ๋ฅ ์ ๊ณ์ฐํ๋ค.
\[\hat{P} = \max_k \left( \frac{e^{z_k}}{\sum_j e^{z_j}} \right)\]A. Matrix Scaling
Logit ๋ฒกํฐ \(z\)์ ๋ํด ์ ํ ๋ณํ \(Wz + b\)๋ฅผ ์ ์ฉํ๊ณ softmax๋ฅผ ์ทจํ๋ค.
\[\hat{q} = \text{softmax}(Wz + b)\]- \(W \in \mathbb{R}^{K \times K}\), \(b \in \mathbb{R}^K\)
- ๋งค์ฐ ์ ์ฐํ ํํ์ด ๊ฐ๋ฅํ์ง๋ง, ํ๋ผ๋ฏธํฐ ์๊ฐ ๋ง์ ๊ณผ์ ํฉ ๊ฐ๋ฅ์ฑ์ด ์กด์ฌํ๋ค.
B. Vector Scaling
Matrix Scaling์ ๊ฐ์ํํ ๋ฒ์ ์ผ๋ก, \(W\)๋ฅผ ๋๊ฐ ํ๋ ฌ \(D\)๋ก ์ ํํ ๊ฒ์ด๋ค.
\[\hat{q} = \text{softmax}(Dz + b)\]- ํ๋ผ๋ฏธํฐ ์๋ \(2K\)๊ฐ๋ก ์ค์ด๋ค๋ฉฐ, ๊ณ์ฐ ํจ์จ์ฑ๊ณผ ๊ณผ์ ํฉ ๊ฐ๋ฅ์ฑ์ด ์ค์ด๋ ๋ค.
C. Temperature Scaling
๊ฐ์ฅ ๊ฐ๋จํ๋ฉด์๋ ๊ฐ๋ ฅํ ๋ณด์ ๊ธฐ๋ฒ์ผ๋ก logit์ \(T\)๋ก ๋๋๊ณ softmax ์ ์ฉํ๋ค.
\[\hat{q} = \text{softmax}(z / T)\]- ( T > 1 ): ํ๋ฅ ์ด ๋ถ์ฐ๋จ
- ( T = 1 ): ์๋ ๋ชจ๋ธ๊ณผ ๋์ผ
- ( T < 1 ): ํ๋ฅ ์ด ๋ sharpํด์ง
๐ฏ ์คํ ๋ชฉ์
ํ๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ๋์ ๋ถ๋ฅ ์ ํ๋๋ฅผ ๋ฌ์ฑํ์ง๋ง, ์์ธก ํ๋ฅ ์ ๊ณผ์ (overconfidence)๋๋ ๊ฒฝํฅ์ด ์๋ ๊ฒ์ ์์ ์ดํด๋ณด์์ผ๋ฉฐ, ๋ค์ํ ์ฌํ ํ๋ฅ ๋ณด์ (post-hoc calibration) ๊ธฐ๋ฒ๋ค์ ์ ๋ฆฌํ์๋ค. ์ด๋ค์ ๋น๊ตํ๊ณ ๊ฐ์ฅ ํจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ์ ์๋ณํ๊ณ ์ ํ๋ค.
- ๋ชจ๋ธ: ResNet, Wide ResNet, DenseNet, LeNet, DAN, TreeLSTM
- ๋ฐ์ดํฐ์ : CIFAR-10, CIFAR-100, ImageNet, SVHN, Birds, Cars, 20News, Reuters, SST
- ํ๊ฐ์งํ: Expected Calibration Error (ECE), Maximum Calibration Error (MCE), Negative Log Likelihood (NLL), Error Rate
Figure 5: ECE (%) (with M = 15 bins) on standard vision and NLP datasets before calibration and with various calibration methods. The number following a modelโs name denotes the network depth.
๋ ผ๋ฌธ์ ํ๋ ์ ๊ฒฝ๋ง๋ค์ด ๋์ ์ ํ๋์๋ ๋ถ๊ตฌํ๊ณ ์์ธก ํ๋ฅ ์ด ๊ณผ๋ํ๊ฒ ์์ ๊ฐ(overconfident)์ ๊ฐ์ง๋ค๋ ๋ฌธ์ ๋ฅผ ์ง์ ํ๊ณ ์์ผ๋ฉฐ, Temperature Scaling์ ๋จ ํ๋์ ํ๋ผ๋ฏธํฐ๋ง์ผ๋ก ๋๋ถ๋ถ์ ๊ฒฝ์ฐ ๊ฐ์ฅ ์ฐ์ํ ๋ณด์ ์ฑ๋ฅ(ECE ๊ฐ์)์ ๋ณด์ฌ์ฃผ์๋ค. ๋ค๋ฅธ ๋ณต์กํ ๋ณด์ ๊ธฐ๋ฒ๋ค๋ณด๋ค ๊ฐ๋จํ๊ณ ์์ ์ ์ผ๋ก ์๋ํ๋ฉฐ, ์ ํ๋๋ฅผ ๋จ์ด๋จ๋ฆฌ์ง ์๊ณ ํ๋ฅ ๋ง ์กฐ์ ํ ์ ์๋ค๋ ๊ฒ์ ํ์ธํ ์ ์์์ผ๋ฉฐ, CIFAR-100 ๋ฑ ๋ค์ํ ๋ฐ์ดํฐ์ ์์ ECE๊ฐ ์ต๋ 16% โ 1% ์์ค์ผ๋ก ํฌ๊ฒ ๊ฐ์ํ์๋ค.
๐งช Python ์ค์ต ๊ฐ์
์ง๊ธ๋ถํฐ๋ ์์ ์ดํด๋ณธ ๋ ผ๋ฌธ๊ณผ ๊ด๋ จํ์ฌ ์ค์ต์ ์งํํด๋ณด๊ณ ์ ํ๋ค. ๋ณธ ์คํ์ CIFAR-100 ์ด๋ฏธ์ง ๋ถ๋ฅ ๊ณผ์ ๋ฅผ ๋์์ผ๋ก ResNet-34 ๋ชจ๋ธ์ ์ ๋ขฐ๋ ๋ณด์ (calibration) ์ฑ๋ฅ์ ์ ๋์ ์ผ๋ก ํ๊ฐํ๊ณ ์๊ฐํํ๋ ๊ฒ์ ๋ชฉ์ ์ผ๋ก ํ๋ค. ๋ฅ๋ฌ๋ ๋ถ๋ฅ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ๋์ ๋ถ๋ฅ ์ ํ๋๋ฅผ ๋ฌ์ฑํ ์ ์์ผ๋, ์ถ๋ ฅ ํ๋ฅ ์ด ์ค์ ์ ๋ต์ผ ๊ฐ๋ฅ์ฑ์ ๊ณผ๋ํ๊ฐํ๋ ๊ณผ์ (overconfidence) ํ์์ ์์ฃผ ๋ํ๋ธ๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ ์ค์ ์์ฉ์์ ๋ชจ๋ธ์ ๋ถํ์ค์ฑ์ ์ ๋ขฐํ ์ ์๊ฒ ๋ง๋ ๋ค. ์์ ์ดํด๋ณธ ๋ ผ๋ฌธ์์๋ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํ ๋ค์ํ Post-hoc ๋ณด์ ๊ธฐ๋ฒ๋ค์ ์ ์ํ๊ณ ์์ผ๋ฉฐ, ์ด ์ค ํ๋์ธ Temperature Scaling์ ์ง์ ๊ตฌํํ๊ณ , ๊ทธ ํจ๊ณผ๋ฅผ Reliability Diagram์ ํตํด ์๊ฐ์ ์ผ๋ก ๋ถ์ํด๋ณด๊ณ ์ ํ๋ค.
๋ชจ๋ธ ํ์ต์๋ CIFAR-100 ๋ฐ์ดํฐ์ ๊ณผ ResNet-34 ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ์ฌ์ฉํ์๋ค. CIFAR-100์ ์ด 100๊ฐ์ ๋ ์ด๋ธ๋ก ๊ตฌ์ฑ๋ ์ปฌ๋ฌ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ์ผ๋ก, ๊ฐ ์ด๋ฏธ์ง๋ 32ร32 ํด์๋์ RGB ์ด๋ฏธ์ง์ด๋ค. ๊ฐ ํด๋์ค๋น 500๊ฐ์ ํ์ต ์ด๋ฏธ์ง์ 100๊ฐ์ ํ ์คํธ ์ด๋ฏธ์ง๊ฐ ํฌํจ๋์ด ์์ผ๋ฉฐ, ์ด 50,000๊ฐ์ ํ์ต ์ํ๊ณผ 10,000๊ฐ์ ํ ์คํธ ์ํ์ ํฌํจํ๋ค. ๋ถ๋ฅ ๋ชจ๋ธ๋ก ์ฌ์ฉ๋ ResNet-34๋ Residual Network ๊ณ์ด์ ๋ํ์ ์ธ ๊ตฌ์กฐ ์ค ํ๋๋ก, 34๊ฐ์ ์ธต์ ๊ฐ๋ ์ฌ์ธต ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง์ด๋ค. ์์ฐจ ์ฐ๊ฒฐ(residual connection)์ ํตํด ๊น์ ๋คํธ์ํฌ์์ ๋ฐ์ํ๋ ๊ธฐ์ธ๊ธฐ ์์ค ๋ฌธ์ ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ํด๊ฒฐํ ์ ์์ผ๋ฉฐ, CIFAR-100๊ณผ ๊ฐ์ ์ค๊ฐ ๋์ด๋์ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ฌธ์ ์ ๋๋ฆฌ ์ฌ์ฉ๋๋ค. ๋ณธ ์คํ์์๋ ์ฌ์ ํ์ต ์์ด ์ฒ์๋ถํฐ CIFAR-100์ ๋ํด ResNet-34๋ฅผ ํ์ต์์ผฐ์ผ๋ฉฐ, ์ถ๋ ฅ์ธต fully-connected layer์ ์ถ๋ ฅ ์ฐจ์์ 100์ผ๋ก ์ค์ ํ์ฌ 100๊ฐ์ ํด๋์ค๋ฅผ ๋ถ๋ฅํ๋๋ก ๊ตฌ์ฑํ์๋ค.
๋ณธ ์คํ์์๋ ํ์ต๋ ResNet-34 ๋ชจ๋ธ์ ๊ธฐ์ค์ผ๋ก T โ \({0.5, 1.0, 1.5, 2.0}\) ๋ฒ์์ ๋ํด Temperature Scaling์ ์ ์ฉํ ํ, ๋ค์๊ณผ ๊ฐ์ ๊ด์ ์์ ๋ณด์ ์ฑ๋ฅ์ ํ๊ฐํ์๋ค:
- Reliability Diagram์ ํตํด confidence vs accuracy ๊ด๊ณ๋ฅผ ์๊ฐํ
- Expected Calibration Error (ECE) ์์น๋ฅผ ๊ณ์ฐํ์ฌ ์ ๋์ ๋ณด์ ์ฑ๋ฅ ํ๊ฐ
- ๊ฐ confidence bin ๋ด์ sample ์ ๋ฐ ์ ํ๋ ๋ณํ ๋ถ์
๊ทธ๋ผ ์ง๊ธ๋ถํฐ๋ ๋จ๊ณ๋ณ๋ก ์ฝ๋๋ฅผ ์ดํด๋ณด๋๋ก ํ๊ฒ ๋ค.
1. CIFAR-100 ๋ฐ์ดํฐ์ ์ ์ ๊ทํ๋ฅผ ์ํ ํ๊ท ๋ฐ ํ์คํธ์ฐจ ๊ณ์ฐ
๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ ๋, ์ ๋ ฅ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ์ ๊ทํ(normalization) ํ๋ ๊ฒ์ ๋งค์ฐ ์ค์ํ ์ ์ฒ๋ฆฌ ๊ณผ์ ์ด๋ค. ๋ณดํต ์ ๊ทํ๋ ๊ฐ ์ฑ๋(R, G, B)์ ๋ํด ํ๊ท ์ ๋นผ๊ณ ํ์คํธ์ฐจ๋ก ๋๋๋ ๋ฐฉ์์ผ๋ก ์ด๋ฃจ์ด์ง๋ค. ์ด ๊ณผ์ ์ ํตํด ์ ๋ ฅ ๊ฐ์ ๋ถํฌ๋ฅผ 0์ ์ค์ฌ์ผ๋ก ์ ๊ทํํจ์ผ๋ก์จ, ํ์ต์ด ๋ ์์ ์ ์ผ๋ก ์ด๋ฃจ์ด์ง๊ณ ์๋ ด ์๋๊ฐ ๋นจ๋ผ์ง ์ ์๋ค.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
transform = transforms.ToTensor()
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(trainset, batch_size=50000, shuffle=False)
data = next(iter(loader))[0] # (50000, 3, 32, 32)
mean = data.mean(dim=(0, 2, 3))
std = data.std(dim=(0, 2, 3))
print("CIFAR-100 ํ๊ท :", mean)
print("CIFAR-100 ํ์คํธ์ฐจ:", std)
์ ์ฝ๋๋ฅผ ์คํํ๊ฒ ๋๋ฉด CIFAR-100 ๋ฐ์ดํฐ์ ์ด ๋ก์ปฌ์ ์ ์ฅ๋์ด์ ธ ์์ง ์์ ๊ฒฝ์ฐ ./data ๊ฒฝ๋ก์ ์ ์ฅํ๊ฒ ๋๋ฉฐ ์ดํ ์ ์ฒด ํ์ต ๋ฐ์ดํฐ์ ๋ํ ํ๊ท ๊ณผ ํ์คํธ์ฐจ๋ฅผ ๊ณ์ฐํ๊ฒ ๋๋ค.
2. CIFAR-100 ๋ฐ์ดํฐ๋ฅผ ์ด์ฉํ ResNet-34 ํ์ต
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
# ---------------- ํ์ดํผํ๋ผ๋ฏธํฐ ์ค์ ----------------
batch_size = 128
epochs = 30
lr = 0.1
save_path = "๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ---------------- ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ๋ฐ ๋ก๋ฉ ----------------
# cifar100_mean_std.py ์ถ๋ ฅ ๊ฒฐ๊ณผ
mean = (0.5071, 0.4866, 0.4409)
std = (0.2673, 0.2564, 0.2762)
# ํ์ต์ฉ ๋ฐ์ดํฐ์ ๋ํด ๋ฐ์ดํฐ ์ฆ๊ฐ ๋ฐ ์ ๊ทํ ์ํ
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), # ๋ฌด์์ crop (32x32 ์ ์ง)
transforms.RandomHorizontalFlip(), # ๋ฌด์์ ์ข์ฐ ๋ฐ์
transforms.ToTensor(), # ํ
์ ๋ณํ (0~1)
transforms.Normalize(mean, std) # ์ฑ๋๋ณ ์ ๊ทํ
])
# CIFAR-100 ํ์ต ๋ฐ์ดํฐ์
๋ก๋
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)
# ---------------- ๋ชจ๋ธ ์ ์ ----------------
model = models.resnet34(weights=None)
model.fc = nn.Linear(model.fc.in_features, 100)
model = model.to(device)
# ---------------- ์์ค ํจ์ ๋ฐ ์ตํฐ๋ง์ด์ ----------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) # 10 epoch๋ง๋ค ํ์ต๋ฅ ๊ฐ์
# ---------------- ํ์ต ๋ฃจํ ----------------
for epoch in range(epochs):
model.train() # ํ์ต ๋ชจ๋ ํ์ฑํ
running_loss = 0.0
correct = 0
total = 0
for inputs, labels in trainloader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad() # ๊ธฐ์ธ๊ธฐ ์ด๊ธฐํ
outputs = model(inputs) # ์์ ํ
loss = criterion(outputs, labels) # ์์ค ๊ณ์ฐ
loss.backward() # ์ญ์ ํ
optimizer.step() # ๊ฐ์ค์น ์
๋ฐ์ดํธ
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100. * correct / total
print(f"[Epoch {epoch+1}/{epochs}] Loss: {running_loss:.3f}, Train Accuracy: {acc:.2f}%")
scheduler.step()
# ---------------- ํ์ต๋ ๋ชจ๋ธ ์ ์ฅ ----------------
torch.save(model.state_dict(), save_path)
print(f"๐พ Model saved to: {save_path}")
์์ ์ฝ๋๋ ResNet-34๋ฅผ CIFAR-100 ๋ฐ์ดํฐ์ ์ ์ด์ฉํ์ฌ ํ์ตํ๋ ๊ณผ์ ์ ๋ํ๋ด์๋ค. 1๋ฒ์์ ๊ณ์ฐ๋ ํ๊ท ๊ณผ ํ์คํธ์ฐจ ๊ฐ์ transforms.Normalize(mean, std) ํจ์์ ๊ทธ๋๋ก ์ ์ฉํ์๋ค. ์ด ์ ๊ทํ ๊ณผ์ ์ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์ ํฌํจ๋์ด ์์ผ๋ฉฐ, ๋ฌด์์ ์๋ฅด๊ธฐ(RandomCrop), ์ข์ฐ ๋ฐ์ (RandomHorizontalFlip), ํ ์ ๋ณํ(ToTensor) ์ดํ ๋ง์ง๋ง ๋จ๊ณ์์ ์ํ๋๋ค. ์ ๊ทํ๋ ๋ฐ์ดํฐ๋ ์ดํ ResNet-34 ๋ชจ๋ธ์ ์ ๋ ฅ๋๋ฉฐ, ์ด ๋ชจ๋ธ์ ์ถ๋ ฅ์ธต๋ง CIFAR-100์ 100๊ฐ ํด๋์ค์ ๋ง๊ฒ ์์ ๋ ํํ๋ก ์ฌ์ฉ๋๋ค.
๋ฅ๋ฌ๋ ๋ชจ๋ธ, ํนํ ๊น์ ๊ตฌ์กฐ์ ResNet-34์ ๊ฐ์ ๋ชจ๋ธ์ ์ ๋ ฅ์ ๋ถํฌ๊ฐ ์ง๋์น๊ฒ ํธํฅ๋์ด ์์ ๊ฒฝ์ฐ ํ์ต์ด ์ ๋์ง ์๊ฑฐ๋, ์ด๊ธฐ ํ์ต ๋จ๊ณ์์ ๋งค์ฐ ๋๋ฆฐ ์๋ ด ์๋๋ฅผ ๋ณด์ผ ์ ์๋ค. ๋ฐ๋ผ์ ์ฌ์ ์ ๋ฐ์ดํฐ์ ์ ํต๊ณ ์ ๋ณด๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์ ๊ทํ๋ฅผ ์ํํ๋ ๊ฒ์ ํ์ต ์ฑ๋ฅ์ ์์ ์ํค๋ ํต์ฌ์ ์ธ ์์๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก, 1๋ฒ์ ์ ๊ทํ ๊ฐ ๊ณ์ฐ์ 2๋ฒ์ ํจ๊ณผ์ ์ธ ๋ชจ๋ธ ํ์ต์ ์ํ ํ์ ์ ์ฒ๋ฆฌ ๊ณผ์ ์ด๋ฉฐ, ์ ์ฒด ํ์ต ํ์ดํ๋ผ์ธ์ ์ ๋ขฐ์ฑ๊ณผ ํจ์จ์ฑ์ ๋์ด๋ ๋ฐ ์ค์ํ ์ญํ ์ ํ๋ค.
๋ชจ๋ธ ํ์ต์ ์ด 30๋ฒ์ epoch์ ๊ฑธ์ณ ์งํ๋๋ฉฐ, ํ ๋ฒ์ epoch๋ง๋ค ์ ์ฒด CIFAR-100 ํ์ต ๋ฐ์ดํฐ๋ฅผ ํ ๋ฒ์ฉ ์ํํ๊ฒ ๋๋ค. ํ์ต ๊ณผ์ ์์ ์์ค ํจ์๋ ๋ค์ค ํด๋์ค ๋ถ๋ฅ์ ์ ํฉํ CrossEntropyLoss๋ฅผ ์ฌ์ฉํ๋ฉฐ, ์ตํฐ๋ง์ด์ ๋ ํ๋ฅ ์ ๊ฒฝ์ฌ ํ๊ฐ๋ฒ(SGD: Stochastic Gradient Descent)์ momentum๊ณผ weight decay๋ฅผ ์ถ๊ฐํ์ฌ ์์ ์ ์ธ ์ต์ ํ๋ฅผ ์ ๋ํ๋ค. ํ์ต๋ฅ ์ ์ด๊ธฐ์ 0.1๋ก ์ค์ ๋๋ฉฐ, StepLR ์ค์ผ์ค๋ฌ๋ฅผ ํตํด 10 ์ํญ๋ง๋ค 1/10์ฉ ๊ฐ์์ํจ๋ค. ์ด๋ ์ด๋ฐ์๋ ๋น ๋ฅด๊ฒ ํ์ตํ๊ณ , ํ๋ฐ์๋ ์ฒ์ฒํ fine-tuningํ๋๋ก ์ ๋ํ๋ ๊ฒ์ด๋ค. ๋ชจ๋ธ์ ํ์ต ๋ชจ๋(model.train())์์ ๊ฐ ๋ฐฐ์น ๋ฐ์ดํฐ๋ฅผ ์์ ํ(forward)์์ผ ์์ธก ๊ฒฐ๊ณผ๋ฅผ ์ถ๋ ฅํ๊ณ , ์ด ๊ฒฐ๊ณผ๋ฅผ ์ค์ ์ ๋ต๊ณผ ๋น๊ตํ์ฌ ์์ค(loss)์ ๊ณ์ฐํ ํ, ์ญ์ ํ(backward)๋ฅผ ํตํด ๊ฐ์ค์น์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ตฌํ๊ณ ์ด๋ฅผ ๋ฐํ์ผ๋ก ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธํ๋ค. ๋ํ ๊ฐ epoch๋ง๋ค ๋์ ๋ ์์ค๊ณผ ์ ํ๋๋ฅผ ์ถ๋ ฅํ์ฌ ํ์ต์ด ์ด๋ป๊ฒ ์งํ๋๊ณ ์๋์ง ํ์ธํ ์ ์๋ค. ๋ง์ง๋ง์ผ๋ก ํ์ต์ด ์ข ๋ฃ๋ ํ์๋ ๋ชจ๋ธ์ ํ์ต๋ ๊ฐ์ค์น ํ๋ผ๋ฏธํฐ๋ฅผ .pth ํ์ผ๋ก ์ ์ฅํ์ฌ ํ์ต๋ ๋ชจ๋ธ์ ์ถํ ํ์ฉํ ์ ์๋๋ก ํ๋ค.
3. Reliability Diagram ์๊ฐํ
์ด๋ฒ์๋ ์์ ํ์ต์ด ์๋ฃ๋ ResNet-34 ๋ชจ๋ธ์ ๋ถ๋ฌ์, CIFAR-100 ํ ์คํธ์ ์ ๋ํด ์์ธก์ ์ํํ ํ, ์์ธก์ ์ ๋ขฐ๋(confidence)์ ์ค์ ์ ๋ต ์ฌ๋ถ ๊ฐ์ ๊ด๊ณ๋ฅผ ๋ถ์ํ๊ณ ์๊ฐํํ๋ค. ์ด๋ฅผ ์ํด Expected Calibration Error (ECE) ๋ฅผ ์์น๋ก ๊ณ์ฐํ๊ณ , Reliability Diagram์ ํตํด ๋ชจ๋ธ์ ์ ๋ขฐ๋๊ฐ ์ผ๋ง๋ ์ ๋ณด์ ๋์ด ์๋์ง๋ฅผ ์๊ฐ์ ์ผ๋ก ํ๊ฐํ๋ค.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
# โ
๋๋ฐ์ด์ค ์ค์ ๋ฐ ๋ชจ๋ธ ๋ก๋
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet34(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 100)
model.load_state_dict(torch.load("./snapshots/resnet34_cifar100_exp/resnet34_cifar100.pth", map_location=device))
model = model.to(device)
# โ
CIFAR-100 ํ
์คํธ์
์ค๋น
mean = (0.5071, 0.4866, 0.4409)
std = (0.2673, 0.2564, 0.2762)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_dataset = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
# โ
ECE ๊ณ์ฐ ๋ฐ Reliability Diagram ์์ฑ
def compute_reliability_and_ece(model, dataloader, device, n_bins=15):
model.eval()
bin_boundaries = torch.linspace(0, 1, n_bins + 1).to(device)
bin_corrects = torch.zeros(n_bins).to(device)
bin_confidences = torch.zeros(n_bins).to(device)
bin_counts = torch.zeros(n_bins).to(device)
total_samples = 0
with torch.no_grad():
for images, labels in dataloader:
images, labels = images.to(device), labels.to(device)
logits = model(images)
probs = F.softmax(logits, dim=1)
confs, preds = torch.max(probs, dim=1)
corrects = preds.eq(labels)
total_samples += labels.size(0)
for i in range(n_bins):
in_bin = (confs > bin_boundaries[i]) & (confs <= bin_boundaries[i+1])
bin_counts[i] += in_bin.sum()
if in_bin.sum() > 0:
bin_corrects[i] += corrects[in_bin].float().sum()
bin_confidences[i] += confs[in_bin].sum()
nonzero = bin_counts > 0
accs = bin_corrects[nonzero] / bin_counts[nonzero]
confs = bin_confidences[nonzero] / bin_counts[nonzero]
bin_centers = ((bin_boundaries[:-1] + bin_boundaries[1:]) / 2)[nonzero]
filtered_counts = bin_counts[nonzero]
ece = torch.sum((filtered_counts / total_samples) * torch.abs(accs - confs)).item()
return bin_centers.cpu(), accs.cpu(), confs.cpu(), ece
def draw_reliability_diagram(bin_centers, accs, confs, ece, name, save_dir):
os.makedirs(save_dir, exist_ok=True)
width = 1.0 / len(bin_centers)
plt.figure(figsize=(5, 5))
plt.bar(bin_centers, accs, width=width * 0.9, color='blue', edgecolor='black', alpha=0.8, label="Accuracy")
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label="Perfect Calibration")
for x, acc, conf in zip(bin_centers, accs, confs):
lower = min(acc, conf)
upper = max(acc, conf)
plt.fill_between([x - width / 2, x + width / 2], lower, upper,
color='red', alpha=0.3, hatch='//')
plt.xlabel("Confidence")
plt.ylabel("Accuracy")
plt.title(f"Reliability Diagram: {name}")
plt.text(0.02, 0.6, f"ECE = {ece * 100:.2f}%", fontsize=12,
bbox=dict(facecolor='lavender', edgecolor='gray'))
plt.legend(loc='upper left')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, f"{name}_reliability.png"))
plt.close()
# โ
์คํ
bin_centers, accs, confs, ece = compute_reliability_and_ece(model, test_loader, device)
draw_reliability_diagram(bin_centers, accs, confs, ece, name="ResNet34_CIFAR100", save_dir="./snapshots/resnet34_cifar100_exp")
์ ์ฝ๋๋ ํ์ต๋ ๋ชจ๋ธ์ CIFAR-100 ํ ์คํธ ๋ฐ์ดํฐ์ ์ ์ฉํ์ฌ ์์ธก์ ์ํํ ํ, ์์ธก ๊ฒฐ๊ณผ์ ๋ํ confidence score์ ์ค์ ์ ๋ต ์ฌ๋ถ๋ฅผ ๋น๊ตํ์ฌ ์ ๋ขฐ๋(calibration)๋ฅผ ํ๊ฐํ๋ ๊ณผ์ ์ ์ํํ๋ค. compute_reliability_and_ece ํจ์๋ confidence ๊ฐ ๋ฒ์๋ฅผ ์ผ์ ํ ๊ฐ๊ฒฉ์ผ๋ก ๋๋ bin์ ๊ธฐ์ค์ผ๋ก ๊ฐ bin ๋ด์ ํ๊ท confidence์ ์ค์ ์ ํ๋(accuracy)๋ฅผ ๊ณ์ฐํ๋ฉฐ, ์ด๋ฅผ ๋ฐํ์ผ๋ก Expected Calibration Error (ECE)๋ฅผ ์์น๋ก ๋ฐํํ๋ค. ์ด ๊ฐ์ด ์์์๋ก ๋ชจ๋ธ์ ์์ธก ํ๋ฅ ์ด ์ค์ ์ ๋ต๋ฅ ๊ณผ ์ ์ผ์นํ๋ค๋ ๊ฒ์ ์๋ฏธํ๋ค.
๋ํ, draw_reliability_diagram ํจ์๋ ์ด๋ฌํ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก ์ ๋ขฐ๋ ๊ทธ๋ํ๋ฅผ ์๊ฐํํ๋ฉฐ, ์ด์์ ์ธ ๊ฒฝ์ฐ์ธ ๋๊ฐ์ (์๋ฒฝํ ๋ณด์ )์ ๊ธฐ์ค์ผ๋ก ๋ชจ๋ธ์ด ๊ณผ์ ํ๊ฑฐ๋ ๊ณผ์์ ํ๋ ๊ตฌ๊ฐ์ ์๊ฐ์ ์ผ๋ก ํ์ธํ ์ ์๋๋ก ํ๋ค. ๋ง๋๋ ๊ฐ confidence ๊ตฌ๊ฐ์ ์ค์ ์ ํ๋๋ฅผ ๋ํ๋ด๋ฉฐ, ํ๋ ๋ง๋์ ํ์ ๋๊ฐ์ ์ฌ์ด์ ๋นจ๊ฐ ์์์ ์ ๋ขฐ๋ ์ค์ฐจ๋ฅผ ์๊ฐ์ ์ผ๋ก ํํํ๋ค. ์ด ๊ฒฐ๊ณผ๋ฅผ ํตํด ๋ชจ๋ธ์ด ์ผ๋ง๋ calibrated ๋์ด ์๋์ง ํ์ธํ ์ ์๊ณ , ์ดํ ๋ณด์ ๊ธฐ๋ฒ(์: temperature scaling)์ ํ์์ฑ์ ํ๊ฐํ๋ ๊ธฐ๋ฐ ์๋ฃ๊ฐ ๋๋ค.
4. Temperature Scaling ์คํ
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch.nn as nn
import os
import matplotlib.pyplot as plt
# Temperature Scaler ์ ์ (ํ์ต ์์ด ๊ณ ์ ๋ T ๊ฐ ์ฌ์ฉ)
class TemperatureScaler(nn.Module):
def __init__(self, temperature: float):
super().__init__()
self.temperature = nn.Parameter(torch.tensor([temperature]), requires_grad=False)
def forward(self, logits):
return logits / self.temperature
# ๋ชจ๋ธ์ Temperature Scaler๋ก ๊ฐ์ธ๋ ๋ํผ
class WrappedModel(nn.Module):
def __init__(self, base_model, temp_scaler):
super().__init__()
self.base_model = base_model
self.temp_scaler = temp_scaler
def forward(self, x):
logits = self.base_model(x)
return self.temp_scaler(logits)
# ๋ค์ํ T ๊ฐ์ ๋ํด ECE ๊ณ์ฐ ๋ฐ Reliability Diagram ์ ์ฅ
def evaluate_multiple_temperatures_with_plots(model, test_loader, device, T_values, output_dir):
os.makedirs(output_dir, exist_ok=True)
ece_list = []
for T in T_values:
print(f"\n๐งช Evaluating T = {T}")
temp_scaler = TemperatureScaler(temperature=T).to(device)
wrapped_model = WrappedModel(model, temp_scaler).to(device)
# ์ ๋ขฐ๋ ํ๊ฐ
bin_centers, accs, confs, bin_counts, total_samples, ece = compute_reliability_and_ece(
wrapped_model, test_loader, device, verbose_under_100=False
)
ece_list.append(ece)
# Reliability Diagram ์ ์ฅ
draw_fancy_reliability_diagram(
bin_centers, accs, confs, bin_counts, total_samples, ece,
name=f"T={T}", output_dir=output_dir
)
return T_values, ece_list
# T์ ๋ฐ๋ฅธ ECE ๋ณํ๋ฅผ ์๊ฐํ
def plot_temperature_vs_ece(T_values, ece_list, save_path):
plt.figure(figsize=(6, 4))
plt.plot(T_values, [ece * 100 for ece in ece_list], marker='o', linestyle='-', color='purple')
plt.xlabel("Temperature (T)")
plt.ylabel("ECE (%)")
plt.title("ECE vs Temperature")
plt.grid(True)
plt.tight_layout()
plt.savefig(save_path)
plt.close()