Post

On Calibration of Modern Neural Networks

On Calibration of Modern Neural Networks

๐Ÿ“„ ๋…ผ๋ฌธ ์ •๋ณด

โ€œOn Calibration of Modern Neural Networksโ€ (ICML 2017)

๐Ÿ“Œ ์—ฐ๊ตฌ ๊ฐœ์š”

Confidence Calibration์€ ๋ชจ๋ธ์ด ์˜ˆ์ธกํ•œ ํ™•๋ฅ ์ด ์‹ค์ œ ์ •๋‹ต์ผ ๊ฐ€๋Šฅ์„ฑ๊ณผ ์–ผ๋งˆ๋‚˜ ์ผ์น˜ํ•˜๋Š”์ง€๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ๊ฐœ๋…์ด๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, ์–ด๋–ค ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด ๋ชจ๋ธ์ด ๊ณ ์–‘์ด์ผ ํ™•๋ฅ ์„ 0.9๋ผ๊ณ  ์˜ˆ์ธกํ–ˆ์„ ๋•Œ, ์ด๋Ÿฌํ•œ ์˜ˆ์ธก์ด ์ž˜ ๋ณด์ •(calibrated)๋˜์–ด ์žˆ๋‹ค๋ฉด ์‹ค์ œ๋กœ ๊ทธ ์ด๋ฏธ์ง€๊ฐ€ ๊ณ ์–‘์ด์ผ ํ™•๋ฅ ๋„ ์•ฝ 90%๊ฐ€ ๋˜์–ด์•ผ ํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

On Calibration of Modern Neural Networks ๋…ผ๋ฌธ์€ ๋น ๋ฅด๊ฒŒ ๋ฐœ์ „ํ•˜๊ณ  ์—ฐ๊ตฌ๋˜์–ด์ ธ ์˜ค๊ณ  ์žˆ๋Š” ResNet, DenseNet ๋“ฑ๊ณผ ๊ฐ™์€ ํ˜„๋Œ€์ ์ธ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ๋“ค์ด ๋†’์€ ๋ถ„๋ฅ˜ ์ •ํ™•๋„๋ฅผ ๋‹ฌ์„ฑํ•จ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ , ์˜คํžˆ๋ ค ํ™•๋ฅ  ๋ณด์ •(calibration) ์„ฑ๋Šฅ์€ ๋” ๋‚˜๋น ์กŒ๋‹ค๋Š” ์‚ฌ์‹ค์„ ์‹คํ—˜์ ์œผ๋กœ ๋ณด์ด๊ณ  ์žˆ๋‹ค. ๊ณผ๊ฑฐ์˜ ์–•์€ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ๋“ค์€ ์˜ˆ์ธก ํ™•๋ฅ ์ด ์‹ค์ œ ์ •๋‹ต ํ™•๋ฅ ๊ณผ ๋น„๊ต์  ์ž˜ ์ผ์น˜ํ–ˆ์ง€๋งŒ, ๊นŠ๊ณ  ๋ณต์žกํ•œ ๊ตฌ์กฐ๋ฅผ ๊ฐ€์ง„ ์ตœ์‹  ๋ชจ๋ธ๋“ค์€ ์ž์‹  ์žˆ๊ฒŒ ์˜ˆ์ธก์„ ํ•˜๋ฉด์„œ๋„ ๊ทธ ํ™•๋ฅ ์ด ์‹ค์ œ ์ •๋‹ต๋ฅ ๊ณผ ๋ถˆ์ผ์น˜ํ•˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ๋‹ค๋Š” ์ ์„ ์งš๊ณ  ์žˆ๋‹ค. ํ•œ ๋งˆ๋””๋กœ ๋งํ•ด ์ตœ์‹  ๋ชจ๋ธ๋“ค์€ Overconfident ๋˜์–ด์ ธ ์žˆ๋‹ค๋Š” ์ ์„ ๋ฐœ๊ฒฌํ•˜์˜€๋‹ค.

Desktop View Figure 1: Confidence histograms (top) and reliability diagrams (bottom) for a 5-layer LeNet (left) and a 110-layer ResNet (right) on CIFAR-100. Refer to the text below for detailed illustration.

์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋Š” ์ž์œจ์ฃผํ–‰, ์˜๋ฃŒ ์ง„๋‹จ, ๋ฒ•๋ฅ  ํŒ๋‹จ ๋“ฑ๊ณผ ๊ฐ™์ด ์˜ˆ์ธก ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ ์‹ ๋ขฐ๋„๊ฐ€ ๋งค์šฐ ์ค‘์š”ํ•œ ์‘์šฉ ๋ถ„์•ผ์—์„œ ํŠนํžˆ ๋„๋“œ๋ผ์งˆ ์ˆ˜ ์žˆ๊ธฐ์— ๋ชจ๋ธ์ด ๋‹จ์ˆœํžˆ ์ •ํ™•ํ•  ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, ์ž์‹ ์˜ ์˜ˆ์ธก์ด ์–ผ๋งˆ๋‚˜ ํ™•์‹คํ•œ์ง€์— ๋Œ€ํ•œ ํ‘œํ˜„ ๋˜ํ•œ ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ์–ด์•ผ ํ•œ๋‹ค.

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์–‘ํ•œ ์‚ฌํ›„ ๋ณด์ •(post-hoc calibration) ๋ฐฉ๋ฒ•๋“ค์„ ์‹คํ—˜์ ์œผ๋กœ ๋น„๊ตํ•˜๊ณ , ๊ทธ ์ค‘์—์„œ๋„ Temperature Scaling์ด๋ผ๋Š” ๋‹จ ํ•˜๋‚˜์˜ ์Šค์นผ๋ผ ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์ด ๋งค์šฐ ํšจ๊ณผ์ ์ด๋ผ๋Š” ์‚ฌ์‹ค์„ ๋ฐํ˜€๋ƒˆ๋‹ค. ๋ณธ ๊ธ€์—์„œ๋Š” ์ด์˜ ์‹คํ—˜ ์ฝ”๋“œ๋„ ๊ตฌํ˜„ํ•˜์˜€๋‹ค.


๐Ÿ” ์‹ ๊ฒฝ๋ง์˜ Overconfidence ์›์ธ ๋ถ„์„

์ตœ๊ทผ์˜ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ๋“ค์€ ๋†’์€ ์ •ํ™•๋„๋ฅผ ์ž๋ž‘ํ•˜์ง€๋งŒ, ๊ทธ confidence (์˜ˆ์ธก ํ™•๋ฅ ) ๋Š” ์‹ค์ œ ์ •๋‹ต๋ฅ ๊ณผ ์ž˜ ๋งž์ง€ ์•Š๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋งŽ๋‹ค. ์ด ํ˜„์ƒ์„ miscalibration (๋ถˆ์™„์ „ํ•œ ๋ณด์ •) ์ด๋ผ๊ณ  ํ•˜๋ฉฐ, ๊ทธ ์›์ธ๊ณผ ๊ด€๋ จ ์š”์†Œ๋“ค์„ ์šฐ์„  ๋ถ„์„ํ•˜์˜€๋‹ค.

Desktop View Figure 2: The effect of network depth (far left), width (middle left), Batch Normalization (middle right), and weight decay (far right) on miscalibration, as measured by ECE (lower is better)

1. ๋ชจ๋ธ ์šฉ๋Ÿ‰์˜ ์ฆ๊ฐ€ (Model Capacity)

  • ์ตœ๊ทผ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ๋“ค์€ ๋ ˆ์ด์–ด ์ˆ˜์™€ ํ•„ํ„ฐ ์ˆ˜๊ฐ€ ๊ธ‰๊ฒฉํžˆ ์ฆ๊ฐ€ํ•˜์—ฌ, ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋” ์ž˜ ๋งž์ถœ ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ ์šฉ๋Ÿ‰(capacity) ์„ ๊ฐ–์ถ”๊ฒŒ ๋˜์—ˆ๋‹ค.
  • ํ•˜์ง€๋งŒ ๋ชจ๋ธ ์šฉ๋Ÿ‰์ด ์ปค์งˆ์ˆ˜๋ก ์˜คํžˆ๋ ค confidence๊ฐ€ ์‹ค์ œ ์ •ํ™•๋„๋ณด๋‹ค ๊ณผ๋„ํ•˜๊ฒŒ ๋†’์•„์ง€๋Š” ๊ณผ์‹ , ์ฆ‰ overconfidence ํ•˜๋Š” ๊ฒฝํ–ฅ์ด ๋‚˜ํƒ€๋‚œ๋‹ค.

์‹คํ—˜ ๊ฒฐ๊ณผ (ResNet on CIFAR-100):

  • ๊นŠ์ด(depth)๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค๋ฉด Error์€ ์ค„์–ด๋“œ๋‚˜ ECE๊ฐ€ ์ฆ๊ฐ€
  • ํ•„ํ„ฐ ์ˆ˜(width)๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค๋ฉด Error์€ ํ™•์—ฐํžˆ ์ค„์–ด๋“œ๋‚˜, ECE๊ฐ€ ์ฆ๊ฐ€

