[기계학습][머신러닝][딥러닝] Vanilla GAN

 

Vanilla GAN

A pytorch implementation of Vanilla GAN using MNIST digits data (https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)

실험 결과와 코드는 https://github.com/EmjayAhn/GAN-pytorch 에서 확인 할 수 있습니다.

0. 목표

GAN 은 최근 인기를 끌고 있는 generative model 중 하나입니다. 다양한 모델이 쏟아져 나오고 있기에, 이 트렌트를 따라가기 위해서, 가장 기본적인 vanilla gan 을 구현하고, 이해하는 것이 목표입니다. 이번 예제에서는 MNIST digit 을 사용하였으나, 이미지 외에도 다양한 데이터를 활용할 수 있습니다.

1. Model 구조

GAN 의 아버지 Ian Goodfellow가 제안한 이 모델은 두가지 신경망으로 구성 되어 있습니다. 먼저, 우리가 가지고 있는 데이터와 비슷하게 데이터를 생성하는 것을 학습하는 Generator 와 Generator 가 생성한 데이터(fake)와 실제 우리가 가지고 있는 데이터(real)를 fake 인지 real 인지 구분하는 Discriminator 를 구분하는 두 신경망으로 구성되어 있습니다.

1-1. Generator

  • Generator 는 Gaussian Random Noise (mean=0, std=1) 를 입력으로 받아, 이 noise로 부터 data 를 생성해냅니다.
  • 이번 구현에서 Generator는 다음과 같이 작성하였습니다.

  • Dense layer 만 사용했으며, gradient vanishing 현상을 막기 위해 activation 을 거친 후, Batch Normalization 을 추가하였습니다.

1-2. Discriminator

  • Discriminator 는 Generator 가 생성해낸 데이터와 기존에 가지고 있는 진짜 데이터를 입력으로 받아, 진짜 데이터를 1, 가짜데이터를 0으로 학습하는 classifier 입니다.
  • 아래의 Loss Function 을 확인 하겠지만, 진짜 데이터에 대해서는 그 확률 값을 높게 하고, 가짜데이터에서는 그 확률 값을 0에 가깝게 하는 것이 이 모델의 optimize 목표입니다.
  • 이번 구현에서 Discriminator를 다음과 같이 작성하였습니다.

2. Loss Function

2-1. Loss Function 의 해석

Vanilla GAN 의 Loss function 은 위와 같습니다. Loss Function 의 구조 자체는 min-max 최적화로써, Discriminator와 Generator 의 loss 함수를 각각 최적화 해 나아가면서 위 식의 균형 향해 다가가는 것입니다.

  • 먼저, Discriminator에 대한 max 부터 살펴 보면, Real을 입력으로 넣었을 때는, log(D(x))의 기댓값이 최대가 되게 하고, G(z) 즉, 가짜를 가짜라고 할 확률 1-D(G(z))는 최대가 되게끔 학습을 하는 것입니다.
  • Generator 에 대한 min 을 살펴보면, G(z) 는 가우시안 랜덤 변수를 받아 생성된 데이터를 D(G(z)), discriminator에 넣었을 때, 1-D(G(z)), 즉 가짜라고 할 확률을 최소화하게끔 학습하는 것입니다. 이는 결국, Discriminator를 속이기 위해 generator의 최적화가 실행된다는 의미입니다.

이 식에서 중요한 점은, Discriminator 는 학습할 때, Generator 가 생성한 데이터와 진짜 데이터 모두를 보며 학습하지만, Generator 는 그 어디에서도 진짜 데이터가 어떻게 생겼는지는 확인하지 않습니다. 오로지 Discriminator 를 속이기 위해 학습하는 것이지만, 그 결과 우리가 가지고 있는 진짜 데이터와 비슷하게 만들수 있는 모델을 획득하게 될 수 있는 것이라는 점에서 매우 획기적인 모델입니다.

2-2. 실제 구현에서의 변형

Generator 에 대한 loss function 은 log(1-(D(G(z)))를 최소화 하는 것입니다. 하지만, log(1-x) 형태의 식은 x가 0일 때, 그 gradient 가 매우 작아, 학습이 매우 오래 걸리는 문제가 있습니다. 이를 해결하기 위해, log(1-D(G(z)))를 G에 대해 최소화 하는 것은 결국, -log(D(G(z)))를 최소화 하는 것과 같고, 이는 log(D(G(z)))를 최대화 하는 것과 같습니다. 따라서 실제 구현에서의 criterion 은 Discriminator 가 사용하는 criterion(여기선, binary cross enntropy loss)을 동일하게 사용합니다.

3. 학습 및 모델 결과

학습 parameter 는 다음과 같습니다.

  • Total epoch: 300
  • batch size : 128
  • z dimension : 100
  • Adam optimizer : lr=0.0002, weight_decay=8e-9

다음은 generator가 학습이 되가면서, 같은 가우시안 랜덤 노이즈에 대해 mnist 와 닮은 데이터를 생성해 나가는 과정입니다.

경축! 아무것도 안하여 에스천사게임즈가 새로운 모습으로 재오픈 하였습니다.
어린이용이며, 설치가 필요없는 브라우저 게임입니다.
https://s1004games.com

(1) 0 epoch

약 400개의 배치를 학습한 후, 찍은 사진이기에 가운데에 아주 미세한 형태는 보이지만, 가우시안 노이즈임을 확인 할 수 있습니다.

(2) 20 epoch

20 epoch 만 되더라도, (마치 뱃속의 아가처럼(?)) 가운데에 어떤 형태가 생성되기 시작하는 것을 확인 할 수 있습니다.

(3) 100 epoch

9, 3, 8, (horizontal flipped) 3, 1.. 아주 힘들게 MNIST 와 비슷해 보이는 숫자를 확인 할 수 있습니다.

4. 결론

이번 구현의 목표는 나의 첫 vanilla gan 을 논문과 여러 자료를 공부해보며, 구현해 보는 것에 있었기에, 이를 완수하고, 실제로 학습 시켰을 때, generator 로써 기능을 할 수 있다는 점에서 유의미 하였습니다. 다양한 repository 에서 서로 다른 framework 를 사용하여, gan 을 구현하는 것을 참조 할 수 있습니다. 하지만, 직접 공부해보고, loss function의 의미를 해석하여 직접 구현해보며 많은 것을 배울 수 있었습니다. 실제로 loss function 위 처럼 바꾸지 않았을 때는 1000 epoch 를 학습하더라도 generator가 학습 되지 않는 실패 경험을 통해, 자세한 논문 리딩과 분석은 구현에 있어 필수적임을 느낄 수 있었습니다.

5. Futher Study

자원의 제약으로 작은 모델 구조와 hyper parameter tuning을 더 하지 못한게 아쉽습니다. generator 가 생성해내는 모양이 조금더 세밀하게 할 수 있는 것을 더 해보고 싶고, MNIST 데이터 뿐만아니라 본 논문의 참조사진처럼 CIFAR10 이나, TFD 데이터에 대해서도 실험해보고 싶습니다.

 

[출처] https://github.com/EmjayAhn/GAN-pytorch/blob/master/README.md

 

본 웹사이트는 광고를 포함하고 있습니다.
광고 클릭에서 발생하는 수익금은 모두 웹사이트 서버의 유지 및 관리, 그리고 기술 콘텐츠 향상을 위해 쓰여집니다.
번호 제목 글쓴이 날짜 조회 수
공지 오라클 기본 샘플 데이터베이스 졸리운_곰 2014.01.02 25085
공지 [SQL컨셉] 서적 "SQL컨셉"의 샘플 데이타 베이스 SAMPLE DATABASE of ORACLE 가을의 곰을... 2013.02.10 24564
공지 [G_SQL] Sample Database 가을의 곰을... 2012.05.20 25943
1025 [postgreSQL] PostgreSQL 계층형 쿼리 구현 방법 졸리운_곰 2023.01.29 35
1024 [postgreSQL] ORACLE쿼리에서 postgreSQL쿼리 변환 졸리운_곰 2023.01.29 26
1023 [postgreSQL] [PostgreSQL] stored function(stored procedures) 사용하기 졸리운_곰 2023.01.23 30
1022 [SQL] CRUD 기본 사용법 file 졸리운_곰 2023.01.23 30
1021 [postgreSQL] [Docker] Docker에 PostgreSQL 설치하기 file 졸리운_곰 2023.01.21 25
1020 [MYSQL] 테이블 스키마 설계 고려사항 졸리운_곰 2022.12.03 33
1019 [MySQL] "아는 만큼 빨라진다" 마이SQL 성능 튜닝 팁 10가지 file 졸리운_곰 2022.11.29 30
1018 [오라클] 오라클 연동 오류 [ORA-01017: invalid username/password; logon denied] 졸리운_곰 2022.11.28 76
1017 [오라클] 제약조건 확인 (FK 찾기) 졸리운_곰 2022.11.28 68
1016 [ADsP] 취업 깡패 ADP 뿌시기! "빅데이터 분석가 최고의 자격증이에요" file 졸리운_곰 2022.11.20 22
1015 [기계학습] [번역] TensorFlow Lite 튜토리얼 3 부 : Raspberry Pi의 음성 인식 졸리운_곰 2022.11.18 7
1014 [기계학습] [번역] TensorFlow Lite 튜토리얼 2 부 : 음성 인식 모델 교육 졸리운_곰 2022.11.18 13
1013 [기계학습] [번역] TensorFlow Lite 튜토리얼 1 부 : Wake Word 기능 추출 졸리운_곰 2022.11.18 10
1012 [기계학습][딥러닝] Generative Adversarial Net (GAN) PyTorch 구현: 손글씨 생성 file 졸리운_곰 2022.11.18 54
1011 [hadoop] Cloudera Quick Start VM in Hyper-V file 졸리운_곰 2022.11.14 14
1010 [기계학습][딥러닝] Flask를 이용하여 파이토치를 REST API로 베포하기 file 졸리운_곰 2022.11.12 44
» [기계학습][머신러닝][딥러닝] Vanilla GAN file 졸리운_곰 2022.11.08 13
1008 [기계학습][머신러닝][딥러닝] Generative Adversarial Net (GAN) PyTorch 구현: 손글씨 생성 file 졸리운_곰 2022.11.08 103
1007 [기계학습][머신러닝][딥러닝] DCGAN 튜토리얼 졸리운_곰 2022.11.08 4
1006 [PyTorch] pytorch 기본 문법 및 코드, 팁 snippets file 졸리운_곰 2022.10.20 30
대표 김성준 주소 : 경기 용인 분당수지 U타워 등록번호 : 142-07-27414
통신판매업 신고 : 제2012-용인수지-0185호 출판업 신고 : 수지구청 제 123호 개인정보보호최고책임자 : 김성준 sjkim70@stechstar.com
대표전화 : 010-4589-2193 [fax] 02-6280-1294 COPYRIGHT(C) stechstar.com ALL RIGHTS RESERVED