๋†’์€ capacity๋Š” overfitting์„ ์•ผ๊ธฐํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋Ÿฌํ•œ ๊ฒฝ์šฐ ์ •ํ™•๋„๋Š” ์ข‹์•„์ ธ๋„ ํ™•๋ฅ ์˜ ํ’ˆ์งˆ์€ ๋‚˜๋น ์ง„๋‹ค.


2. Batch Normalization์˜ ์˜ํ–ฅ

  • Batch Normalization์€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์˜ ํ•™์Šต์„ ์•ˆ์ •ํ™”์‹œํ‚ค๊ณ  ๋น ๋ฅด๊ฒŒ ๋งŒ๋“œ๋Š” ๊ธฐ๋ฒ•์œผ๋กœ, ํ˜„๋Œ€ ์•„ํ‚คํ…์ฒ˜์—์„œ ํ•„์ˆ˜์ ์œผ๋กœ ์‚ฌ์šฉ๋œ๋‹ค.
  • ํ•˜์ง€๋งŒ, BN์„ ์‚ฌ์šฉํ•œ ๋ชจ๋ธ๋“ค์€ ์ •ํ™•๋„๋Š” ์˜ฌ๋ผ๊ฐ€์ง€๋งŒ calibration์€ ์˜คํžˆ๋ ค ๋‚˜๋น ์ง€๋Š” ํ˜„์ƒ์ด ๋‚˜ํƒ€๋‚œ๋‹ค.

์‹คํ—˜ ๊ฒฐ๊ณผ (6-layer ConvNet on CIFAR-100):

  • BN์„ ์ ์šฉํ•œ ConvNet์€ ์ •ํ™•๋„๊ฐ€ ์•ฝ๊ฐ„ ์˜ฌ๋ผ๊ฐ€์ง€๋งŒ(Error ๊ฐ์†Œ), ECE๋Š” ๋šœ๋ ทํ•˜๊ฒŒ ์ฆ๊ฐ€

BN์€ ๋‚ด๋ถ€ ํ™œ์„ฑ ๋ถ„ํฌ๋ฅผ ์ •๊ทœํ™”ํ•˜์—ฌ ํ•™์Šต์„ ๋” ์ž˜ ๋˜๊ฒŒ ํ•˜์ง€๋งŒ, ๊ฒฐ๊ณผ์ ์œผ๋กœ ์ถœ๋ ฅ ํ™•๋ฅ ์ด ๋” ๊ณผ์‹ ๋œ(overconfident) ์ƒํƒœ๋กœ ๋‚˜ํƒ€๋‚˜ calibration์—๋Š” ๋ถ€์ •์ ์ธ ์˜ํ–ฅ์„ ๋ฏธ์น˜๊ฒŒ ๋œ๋‹ค.


3. Weight Decay ๊ฐ์†Œ์˜ ์˜ํ–ฅ

  • ์ „ํ†ต์ ์œผ๋กœ weight decay ๋Š” ๊ณผ์ ํ•ฉ์„ ๋ง‰๊ธฐ ์œ„ํ•œ ์ •๊ทœํ™” ๋ฐฉ๋ฒ•์œผ๋กœ ๋„๋ฆฌ ์‚ฌ์šฉ๋˜์–ด ์™”์œผ๋ฉฐ, overfitting์„ ๋ฐฉ์ง€ํ•˜๊ธฐ ์œ„ํ•ด ๊ฐ€์ค‘์น˜์— ํŒจ๋„ํ‹ฐ๋ฅผ ์ฃผ๋Š” ์ •๊ทœํ™” ๊ธฐ๋ฒ•์ด๋‹ค.
  • ์ตœ๊ทผ์—๋Š” BN์˜ ์ •๊ทœํ™” ํšจ๊ณผ ๋•Œ๋ฌธ์— weight decay ์‚ฌ์šฉ๋Ÿ‰์ด ์ค„์–ด๋“œ๋Š” ์ถ”์„ธ์ด๋‹ค.
  • ํ•˜์ง€๋งŒ ์‹คํ—˜์—์„œ๋Š” weight decay๋ฅผ ์ฆ๊ฐ€์‹œํ‚ฌ์ˆ˜๋ก calibration์€ ๊ฐœ์„ ๋˜๋Š” ๊ฒฝํ–ฅ์„ ๋ณด์ธ๋‹ค.

์‹คํ—˜ ๊ฒฐ๊ณผ (ResNet-110 on CIFAR-100):

  • Weight decay๋ฅผ ์ฆ๊ฐ€์‹œํ‚ค๋ฉด ๋ถ„๋ฅ˜ ์ •ํ™•๋„(Error)๋Š” ํŠน์ • ๊ตฌ๊ฐ„์—์„œ ์ตœ์ ์ ์„ ์ฐ๊ณ  ์ดํ›„ ๋‹ค์‹œ ์ฆ๊ฐ€
  • ECE๋Š” weight decay๊ฐ€ ์ฆ๊ฐ€ํ• ์ˆ˜๋ก ๊ฐ์†Œํ•˜๋Š” ๊ฒฝํ–ฅ์„ ๋ณด์ž„

์ฆ‰, ์ •ํ™•๋„๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋Š” weight decay ์„ค์ •๊ณผ calibration์„ ์ตœ์ ํ™”ํ•˜๋Š” ์„ค์ •์€ ์„œ๋กœ ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ •ํ™•๋„๋Š” ์œ ์ง€๋˜๋”๋ผ๋„ confidence๋Š” ์™œ๊ณก๋  ์ˆ˜ ์žˆ๋‹ค.


4. NLL ๊ณผ์ ํ•ฉ ํ˜„์ƒ (Overfitting to NLL)

Desktop View Figure 3: Test error and NLL of a 110-layer ResNet with stochastic depth on CIFAR-100 during training

  • ์‹คํ—˜์—์„œ๋Š” learning rate๊ฐ€ ๋‚ฎ์•„์ง€๋Š” ๊ตฌ๊ฐ„์—์„œ test error๋Š” ๊ณ„์† ์ค„์–ด๋“œ๋Š” ๋ฐ˜๋ฉด, NLL์€ ๋‹ค์‹œ ์ฆ๊ฐ€ํ•˜๋Š” ํ˜„์ƒ์„ ํ™•์ธํ•˜์˜€๋‹ค.
  • ์ด๋Š” ๋ชจ๋ธ์ด ์ •ํ™•๋„๋Š” ๋†’์ด์ง€๋งŒ confidence๊ฐ€ ์‹ค์ œ๋ณด๋‹ค ๊ณผ๋„ํ•œ ์ƒํƒœ๋กœ ํ•™์Šต์ด ์ง„ํ–‰๋˜๊ณ  ์žˆ์Œ์„ ์˜๋ฏธํ•œ๋‹ค.

์‹คํ—˜ ๊ฒฐ๊ณผ (ResNet-110 + stochastic depth on CIFAR-100):

  • Epoch 250 ์ดํ›„ learning rate ๊ฐ์†Œ
  • ์ดํ›„ test error๋Š” ๊ฐ์†Œ (29% โ†’ 27%)ํ•˜์ง€๋งŒ, NLL์€ ์ฆ๊ฐ€

์ตœ์‹  ์‹ ๊ฒฝ๋ง์€ ํ•™์Šต ํ›„๋ฐ˜๋ถ€์—์„œ NLL์„ ๊ณ„์† ์ตœ์†Œํ™”ํ•˜๋ ค๋Š” ๊ณผ์ •์—์„œ confidence๋ฅผ ๊ณผ๋„ํ•˜๊ฒŒ ๋†’์ด๋Š” ๊ฒฝํ–ฅ์ด ์žˆ์œผ๋ฉฐ, ์ด๋กœ ์ธํ•ด ์‹ค์ œ ์ •๋‹ต๋ฅ ๋ณด๋‹ค ๋†’์€ ํ™•๋ฅ ์„ ์ถœ๋ ฅํ•˜๋Š” overconfidentํ•œ ์ƒํƒœ๋กœ calibration ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•œ๋‹ค.


๐Ÿ“ Calibration์˜ ์ •์˜ ๋ฐ ์ธก์ • ๋ฐฉ๋ฒ•

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๋‹ค์ค‘ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋ฅผ ๋‹ค๋ฃจ๊ณ  ์žˆ์œผ๋ฉฐ, ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ ์ฃผ์–ด์ง„ ์ž…๋ ฅ \(X \in \mathcal{X}, \quad Y \in \{1, \dots, K\}\) ๋ฅผ ์˜ˆ์ธกํ•˜๋Š”๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ๋ชจ๋ธ์˜ ์˜ˆ์ธก ํ™•๋ฅ ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜๋œ๋‹ค.

\[h(X) = (\hat{Y}, \hat{P})\]

์—ฌ๊ธฐ์„œ \(\hat{Y}\)๋Š” ์˜ˆ์ธก๋œ ํด๋ž˜์Šค, \(\hat{P}\)๋Š” ๊ฐ ํด๋ž˜์Šค์— ๋Œ€ํ•œ ํ™•๋ฅ  ๋ถ„ํฌ์ด๋ฉฐ, softmax ์ถœ๋ ฅ์˜ ์ตœ๋Œ“๊ฐ’์œผ๋กœ ์ •์˜๋œ๋‹ค.

๊ทธ๋Ÿผ ์™„๋ฒฝํ•˜๊ฒŒ ๋ณด์ •๋œ ๋ชจ๋ธ์˜ ์ •์˜๋Š” ์–ด๋–ป๊ฒŒ ๋ ๊นŒ? ๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜ํ•˜๊ณ  ์žˆ๋‹ค.

\[P(\hat{Y} = Y \mid \hat{P} = p) = p, \quad \forall p \in [0, 1]\]

์œ„ ์‹์—์„œ ์•Œ ์ˆ˜ ์žˆ๋“ฏ์ด ๋ชจ๋ธ์ด ์™„๋ฒฝํžˆ ๋ณด์ •(calibrated)๋˜์–ด ์žˆ๋‹ค๋Š” ๊ฒƒ์€, ์˜ˆ์ธกํ•œ ํ™•๋ฅ  ๊ฐ’์ด ์‹ค์ œ ์ •๋‹ต๋ฅ ๊ณผ ์ผ์น˜ํ•˜๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด ๋ชจ๋ธ์ด 100๊ฐœ์˜ ์ƒ˜ํ”Œ์— ๋Œ€ํ•ด ๋ชจ๋‘ 0.8์˜ confidence๋ฅผ ์ถœ๋ ฅํ–ˆ๋‹ค๋ฉด, ์‹ค์ œ๋กœ ๊ทธ ์ค‘ ์•ฝ 80๊ฐœ๊ฐ€ ๋งž์•„์•ผ ์™„๋ฒฝํžˆ ๋ณด์ •๋œ ๊ฒƒ์ด๋‹ค.

๐Ÿ” Reliability Diagram (์‹ ๋ขฐ๋„ ๋‹ค์ด์–ด๊ทธ๋žจ)

์˜ˆ์ธก ํ™•๋ฅ  \(\hat{P}\)๋ฅผ ๊ตฌ๊ฐ„์œผ๋กœ ์ž˜๊ฒŒ ๋‚˜๋ˆ„๊ณ , ๊ฐ ๊ตฌ๊ฐ„์—์„œ์˜ ์‹ค์ œ ์ •๋‹ต๋ฅ (accuracy)๊ณผ ํ‰๊ท  confidence๋ฅผ ๋น„๊ตํ•œ๋‹ค. ๋งŒ์•ฝ ๋ชจ๋ธ์ด ์ž˜ ๋ณด์ •๋˜์–ด์ ธ ์žˆ๋‹ค๋ฉด, ๊ฐ ๊ตฌ๊ฐ„์—์„œ๋Š” ์•„๋ž˜์˜ ๊ด€๊ณ„์‹์ด ์„ฑ๋ฆฝํ•ด์•ผํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

\[\text{acc}(B_m) = \frac{1}{|B_m|} \sum_{i \in B_m} \mathbf{1}(\hat{y}_i = y_i)\] \[\text{conf}(B_m) = \frac{1}{|B_m|} \sum_{i \in B_m} \hat{p}_i\] \[\text{Accuracy}(B_m) \approx \text{Confidence}(B_m)\]

์—ฌ๊ธฐ์„œ \(\text{acc}(B_m)\)๋Š” ๊ตฌ๊ฐ„ \(B_m\)์— ์†ํ•˜๋Š” ์ƒ˜ํ”Œ๋“ค์˜ ์‹ค์ œ ์ •๋‹ต๋ฅ , \(\text{conf}(B_m)\)๋Š” ๊ตฌ๊ฐ„ \(B_m\)์— ์†ํ•˜๋Š” ์ƒ˜ํ”Œ๋“ค์˜ ํ‰๊ท  confidence๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ๋งŒ์•ฝ ๋ชจ๋ธ์ด ์ž˜ ๋ณด์ •๋˜์–ด ์žˆ๋‹ค๋ฉด, ๋‘ ๊ฐ’์€ ์„œ๋กœ ๋น„์Šทํ•ด์•ผ ํ•œ๋‹ค.

๐Ÿ“ Expected Calibration Error (ECE)

ECE๋Š” ๋ชจ๋ธ์˜ ์ „์ฒด calibration ์„ฑ๋Šฅ์„ ์ˆ˜์น˜์ ์œผ๋กœ ์ธก์ •ํ•˜๋Š” ๋Œ€ํ‘œ์ ์ธ ์ง€ํ‘œ๋กœ, ๊ฐ bin์— ๋Œ€ํ•ด ์˜ˆ์ธก ํ™•๋ฅ ๊ณผ ์‹ค์ œ ์ •๋‹ต๋ฅ  ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ํ‰๊ท ํ•˜์—ฌ ๊ณ„์‚ฐํ•œ๋‹ค. \(M\)๊ฐœ์˜ bin์œผ๋กœ ๋‚˜๋ˆ„๊ณ , ๊ฐ bin \(B_m\)์— ๋Œ€ํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ •์˜๋œ๋‹ค.

\[\text{ECE} = \sum_{m=1}^{M} \frac{|B_m|}{n} \left| \text{acc}(B_m) - \text{conf}(B_m) \right|\]

๐Ÿ“ Maximum Calibration Error (MCE)

MCE๋Š” ๊ฐ€์žฅ ํฐ ์˜ค์ฐจ๋ฅผ ๋ณด์ธ bin์˜ calibration gap์„ ์ธก์ •ํ•˜๊ฒŒ ๋˜๋ฉฐ, ์‰ฝ๊ฒŒ ๋งํ•ด โ€œ์ตœ์•…์˜ ๋ณด์ • ์‹คํŒจโ€ ์ •๋„๋ฅผ ๋‚˜ํƒ€๋‚ธ๋‹ค๊ณ  ์ƒ๊ฐํ•  ์ˆ˜ ์žˆ๋‹ค. ์•ˆ์ „์ด ์ค‘์š”ํ•œ ์‹œ์Šคํ…œ์—์„œ ๋งค์šฐ ์ค‘์š”ํ•œ ์ง€ํ‘œ๋กœ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ๋‹ค.

\[\text{MCE} = \max_{m \in \{1, \dots, M\}} \left| \text{acc}(B_m) - \text{conf}(B_m) \right|\]

๐Ÿ› ๏ธ Calibration Methods

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋ฏธ ํ•™์Šต๋œ ๋ชจ๋ธ์— ๋Œ€ํ•ด ํ™•๋ฅ  ๋ณด์ •์„ ์œ„ํ•œ ์‚ฌํ›„ ์ฒ˜๋ฆฌ(Post-hoc) ๋ฐฉ๋ฒ•๋“ค์„ ์†Œ๊ฐœํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ์ด๋“ค์€ ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ตฌ์กฐ๋‚˜ ์ •ํ™•๋„๋Š” ์œ ์ง€ํ•˜๋ฉด์„œ, ์˜ˆ์ธก ํ™•๋ฅ (confidence)์ด ์‹ค์ œ ์ •๋‹ต๋ฅ ๊ณผ ๋” ์ž˜ ์ผ์น˜ํ•˜๋„๋ก ๋งŒ๋“ค์–ด์ค€๋‹ค.

๐Ÿ“˜ 1. Calibrating Binary Models

์šฐ์„  ๋‹ค๋ฃฐ ๋‚ด์šฉ์€ ์ด์ง„ ๋ถ„๋ฅ˜(binary classification) ์ƒํ™ฉ์—์„œ์˜ confidence calibration ๊ธฐ๋ฒ•๋“ค์„ ์•Œ์•„๋ณด๊ณ ์ž ํ•œ๋‹ค. ์ด ๊ฒฝ์šฐ ํด๋ž˜์Šค ๋ผ๋ฒจ์€ \(Y \in \{0, 1\}\)์ด๋ฉฐ, ๋ชจ๋ธ์€ positive ํด๋ž˜์Šค(1)์— ๋Œ€ํ•œ ํ™•๋ฅ  \(\hat{p}_i\)์™€ logit \(z_i \in \mathbb{R}\)๋ฅผ ์ถœ๋ ฅํ•œ๋‹ค. ํ™•๋ฅ ์€ ๋ณดํ†ต sigmoid ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ณ„์‚ฐ๋œ๋‹ค:

\[\hat{p}_i = \sigma(z_i) = \frac{1}{1 + \exp(-z_i)}\]

Calibration์€ \(\hat{p}_i\) ๋˜๋Š” \(z_i\)๋ฅผ ์กฐ์ •ํ•˜์—ฌ ์ƒˆ๋กœ์šด calibrated ํ™•๋ฅ  \(\hat{q}_i\)๋ฅผ ์–ป๋Š” ๊ฒƒ์ด๋‹ค.

A. Histogram Binning

Histogram Binning์€ ๋ชจ๋ธ์ด ์ถœ๋ ฅํ•˜๋Š” ํ™•๋ฅ ์„ ์ผ์ • ๊ตฌ๊ฐ„(bin)์œผ๋กœ ๋‚˜๋ˆ„๊ณ , ๊ฐ ๊ตฌ๊ฐ„ ๋‚ด์—์„œ์˜ ์‹ค์ œ ์ •๋‹ต ๋น„์œจ์„ ์ƒˆ๋กœ์šด ๋ณด์ • ํ™•๋ฅ ๋กœ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์ด๋‹ค. ๋ชจ๋“  \(\hat{p}_i\)๋ฅผ \(M\)๊ฐœ์˜ ๊ตฌ๊ฐ„ \(B_1, B_2, ..., B_M\)์œผ๋กœ ๋‚˜๋ˆˆ๋‹ค. ๊ฐ bin๋งˆ๋‹ค ์ •๋‹ต๋ฅ ์„ ๊ณ„์‚ฐํ•˜์—ฌ ๊ทธ ๊ฐ’์„ ๋ณด์ •๋œ ํ™•๋ฅ ๋กœ ์‚ฌ์šฉํ•œ๋‹ค. ์ด๋ฅผ ์ˆ˜์‹์œผ๋กœ ์‚ดํŽด๋ณด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

  • ๊ตฌ๊ฐ„ ์ •์˜: \(B_m = (a_m, a_{m+1}]\) where \(0 = a_1 \leq a_2 \leq \dots \leq a_{M+1} = 1\)
  • ์ตœ์ ํ™” ๋ฌธ์ œ: \(\min_{\theta_1,\dots,\theta_M} \sum_{m=1}^{M} \sum_{i=1}^{n} \mathbf{1}(a_m \leq \hat{p}_i < a_{m+1}) (\theta_m - y_i)^2\)

์œ„์˜ ์ตœ์ ํ™” ๋ฌธ์ œ๋ฅผ ํ’€๊ฒŒ๋˜๋ฉด ๊ฐ bin์˜ ๋ณด์ • ํ™•๋ฅ ์€ bin ๋‚ด ๋ ˆ์ด๋ธ” ํ‰๊ท ์œผ๋กœ ๊ณ„์‚ฐ๋˜๊ฒŒ ๋œ๋‹ค.

\[\theta_m = \frac{1}{|B_m|} \sum_{i \in B_m} y_i\]

์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์‚ดํŽด๋ณด๋„๋ก Histogram Binning์„ ์‚ดํŽด๋ณด๋„๋ก ํ•˜๊ฒ ๋‹ค.

  • ์˜ˆ์ธก ํ™•๋ฅ ๊ณผ ์‹ค์ œ ์ •๋‹ต
์ƒ˜ํ”Œ ๋ฒˆํ˜ธ์˜ˆ์ธก ํ™•๋ฅ  \(\hat{p}\_i\)์‹ค์ œ ์ •๋‹ต \(y\_i\)
10.150
20.221
30.351
40.480
50.631
60.721
70.851
80.890
90.951
  • ์˜ˆ์ธก ํ™•๋ฅ ์„ bin์œผ๋กœ ๋‚˜๋ˆˆ ๊ฒฐ๊ณผ
Bin ๋ฒˆํ˜ธํ™•๋ฅ  ๊ตฌ๊ฐ„ํ•ด๋‹น ์ƒ˜ํ”Œ ๋ฒˆํ˜ธ์ •๋‹ต ๋ชฉ๋ก๋ณด์ • ํ™•๋ฅ  \(\theta\_m\)
1(0.0, 0.3]1, 2[0, 1]0.50
2(0.3, 0.7]3, 4, 5[1, 0, 1]0.67
3(0.7, 1.0]6, 7, 8, 9[1, 1, 0, 1]0.75
  • ๊ฐ bin์—์„œ์˜ ์‹ค์ œ ์ •๋‹ต ํ‰๊ท  (๋ณด์ • ํ™•๋ฅ )
Bin ๋ฒˆํ˜ธ์ •๋‹ต ๋ชฉ๋ก๋ณด์ • ํ™•๋ฅ  \(\theta\_m\)
1[0, 1]0.50
2[1, 0, 1]0.67
3[1, 1, 0, 1]0.75

๊ตฌํ˜„์ด ๊ฐ„๋‹จํ•˜๊ณ  ์–ด๋–ค ๋ชจ๋ธ์ด๋“  ์‚ฌํ›„ ๋ณด์ •์œผ๋กœ ์ ์šฉํ•  ์ˆ˜ ์žˆ์œผ๋‚˜, ์˜ˆ์ธก์ด ๊ณ„๋‹จ ํ•จ์ˆ˜์ฒ˜๋Ÿผ ๋ณ€๊ฒฝ๋˜์–ด ๋ถ€๋“œ๋Ÿฝ์ง€ ์•Š์œผ๋ฉฐ bin์„ ๋ช‡ ๊ฐœ ๊ฐ€์ง€๋А๋ƒ์— ๋”ฐ๋ผ ์„ฑ๋Šฅ์ด ํฌ๊ฒŒ ๋ณ€ํ™”ํ•  ์ˆ˜ ์žˆ๋‹ค.

B. Isotonic Regression

$\hat{p}_i$์— ๋Œ€ํ•ด ๋‹จ์กฐ ์ฆ๊ฐ€(monotonic) ํ•จ์ˆ˜ $f$๋ฅผ ํ•™์Šตํ•˜์—ฌ:

\[\hat{q}_i = f(\hat{p}_i)\]
  • ๋ชฉ์  ํ•จ์ˆ˜: \(\min_f \sum_{i=1}^n (f(\hat{p}_i) - y_i)^2 \quad \text{subject to } f \text{ is monotonic}\)

์ด ๋ฐฉ๋ฒ•์€ histogram๋ณด๋‹ค ์œ ์—ฐํ•˜์ง€๋งŒ, ๊ณผ์ ํ•ฉ์˜ ๊ฐ€๋Šฅ์„ฑ์ด ์žˆ์œผ๋ฉฐ ์ •๊ทœํ™”๊ฐ€ ํ•„์š”ํ•  ์ˆ˜ ์žˆ๋‹ค

C. Platt Scaling

๋ชจ๋ธ์ด ์ถœ๋ ฅํ•œ logit $z_i$๋ฅผ ์ž…๋ ฅ์œผ๋กœ ์‚ฌ์šฉํ•˜์—ฌ, sigmoid๋ฅผ ์ ์šฉํ•œ ๋กœ์ง€์Šคํ‹ฑ ํšŒ๊ท€๋ฅผ ์ˆ˜ํ–‰ํ•จ:

\[\hat{q}_i = \sigma(az_i + b) = \frac{1}{1 + \exp(-(az_i + b))}\]
  • $a$, $b$๋Š” validation set์—์„œ NLL์„ ์ตœ์†Œํ™”ํ•˜๋„๋ก ํ•™์Šต
  • ์‹ ๊ฒฝ๋ง ํŒŒ๋ผ๋ฏธํ„ฐ๋Š” ๊ณ ์ •๋จ

์ •ํ™•๋„๋Š” ๊ทธ๋Œ€๋กœ ์œ ์ง€ํ•˜๋ฉด์„œ ํ™•๋ฅ ์„ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•.

D. BBQ (Bayesian Binning into Quantiles)

Histogram Binning์˜ ํ™•์žฅ์œผ๋กœ, ์—ฌ๋Ÿฌ binning scheme์„ ๊ณ ๋ คํ•˜์—ฌ Bayesian model averaging์„ ์ˆ˜ํ–‰ํ•œ๋‹ค.

  • calibrated ํ™•๋ฅ : \(P(\hat{q}_i | \hat{p}_i, D) = \sum_{s \in \mathcal{S}} P(\hat{q}_i | s, \hat{p}_i, D) \cdot P(s | D)\)

  • ์—ฌ๊ธฐ์„œ $s$๋Š” binning scheme์ด๋ฉฐ, prior๋Š” Beta ๋ถ„ํฌ, likelihood๋Š” binomial๋กœ ๊ณ„์‚ฐ๋จ

์ •ํ™•ํ•˜์ง€๋งŒ ๊ตฌํ˜„์ด ๋ณต์žกํ•˜๊ณ  ๊ณ„์‚ฐ๋Ÿ‰์ด ๋งŽ์Œ.

๐Ÿ“˜ 2. Extension to Multiclass Models

๋ฌผ๋ก ์ž…๋‹ˆ๋‹ค. ์•„๋ž˜๋Š” ๐Ÿ“˜ 2. Extension to Multiclass Models ๋‚ด์šฉ์„ ์ˆ˜์‹ ํฌํ•จ ๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ํ•œ ๋ฒˆ์— ๋ณต์‚ฌํ•  ์ˆ˜ ์žˆ๋„๋ก ์ •๋ฆฌํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค:


๐Ÿ“˜ 2. Extension to Multiclass Models

๋‹ค์ค‘ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ ๋ฌธ์ œ์—์„œ๋Š” ๊ฐ ์ž…๋ ฅ ์ƒ˜ํ”Œ \(x_i\)์— ๋Œ€ํ•ด logit ๋ฒกํ„ฐ \(z\)๋ฅผ softmax์— ์ž…๋ ฅํ•˜์—ฌ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ™•๋ฅ ์„ ๊ณ„์‚ฐํ•œ๋‹ค.

\[\hat{P} = \max_k \left( \frac{e^{z_k}}{\sum_j e^{z_j}} \right)\]

A. Matrix Scaling

Logit ๋ฒกํ„ฐ \(z\)์— ๋Œ€ํ•ด ์„ ํ˜• ๋ณ€ํ™˜ \(Wz + b\)๋ฅผ ์ ์šฉํ•˜๊ณ  softmax๋ฅผ ์ทจํ•œ๋‹ค.

\[\hat{q} = \text{softmax}(Wz + b)\]
  • \(W \in \mathbb{R}^{K \times K}\), \(b \in \mathbb{R}^K\)
  • ๋งค์šฐ ์œ ์—ฐํ•œ ํ‘œํ˜„์ด ๊ฐ€๋Šฅํ•˜์ง€๋งŒ, ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๊ฐ€ ๋งŽ์•„ ๊ณผ์ ํ•ฉ ๊ฐ€๋Šฅ์„ฑ์ด ์กด์žฌํ•œ๋‹ค.

B. Vector Scaling

Matrix Scaling์„ ๊ฐ„์†Œํ™”ํ•œ ๋ฒ„์ „์œผ๋กœ, \(W\)๋ฅผ ๋Œ€๊ฐ ํ–‰๋ ฌ \(D\)๋กœ ์ œํ•œํ•œ ๊ฒƒ์ด๋‹ค.

\[\hat{q} = \text{softmax}(Dz + b)\]
  • ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜๋Š” \(2K\)๊ฐœ๋กœ ์ค„์–ด๋“ค๋ฉฐ, ๊ณ„์‚ฐ ํšจ์œจ์„ฑ๊ณผ ๊ณผ์ ํ•ฉ ๊ฐ€๋Šฅ์„ฑ์ด ์ค„์–ด๋“ ๋‹ค.

C. Temperature Scaling

๊ฐ€์žฅ ๊ฐ„๋‹จํ•˜๋ฉด์„œ๋„ ๊ฐ•๋ ฅํ•œ ๋ณด์ • ๊ธฐ๋ฒ•์œผ๋กœ logit์„ \(T\)๋กœ ๋‚˜๋ˆ„๊ณ  softmax ์ ์šฉํ•œ๋‹ค.

\[\hat{q} = \text{softmax}(z / T)\]
  • ( T > 1 ): ํ™•๋ฅ ์ด ๋ถ„์‚ฐ๋จ
  • ( T = 1 ): ์›๋ž˜ ๋ชจ๋ธ๊ณผ ๋™์ผ
  • ( T < 1 ): ํ™•๋ฅ ์ด ๋” sharpํ•ด์ง

๐ŸŽฏ ์‹คํ—˜ ๋ชฉ์ 

ํ˜„๋Œ€ ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์€ ๋†’์€ ๋ถ„๋ฅ˜ ์ •ํ™•๋„๋ฅผ ๋‹ฌ์„ฑํ•˜์ง€๋งŒ, ์˜ˆ์ธก ํ™•๋ฅ ์€ ๊ณผ์‹ (overconfidence)๋˜๋Š” ๊ฒฝํ–ฅ์ด ์žˆ๋Š” ๊ฒƒ์„ ์•ž์„œ ์‚ดํŽด๋ณด์•˜์œผ๋ฉฐ, ๋‹ค์–‘ํ•œ ์‚ฌํ›„ ํ™•๋ฅ  ๋ณด์ •(post-hoc calibration) ๊ธฐ๋ฒ•๋“ค์„ ์ •๋ฆฌํ•˜์˜€๋‹ค. ์ด๋“ค์„ ๋น„๊ตํ•˜๊ณ  ๊ฐ€์žฅ ํšจ๊ณผ์ ์ธ ๋ฐฉ๋ฒ•์„ ์‹๋ณ„ํ•˜๊ณ ์ž ํ•œ๋‹ค.

  • ๋ชจ๋ธ: ResNet, Wide ResNet, DenseNet, LeNet, DAN, TreeLSTM
  • ๋ฐ์ดํ„ฐ์…‹: CIFAR-10, CIFAR-100, ImageNet, SVHN, Birds, Cars, 20News, Reuters, SST
  • ํ‰๊ฐ€์ง€ํ‘œ: Expected Calibration Error (ECE), Maximum Calibration Error (MCE), Negative Log Likelihood (NLL), Error Rate

Desktop View Figure 5: ECE (%) (with M = 15 bins) on standard vision and NLP datasets before calibration and with various calibration methods. The number following a modelโ€™s name denotes the network depth.

๋…ผ๋ฌธ์€ ํ˜„๋Œ€ ์‹ ๊ฒฝ๋ง๋“ค์ด ๋†’์€ ์ •ํ™•๋„์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ์˜ˆ์ธก ํ™•๋ฅ ์ด ๊ณผ๋„ํ•˜๊ฒŒ ์ž์‹ ๊ฐ(overconfident)์„ ๊ฐ€์ง„๋‹ค๋Š” ๋ฌธ์ œ๋ฅผ ์ง€์ ํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, Temperature Scaling์€ ๋‹จ ํ•˜๋‚˜์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋งŒ์œผ๋กœ ๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ ๊ฐ€์žฅ ์šฐ์ˆ˜ํ•œ ๋ณด์ • ์„ฑ๋Šฅ(ECE ๊ฐ์†Œ)์„ ๋ณด์—ฌ์ฃผ์—ˆ๋‹ค. ๋‹ค๋ฅธ ๋ณต์žกํ•œ ๋ณด์ • ๊ธฐ๋ฒ•๋“ค๋ณด๋‹ค ๊ฐ„๋‹จํ•˜๊ณ  ์•ˆ์ •์ ์œผ๋กœ ์ž‘๋™ํ•˜๋ฉฐ, ์ •ํ™•๋„๋ฅผ ๋–จ์–ด๋œจ๋ฆฌ์ง€ ์•Š๊ณ  ํ™•๋ฅ ๋งŒ ์กฐ์ •ํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์—ˆ์œผ๋ฉฐ, CIFAR-100 ๋“ฑ ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ์…‹์—์„œ ECE๊ฐ€ ์ตœ๋Œ€ 16% โ†’ 1% ์ˆ˜์ค€์œผ๋กœ ํฌ๊ฒŒ ๊ฐ์†Œํ•˜์˜€๋‹ค.

๐Ÿงช Python ์‹ค์Šต ๊ฐœ์š”

์ง€๊ธˆ๋ถ€ํ„ฐ๋Š” ์•ž์„œ ์‚ดํŽด๋ณธ ๋…ผ๋ฌธ๊ณผ ๊ด€๋ จํ•˜์—ฌ ์‹ค์Šต์„ ์ง„ํ–‰ํ•ด๋ณด๊ณ ์ž ํ•œ๋‹ค. ๋ณธ ์‹คํ—˜์€ CIFAR-100 ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๊ณผ์ œ๋ฅผ ๋Œ€์ƒ์œผ๋กœ ResNet-34 ๋ชจ๋ธ์˜ ์‹ ๋ขฐ๋„ ๋ณด์ •(calibration) ์„ฑ๋Šฅ์„ ์ •๋Ÿ‰์ ์œผ๋กœ ํ‰๊ฐ€ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•˜๋Š” ๊ฒƒ์„ ๋ชฉ์ ์œผ๋กœ ํ•œ๋‹ค. ๋”ฅ๋Ÿฌ๋‹ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ๋†’์€ ๋ถ„๋ฅ˜ ์ •ํ™•๋„๋ฅผ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์œผ๋‚˜, ์ถœ๋ ฅ ํ™•๋ฅ ์ด ์‹ค์ œ ์ •๋‹ต์ผ ๊ฐ€๋Šฅ์„ฑ์„ ๊ณผ๋Œ€ํ‰๊ฐ€ํ•˜๋Š” ๊ณผ์‹ (overconfidence) ํ˜„์ƒ์„ ์ž์ฃผ ๋‚˜ํƒ€๋‚ธ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋Š” ์‹ค์ œ ์‘์šฉ์—์„œ ๋ชจ๋ธ์˜ ๋ถˆํ™•์‹ค์„ฑ์„ ์‹ ๋ขฐํ•  ์ˆ˜ ์—†๊ฒŒ ๋งŒ๋“ ๋‹ค. ์•ž์„œ ์‚ดํŽด๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋‹ค์–‘ํ•œ Post-hoc ๋ณด์ • ๊ธฐ๋ฒ•๋“ค์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ์œผ๋ฉฐ, ์ด ์ค‘ ํ•˜๋‚˜์ธ Temperature Scaling์„ ์ง์ ‘ ๊ตฌํ˜„ํ•˜๊ณ , ๊ทธ ํšจ๊ณผ๋ฅผ Reliability Diagram์„ ํ†ตํ•ด ์‹œ๊ฐ์ ์œผ๋กœ ๋ถ„์„ํ•ด๋ณด๊ณ ์ž ํ•œ๋‹ค.

๋ชจ๋ธ ํ•™์Šต์—๋Š” CIFAR-100 ๋ฐ์ดํ„ฐ์…‹๊ณผ ResNet-34 ๋ชจ๋ธ ๊ตฌ์กฐ๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. CIFAR-100์€ ์ด 100๊ฐœ์˜ ๋ ˆ์ด๋ธ”๋กœ ๊ตฌ์„ฑ๋œ ์ปฌ๋Ÿฌ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ, ๊ฐ ์ด๋ฏธ์ง€๋Š” 32ร—32 ํ•ด์ƒ๋„์˜ RGB ์ด๋ฏธ์ง€์ด๋‹ค. ๊ฐ ํด๋ž˜์Šค๋‹น 500๊ฐœ์˜ ํ•™์Šต ์ด๋ฏธ์ง€์™€ 100๊ฐœ์˜ ํ…Œ์ŠคํŠธ ์ด๋ฏธ์ง€๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ด 50,000๊ฐœ์˜ ํ•™์Šต ์ƒ˜ํ”Œ๊ณผ 10,000๊ฐœ์˜ ํ…Œ์ŠคํŠธ ์ƒ˜ํ”Œ์„ ํฌํ•จํ•œ๋‹ค. ๋ถ„๋ฅ˜ ๋ชจ๋ธ๋กœ ์‚ฌ์šฉ๋œ ResNet-34๋Š” Residual Network ๊ณ„์—ด์˜ ๋Œ€ํ‘œ์ ์ธ ๊ตฌ์กฐ ์ค‘ ํ•˜๋‚˜๋กœ, 34๊ฐœ์˜ ์ธต์„ ๊ฐ–๋Š” ์‹ฌ์ธต ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง์ด๋‹ค. ์ž”์ฐจ ์—ฐ๊ฒฐ(residual connection)์„ ํ†ตํ•ด ๊นŠ์€ ๋„คํŠธ์›Œํฌ์—์„œ ๋ฐœ์ƒํ•˜๋Š” ๊ธฐ์šธ๊ธฐ ์†Œ์‹ค ๋ฌธ์ œ๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, CIFAR-100๊ณผ ๊ฐ™์€ ์ค‘๊ฐ„ ๋‚œ์ด๋„์˜ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ฌธ์ œ์— ๋„๋ฆฌ ์‚ฌ์šฉ๋œ๋‹ค. ๋ณธ ์‹คํ—˜์—์„œ๋Š” ์‚ฌ์ „ ํ•™์Šต ์—†์ด ์ฒ˜์Œ๋ถ€ํ„ฐ CIFAR-100์— ๋Œ€ํ•ด ResNet-34๋ฅผ ํ•™์Šต์‹œ์ผฐ์œผ๋ฉฐ, ์ถœ๋ ฅ์ธต fully-connected layer์˜ ์ถœ๋ ฅ ์ฐจ์›์„ 100์œผ๋กœ ์„ค์ •ํ•˜์—ฌ 100๊ฐœ์˜ ํด๋ž˜์Šค๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋„๋ก ๊ตฌ์„ฑํ•˜์˜€๋‹ค.

๋ณธ ์‹คํ—˜์—์„œ๋Š” ํ•™์Šต๋œ ResNet-34 ๋ชจ๋ธ์„ ๊ธฐ์ค€์œผ๋กœ T โˆˆ \({0.5, 1.0, 1.5, 2.0}\) ๋ฒ”์œ„์— ๋Œ€ํ•ด Temperature Scaling์„ ์ ์šฉํ•œ ํ›„, ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ด€์ ์—์„œ ๋ณด์ • ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•˜์˜€๋‹ค:

  1. Reliability Diagram์„ ํ†ตํ•ด confidence vs accuracy ๊ด€๊ณ„๋ฅผ ์‹œ๊ฐํ™”
  2. Expected Calibration Error (ECE) ์ˆ˜์น˜๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ์ •๋Ÿ‰์  ๋ณด์ • ์„ฑ๋Šฅ ํ‰๊ฐ€
  3. ๊ฐ confidence bin ๋‚ด์˜ sample ์ˆ˜ ๋ฐ ์ •ํ™•๋„ ๋ณ€ํ™” ๋ถ„์„

๊ทธ๋Ÿผ ์ง€๊ธˆ๋ถ€ํ„ฐ๋Š” ๋‹จ๊ณ„๋ณ„๋กœ ์ฝ”๋“œ๋ฅผ ์‚ดํŽด๋ณด๋„๋ก ํ•˜๊ฒ ๋‹ค.

1. CIFAR-100 ๋ฐ์ดํ„ฐ์…‹์˜ ์ •๊ทœํ™”๋ฅผ ์œ„ํ•œ ํ‰๊ท  ๋ฐ ํ‘œ์ค€ํŽธ์ฐจ ๊ณ„์‚ฐ

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ์„ ํ•™์Šตํ•  ๋•Œ, ์ž…๋ ฅ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ๋ฅผ ์ •๊ทœํ™”(normalization) ํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์ค‘์š”ํ•œ ์ „์ฒ˜๋ฆฌ ๊ณผ์ •์ด๋‹ค. ๋ณดํ†ต ์ •๊ทœํ™”๋Š” ๊ฐ ์ฑ„๋„(R, G, B)์— ๋Œ€ํ•ด ํ‰๊ท ์„ ๋นผ๊ณ  ํ‘œ์ค€ํŽธ์ฐจ๋กœ ๋‚˜๋ˆ„๋Š” ๋ฐฉ์‹์œผ๋กœ ์ด๋ฃจ์–ด์ง„๋‹ค. ์ด ๊ณผ์ •์„ ํ†ตํ•ด ์ž…๋ ฅ ๊ฐ’์˜ ๋ถ„ํฌ๋ฅผ 0์„ ์ค‘์‹ฌ์œผ๋กœ ์ •๊ทœํ™”ํ•จ์œผ๋กœ์จ, ํ•™์Šต์ด ๋” ์•ˆ์ •์ ์œผ๋กœ ์ด๋ฃจ์–ด์ง€๊ณ  ์ˆ˜๋ ด ์†๋„๊ฐ€ ๋นจ๋ผ์งˆ ์ˆ˜ ์žˆ๋‹ค.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

transform = transforms.ToTensor()
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(trainset, batch_size=50000, shuffle=False)

data = next(iter(loader))[0]  # (50000, 3, 32, 32)
mean = data.mean(dim=(0, 2, 3))
std = data.std(dim=(0, 2, 3))

print("CIFAR-100 ํ‰๊ท :", mean)
print("CIFAR-100 ํ‘œ์ค€ํŽธ์ฐจ:", std)

์œ„ ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•˜๊ฒŒ ๋˜๋ฉด CIFAR-100 ๋ฐ์ดํ„ฐ์…‹์ด ๋กœ์ปฌ์— ์ €์žฅ๋˜์–ด์ ธ ์žˆ์ง€ ์•Š์€ ๊ฒฝ์šฐ ./data ๊ฒฝ๋กœ์— ์ €์žฅํ•˜๊ฒŒ ๋˜๋ฉฐ ์ดํ›„ ์ „์ฒด ํ•™์Šต ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ๊ณ„์‚ฐํ•˜๊ฒŒ ๋œ๋‹ค.

2. CIFAR-100 ๋ฐ์ดํ„ฐ๋ฅผ ์ด์šฉํ•œ ResNet-34 ํ•™์Šต

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# ---------------- ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ • ----------------
batch_size = 128
epochs = 30
lr = 0.1
save_path = "๋ชจ๋ธ ์ €์žฅ ๊ฒฝ๋กœ"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------- ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ๋ฐ ๋กœ๋”ฉ ----------------
# cifar100_mean_std.py ์ถœ๋ ฅ ๊ฒฐ๊ณผ
mean = (0.5071, 0.4866, 0.4409)
std = (0.2673, 0.2564, 0.2762)

# ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด ๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ๋ฐ ์ •๊ทœํ™” ์ˆ˜ํ–‰
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),         # ๋ฌด์ž‘์œ„ crop (32x32 ์œ ์ง€)
    transforms.RandomHorizontalFlip(),            # ๋ฌด์ž‘์œ„ ์ขŒ์šฐ ๋ฐ˜์ „
    transforms.ToTensor(),                        # ํ…์„œ ๋ณ€ํ™˜ (0~1)
    transforms.Normalize(mean, std)               # ์ฑ„๋„๋ณ„ ์ •๊ทœํ™”
])

# CIFAR-100 ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=8)

# ---------------- ๋ชจ๋ธ ์ •์˜ ----------------
model = models.resnet34(weights=None)
model.fc = nn.Linear(model.fc.in_features, 100)
model = model.to(device)

# ---------------- ์†์‹ค ํ•จ์ˆ˜ ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ----------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # 10 epoch๋งˆ๋‹ค ํ•™์Šต๋ฅ  ๊ฐ์†Œ

# ---------------- ํ•™์Šต ๋ฃจํ”„ ----------------
for epoch in range(epochs):
    model.train()  # ํ•™์Šต ๋ชจ๋“œ ํ™œ์„ฑํ™”
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()              # ๊ธฐ์šธ๊ธฐ ์ดˆ๊ธฐํ™”
        outputs = model(inputs)            # ์ˆœ์ „ํŒŒ
        loss = criterion(outputs, labels)  # ์†์‹ค ๊ณ„์‚ฐ
        loss.backward()                    # ์—ญ์ „ํŒŒ
        optimizer.step()                   # ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    acc = 100. * correct / total
    print(f"[Epoch {epoch+1}/{epochs}] Loss: {running_loss:.3f}, Train Accuracy: {acc:.2f}%")
    scheduler.step()

# ---------------- ํ•™์Šต๋œ ๋ชจ๋ธ ์ €์žฅ ----------------
torch.save(model.state_dict(), save_path)
print(f"๐Ÿ’พ Model saved to: {save_path}")

์œ„์˜ ์ฝ”๋“œ๋Š” ResNet-34๋ฅผ CIFAR-100 ๋ฐ์ดํ„ฐ์…‹์„ ์ด์šฉํ•˜์—ฌ ํ•™์Šตํ•˜๋Š” ๊ณผ์ •์„ ๋‚˜ํƒœ๋‚ด์—ˆ๋‹ค. 1๋ฒˆ์—์„œ ๊ณ„์‚ฐ๋œ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ ๊ฐ’์„ transforms.Normalize(mean, std) ํ•จ์ˆ˜์— ๊ทธ๋Œ€๋กœ ์ ์šฉํ•˜์˜€๋‹ค. ์ด ์ •๊ทœํ™” ๊ณผ์ •์€ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ํŒŒ์ดํ”„๋ผ์ธ์— ํฌํ•จ๋˜์–ด ์žˆ์œผ๋ฉฐ, ๋ฌด์ž‘์œ„ ์ž๋ฅด๊ธฐ(RandomCrop), ์ขŒ์šฐ ๋ฐ˜์ „(RandomHorizontalFlip), ํ…์„œ ๋ณ€ํ™˜(ToTensor) ์ดํ›„ ๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„์—์„œ ์ˆ˜ํ–‰๋œ๋‹ค. ์ •๊ทœํ™”๋œ ๋ฐ์ดํ„ฐ๋Š” ์ดํ›„ ResNet-34 ๋ชจ๋ธ์— ์ž…๋ ฅ๋˜๋ฉฐ, ์ด ๋ชจ๋ธ์€ ์ถœ๋ ฅ์ธต๋งŒ CIFAR-100์˜ 100๊ฐœ ํด๋ž˜์Šค์— ๋งž๊ฒŒ ์ˆ˜์ •๋œ ํ˜•ํƒœ๋กœ ์‚ฌ์šฉ๋œ๋‹ค.

๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ, ํŠนํžˆ ๊นŠ์€ ๊ตฌ์กฐ์˜ ResNet-34์™€ ๊ฐ™์€ ๋ชจ๋ธ์€ ์ž…๋ ฅ์˜ ๋ถ„ํฌ๊ฐ€ ์ง€๋‚˜์น˜๊ฒŒ ํŽธํ–ฅ๋˜์–ด ์žˆ์„ ๊ฒฝ์šฐ ํ•™์Šต์ด ์ž˜ ๋˜์ง€ ์•Š๊ฑฐ๋‚˜, ์ดˆ๊ธฐ ํ•™์Šต ๋‹จ๊ณ„์—์„œ ๋งค์šฐ ๋А๋ฆฐ ์ˆ˜๋ ด ์†๋„๋ฅผ ๋ณด์ผ ์ˆ˜ ์žˆ๋‹ค. ๋”ฐ๋ผ์„œ ์‚ฌ์ „์— ๋ฐ์ดํ„ฐ์…‹์˜ ํ†ต๊ณ„ ์ •๋ณด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์ •๊ทœํ™”๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์€ ํ•™์Šต ์„ฑ๋Šฅ์„ ์•ˆ์ •์‹œํ‚ค๋Š” ํ•ต์‹ฌ์ ์ธ ์š”์†Œ๋‹ค. ๊ฒฐ๊ณผ์ ์œผ๋กœ, 1๋ฒˆ์˜ ์ •๊ทœํ™” ๊ฐ’ ๊ณ„์‚ฐ์€ 2๋ฒˆ์˜ ํšจ๊ณผ์ ์ธ ๋ชจ๋ธ ํ•™์Šต์„ ์œ„ํ•œ ํ•„์ˆ˜ ์ „์ฒ˜๋ฆฌ ๊ณผ์ •์ด๋ฉฐ, ์ „์ฒด ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์˜ ์‹ ๋ขฐ์„ฑ๊ณผ ํšจ์œจ์„ฑ์„ ๋†’์ด๋Š” ๋ฐ ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•œ๋‹ค.

๋ชจ๋ธ ํ•™์Šต์€ ์ด 30๋ฒˆ์˜ epoch์— ๊ฑธ์ณ ์ง„ํ–‰๋˜๋ฉฐ, ํ•œ ๋ฒˆ์˜ epoch๋งˆ๋‹ค ์ „์ฒด CIFAR-100 ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ํ•œ ๋ฒˆ์”ฉ ์ˆœํšŒํ•˜๊ฒŒ ๋œ๋‹ค. ํ•™์Šต ๊ณผ์ •์—์„œ ์†์‹ค ํ•จ์ˆ˜๋Š” ๋‹ค์ค‘ ํด๋ž˜์Šค ๋ถ„๋ฅ˜์— ์ ํ•ฉํ•œ CrossEntropyLoss๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ์˜ตํ‹ฐ๋งˆ์ด์ €๋Š” ํ™•๋ฅ ์  ๊ฒฝ์‚ฌ ํ•˜๊ฐ•๋ฒ•(SGD: Stochastic Gradient Descent)์— momentum๊ณผ weight decay๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์•ˆ์ •์ ์ธ ์ตœ์ ํ™”๋ฅผ ์œ ๋„ํ•œ๋‹ค. ํ•™์Šต๋ฅ ์€ ์ดˆ๊ธฐ์— 0.1๋กœ ์„ค์ •๋˜๋ฉฐ, StepLR ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ํ†ตํ•ด 10 ์—ํญ๋งˆ๋‹ค 1/10์”ฉ ๊ฐ์†Œ์‹œํ‚จ๋‹ค. ์ด๋Š” ์ดˆ๋ฐ˜์—๋Š” ๋น ๋ฅด๊ฒŒ ํ•™์Šตํ•˜๊ณ , ํ›„๋ฐ˜์—๋Š” ์ฒœ์ฒœํžˆ fine-tuningํ•˜๋„๋ก ์œ ๋„ํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ๋ชจ๋ธ์€ ํ•™์Šต ๋ชจ๋“œ(model.train())์—์„œ ๊ฐ ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ˆœ์ „ํŒŒ(forward)์‹œ์ผœ ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ์ถœ๋ ฅํ•˜๊ณ , ์ด ๊ฒฐ๊ณผ๋ฅผ ์‹ค์ œ ์ •๋‹ต๊ณผ ๋น„๊ตํ•˜์—ฌ ์†์‹ค(loss)์„ ๊ณ„์‚ฐํ•œ ํ›„, ์—ญ์ „ํŒŒ(backward)๋ฅผ ํ†ตํ•ด ๊ฐ€์ค‘์น˜์˜ ๊ธฐ์šธ๊ธฐ๋ฅผ ๊ตฌํ•˜๊ณ  ์ด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค. ๋˜ํ•œ ๊ฐ epoch๋งˆ๋‹ค ๋ˆ„์ ๋œ ์†์‹ค๊ณผ ์ •ํ™•๋„๋ฅผ ์ถœ๋ ฅํ•˜์—ฌ ํ•™์Šต์ด ์–ด๋–ป๊ฒŒ ์ง„ํ–‰๋˜๊ณ  ์žˆ๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ๋งˆ์ง€๋ง‰์œผ๋กœ ํ•™์Šต์ด ์ข…๋ฃŒ๋œ ํ›„์—๋Š” ๋ชจ๋ธ์˜ ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ .pth ํŒŒ์ผ๋กœ ์ €์žฅํ•˜์—ฌ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ์ถ”ํ›„ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค.

3. Reliability Diagram ์‹œ๊ฐํ™”

์ด๋ฒˆ์—๋Š” ์•ž์„œ ํ•™์Šต์ด ์™„๋ฃŒ๋œ ResNet-34 ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€, CIFAR-100 ํ…Œ์ŠคํŠธ์…‹์— ๋Œ€ํ•ด ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•œ ํ›„, ์˜ˆ์ธก์˜ ์‹ ๋ขฐ๋„(confidence)์™€ ์‹ค์ œ ์ •๋‹ต ์—ฌ๋ถ€ ๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ๋ถ„์„ํ•˜๊ณ  ์‹œ๊ฐํ™”ํ•œ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด Expected Calibration Error (ECE) ๋ฅผ ์ˆ˜์น˜๋กœ ๊ณ„์‚ฐํ•˜๊ณ , Reliability Diagram์„ ํ†ตํ•ด ๋ชจ๋ธ์˜ ์‹ ๋ขฐ๋„๊ฐ€ ์–ผ๋งˆ๋‚˜ ์ž˜ ๋ณด์ •๋˜์–ด ์žˆ๋Š”์ง€๋ฅผ ์‹œ๊ฐ์ ์œผ๋กœ ํ‰๊ฐ€ํ•œ๋‹ค.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# โœ… ๋””๋ฐ”์ด์Šค ์„ค์ • ๋ฐ ๋ชจ๋ธ ๋กœ๋“œ
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet34(weights=None)
model.fc = torch.nn.Linear(model.fc.in_features, 100)
model.load_state_dict(torch.load("./snapshots/resnet34_cifar100_exp/resnet34_cifar100.pth", map_location=device))
model = model.to(device)

# โœ… CIFAR-100 ํ…Œ์ŠคํŠธ์…‹ ์ค€๋น„
mean = (0.5071, 0.4866, 0.4409)
std = (0.2673, 0.2564, 0.2762)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
test_dataset = datasets.CIFAR100(root="./data", train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# โœ… ECE ๊ณ„์‚ฐ ๋ฐ Reliability Diagram ์ƒ์„ฑ
def compute_reliability_and_ece(model, dataloader, device, n_bins=15):
    model.eval()
    bin_boundaries = torch.linspace(0, 1, n_bins + 1).to(device)
    bin_corrects = torch.zeros(n_bins).to(device)
    bin_confidences = torch.zeros(n_bins).to(device)
    bin_counts = torch.zeros(n_bins).to(device)
    total_samples = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            probs = F.softmax(logits, dim=1)
            confs, preds = torch.max(probs, dim=1)
            corrects = preds.eq(labels)

            total_samples += labels.size(0)

            for i in range(n_bins):
                in_bin = (confs > bin_boundaries[i]) & (confs <= bin_boundaries[i+1])
                bin_counts[i] += in_bin.sum()
                if in_bin.sum() > 0:
                    bin_corrects[i] += corrects[in_bin].float().sum()
                    bin_confidences[i] += confs[in_bin].sum()

    nonzero = bin_counts > 0
    accs = bin_corrects[nonzero] / bin_counts[nonzero]
    confs = bin_confidences[nonzero] / bin_counts[nonzero]
    bin_centers = ((bin_boundaries[:-1] + bin_boundaries[1:]) / 2)[nonzero]
    filtered_counts = bin_counts[nonzero]

    ece = torch.sum((filtered_counts / total_samples) * torch.abs(accs - confs)).item()
    return bin_centers.cpu(), accs.cpu(), confs.cpu(), ece

def draw_reliability_diagram(bin_centers, accs, confs, ece, name, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    width = 1.0 / len(bin_centers)

    plt.figure(figsize=(5, 5))
    plt.bar(bin_centers, accs, width=width * 0.9, color='blue', edgecolor='black', alpha=0.8, label="Accuracy")
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label="Perfect Calibration")
    for x, acc, conf in zip(bin_centers, accs, confs):
        lower = min(acc, conf)
        upper = max(acc, conf)
        plt.fill_between([x - width / 2, x + width / 2], lower, upper,
                         color='red', alpha=0.3, hatch='//')

    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.title(f"Reliability Diagram: {name}")
    plt.text(0.02, 0.6, f"ECE = {ece * 100:.2f}%", fontsize=12,
             bbox=dict(facecolor='lavender', edgecolor='gray'))
    plt.legend(loc='upper left')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"{name}_reliability.png"))
    plt.close()

# โœ… ์‹คํ–‰
bin_centers, accs, confs, ece = compute_reliability_and_ece(model, test_loader, device)
draw_reliability_diagram(bin_centers, accs, confs, ece, name="ResNet34_CIFAR100", save_dir="./snapshots/resnet34_cifar100_exp")

์œ„ ์ฝ”๋“œ๋Š” ํ•™์Šต๋œ ๋ชจ๋ธ์„ CIFAR-100 ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์— ์ ์šฉํ•˜์—ฌ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•œ ํ›„, ์˜ˆ์ธก ๊ฒฐ๊ณผ์— ๋Œ€ํ•œ confidence score์™€ ์‹ค์ œ ์ •๋‹ต ์—ฌ๋ถ€๋ฅผ ๋น„๊ตํ•˜์—ฌ ์‹ ๋ขฐ๋„(calibration)๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ๊ณผ์ •์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. compute_reliability_and_ece ํ•จ์ˆ˜๋Š” confidence ๊ฐ’ ๋ฒ”์œ„๋ฅผ ์ผ์ •ํ•œ ๊ฐ„๊ฒฉ์œผ๋กœ ๋‚˜๋ˆˆ bin์„ ๊ธฐ์ค€์œผ๋กœ ๊ฐ bin ๋‚ด์˜ ํ‰๊ท  confidence์™€ ์‹ค์ œ ์ •ํ™•๋„(accuracy)๋ฅผ ๊ณ„์‚ฐํ•˜๋ฉฐ, ์ด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ Expected Calibration Error (ECE)๋ฅผ ์ˆ˜์น˜๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค. ์ด ๊ฐ’์ด ์ž‘์„์ˆ˜๋ก ๋ชจ๋ธ์˜ ์˜ˆ์ธก ํ™•๋ฅ ์ด ์‹ค์ œ ์ •๋‹ต๋ฅ ๊ณผ ์ž˜ ์ผ์น˜ํ•œ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค.

๋˜ํ•œ, draw_reliability_diagram ํ•จ์ˆ˜๋Š” ์ด๋Ÿฌํ•œ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์‹ ๋ขฐ๋„ ๊ทธ๋ž˜ํ”„๋ฅผ ์‹œ๊ฐํ™”ํ•˜๋ฉฐ, ์ด์ƒ์ ์ธ ๊ฒฝ์šฐ์ธ ๋Œ€๊ฐ์„ (์™„๋ฒฝํ•œ ๋ณด์ •)์„ ๊ธฐ์ค€์œผ๋กœ ๋ชจ๋ธ์ด ๊ณผ์‹ ํ•˜๊ฑฐ๋‚˜ ๊ณผ์†Œ์‹ ํ•˜๋Š” ๊ตฌ๊ฐ„์„ ์‹œ๊ฐ์ ์œผ๋กœ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค. ๋ง‰๋Œ€๋Š” ๊ฐ confidence ๊ตฌ๊ฐ„์˜ ์‹ค์ œ ์ •ํ™•๋„๋ฅผ ๋‚˜ํƒ€๋‚ด๋ฉฐ, ํŒŒ๋ž€ ๋ง‰๋Œ€์™€ ํšŒ์ƒ‰ ๋Œ€๊ฐ์„  ์‚ฌ์ด์˜ ๋นจ๊ฐ„ ์Œ์˜์€ ์‹ ๋ขฐ๋„ ์˜ค์ฐจ๋ฅผ ์‹œ๊ฐ์ ์œผ๋กœ ํ‘œํ˜„ํ•œ๋‹ค. ์ด ๊ฒฐ๊ณผ๋ฅผ ํ†ตํ•ด ๋ชจ๋ธ์ด ์–ผ๋งˆ๋‚˜ calibrated ๋˜์–ด ์žˆ๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ๊ณ , ์ดํ›„ ๋ณด์ • ๊ธฐ๋ฒ•(์˜ˆ: temperature scaling)์˜ ํ•„์š”์„ฑ์„ ํ‰๊ฐ€ํ•˜๋Š” ๊ธฐ๋ฐ˜ ์ž๋ฃŒ๊ฐ€ ๋œ๋‹ค.

4. Temperature Scaling ์‹คํ—˜

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch.nn as nn
import os
import matplotlib.pyplot as plt

# Temperature Scaler ์ •์˜ (ํ•™์Šต ์—†์ด ๊ณ ์ •๋œ T ๊ฐ’ ์‚ฌ์šฉ)
class TemperatureScaler(nn.Module):
    def __init__(self, temperature: float):
        super().__init__()
        self.temperature = nn.Parameter(torch.tensor([temperature]), requires_grad=False)

    def forward(self, logits):
        return logits / self.temperature

# ๋ชจ๋ธ์„ Temperature Scaler๋กœ ๊ฐ์‹ธ๋Š” ๋ž˜ํผ
class WrappedModel(nn.Module):
    def __init__(self, base_model, temp_scaler):
        super().__init__()
        self.base_model = base_model
        self.temp_scaler = temp_scaler

    def forward(self, x):
        logits = self.base_model(x)
        return self.temp_scaler(logits)

# ๋‹ค์–‘ํ•œ T ๊ฐ’์— ๋Œ€ํ•ด ECE ๊ณ„์‚ฐ ๋ฐ Reliability Diagram ์ €์žฅ
def evaluate_multiple_temperatures_with_plots(model, test_loader, device, T_values, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    ece_list = []

    for T in T_values:
        print(f"\n๐Ÿงช Evaluating T = {T}")
        temp_scaler = TemperatureScaler(temperature=T).to(device)
        wrapped_model = WrappedModel(model, temp_scaler).to(device)

        # ์‹ ๋ขฐ๋„ ํ‰๊ฐ€
        bin_centers, accs, confs, bin_counts, total_samples, ece = compute_reliability_and_ece(
            wrapped_model, test_loader, device, verbose_under_100=False
        )
        ece_list.append(ece)

        # Reliability Diagram ์ €์žฅ
        draw_fancy_reliability_diagram(
            bin_centers, accs, confs, bin_counts, total_samples, ece,
            name=f"T={T}", output_dir=output_dir
        )

    return T_values, ece_list

# T์— ๋”ฐ๋ฅธ ECE ๋ณ€ํ™”๋ฅผ ์‹œ๊ฐํ™”
def plot_temperature_vs_ece(T_values, ece_list, save_path):
    plt.figure(figsize=(6, 4))
    plt.plot(T_values, [ece * 100 for ece in ece_list], marker='o', linestyle='-', color='purple')
    plt.xlabel("Temperature (T)")
    plt.ylabel("ECE (%)")
    plt.title("ECE vs Temperature")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
This post is licensed under CC BY 4.0 by the author.