저는 현재 트랜스포머 전체 구조를 코드화하는 작업중에 있습니다.

https://smartest-suri.tistory.com/48

 

딥러닝 | 트랜스포머(2017) 논문 리뷰 - Attention is all you need

[참고] 본 포스팅은 수리링 본인이 Attention is all you need 논문을 처음부터 끝까지 직접 읽으며 분석하고 리뷰하여 작성했습니다. 불펌 절대 금지! 본문 내용에 잘못된 부분이 있다면 댓글 달아주

smartest-suri.tistory.com


지난 논문 리뷰에서 살펴본 바와 같이 트랜스포머는 '위치 인코딩(positional encoding)'을 통해 통으로 받은 입력에 문맥 정보를 추가하는데요. 본 포스팅에서는 포지셔널 인코딩을 파이토치로 구현하는 과정에서

1.  제가 처음에 쓴 코드에 어떤 문제가 있었고
2.  그걸 어떻게 더 나은 방향으로 개선했으며
3.  개선한 최종 코드 결과물은 어떠한지

를 중점적으로 다루어 보겠습니다.


출처 :  https://arxiv.org/pdf/1706.03762

참고 문헌 목록


문제 1

먼저, 논문의 positional encoding 수식과 kaggle 참조 코드를 참조하면서 첫 코드를 작성했습니다. (내가 왜 그랬을까)

캐글 코드

먼저, 참조한 캐글 코드에서 빨간색 박스 부분이 잘못 되었음을 곧바로 인지했습니다. (여기서 바로 캐글 닫았어야 했는데)

논문 원문에서 발췌

포지셔널 인코딩 수식에 의하면,

  • 임베딩 벡터의 차원이 짝수일 때와 홀수일 때를 나누어 sin, cos에 각각 할당합니다.
  • 싸인 안에 분모값으로 들어가는 10,000의 지수는 짝홀이 짝을 맞춰 같은 값의 짝수로 할당됩니다.
  • 그런데 캐글 코드는 이 부분이 좀 지저분하고, 안 맞습니다.
  • for문을 보시면 range가 0부터 임베딩 벡터의 차원까지 step을 2로 건너뛰어 i 자체가 0, 2, 4, 6...과 같이 짝수로 할당이 됩니다. 이미 i값이 짝수인데, 지수를 보면 거기에 2를 또 곱해서 중복이 되어 4의 배수가 됩니다. 그리고 10,000의 지수에 짝홀이 짝을 맞춰 같은 값의 짝수로 할당되지 않고, cos의 지수가 더 크게 할당됩니다.

그런데 실제로 저렇게 수식을 잘못 썼다고 해도 사실 크리티컬한 성능의 차이는 없을것이라고 여겨진다는 코멘트를 받았습니다. 어쨌든 포지셔널 인코딩의 핵심이 position과 i값에 따라서 삼각함수로부터 다른 임의의 값을 뽑아내는 것에 있고, 그래서 저렇게 써도 어쨌든 비스무리하게 돌아는 갈 것이라는 것입니다. 하지만 논문의 수식을 100% 그대로 재현하고 싶은 저의 입장에서는 굉장히 거슬렸고요. 그래서 일단 아래와 같이 수정을 해주었습니다.

# 위치 인코딩(Positional Embedding)
class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len, d_model):
        """
        입력 - max_seq_len : input sequence의 최대 길이
              d_model : 임베딩 차원
        """
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        
        pe = torch.zeros(max_seq_len, self.d_model) # 포지셔널 인코딩 벡터 -> 모든 자리에 초기값 0으로 설정
        for pos in range(max_seq_len):
            for i in range(0, self.d_model, 2): # 0, 2, 4... 
                pe[pos, i] = math.sin(pos / (10000 ** (i/self.d_model))) # 짝수 차원 -> 싸인 (0->0, 2->2..)
                pe[pos, i+1] = math.cos(pos/ (10000 ** (i/self.d_model))) # 홀수 차원 -> 코싸인 (1->0, 3->2, 5->4....)
        pe = pe.unsqueeze(0) # [max_seq_len, d_model] 차원 -> [1, max_seq_len, d_model] 차원으로 1차원 앞에 추가해줌 (예 : [6, 4] -> [1, 6, 4])
        # 해주는 이유 : input shape이 [batch_size, seq_len, d_model] 이기 때문이다!! (임베딩 결과값이랑 더해야되니깐 shape 맞춰주는거임)
        self.register_buffer('pe', pe) # pe 벡터를 buffer로 register : state_dict()에는 parameter와 buffer가 있는데, 그 중 buffer로 등록 -> 학습할때 update 되지 않도록 고정

삼각함수 안에 10000의 지수 부분을 전부 i로 바꾸어 주면서 논문의 수식과 통일을 시켜주었습니다.


문제 2

위에서 이미 신뢰를 잃어서 (ㅋㅋㅋㅋ) kaggle 코드를 꺼버리려고 했는데, 일단 forward까지만 참조를 해보자는 마음으로... forward 함수까지 작성을 해보았습니다. 1차 수정한 PositionalEncoding 클래스를 전체 보여드리겠습니다.

# 위치 인코딩(Positional Embedding)
class PositionalEncoding(nn.Module):
    def __init__(self, max_seq_len, d_model):
        """
        입력 - max_seq_len : input sequence의 최대 길이
              d_model : 임베딩 차원
        """
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model
        
        pe = torch.zeros(max_seq_len, self.d_model) # 포지셔널 인코딩 벡터 -> 모든 자리에 초기값 0으로 설정
        for pos in range(max_seq_len):
            for i in range(0, self.d_model, 2): # 0, 2, 4... 
                pe[pos, i] = math.sin(pos / (10000 ** (i/self.d_model))) # 짝수 차원 -> 싸인 (0->0, 2->2..)
                pe[pos, i+1] = math.cos(pos/ (10000 ** (i/self.d_model))) # 홀수 차원 -> 코싸인 (1->0, 3->2, 5->4....)
        pe = pe.unsqueeze(0) # [max_seq_len, d_model] 차원 -> [1, max_seq_len, d_model] 차원으로 1차원 앞에 추가해줌 (예 : [6, 4] -> [1, 6, 4])
        # 해주는 이유 : input shape이 [batch_size, seq_len, d_model] 이기 때문이다!! (임베딩 결과값이랑 더해야되니깐 shape 맞춰주는거임)
        self.register_buffer('pe', pe) # pe 벡터를 buffer로 register : state_dict()에는 parameter와 buffer가 있는데, 그 중 buffer로 등록 -> 학습할때 update 되지 않도록 고정 
        
    def forward(self, x):
    	x = x * math.sqrt(d_model) # 워드임베딩 벡터에 √d_model 곱해줌 (논문 3.4장)
        seq_len = x.size(1) # 각 시퀀스가 몇개의 토큰인지 숫자를 뽑아냄 (max_seq_len이 6이라면 6 이하의 숫자일것)
        x = x + self.pe[:, :seq_len].to(x.device) # 길이 맞춰서 pe랑 더해줌!!!

처음엔 인지하지 못했는데, math가 좀 많이 쓰인 것이 보입니다. math.sin math.cos math.sqrt.......

math 대신 torch를 쓰면 어떨까요?

시간 측정해보기 (math)

time 라이브러리를 불러와서 소요되는 시간을 측정해보면, 1.10에 가까운 값이 나옵니다.

시간 측정해보기 (torch)

https://pytorch.org/tutorials/beginner/translation_transformer.html

 

Language Translation with nn.Transformer and torchtext — PyTorch Tutorials 2.3.0+cu121 documentation

Note Click here to download the full example code Language Translation with nn.Transformer and torchtext This tutorial shows: How to train a translation model from scratch using Transformer. Use torchtext library to access Multi30k dataset to train a Germa

pytorch.org

파이토치 닥스에서 Transformer Tutorial을 찾아내서, 코드를 비교해 봅니다.

아놔 첨부터 이거 볼걸. 확실히 다르네.

  • 파이토치 공식 닥스에서는 for문 대신 indexing을 활용하고 있으며
  • math 대신 torch.sin/torch.cos를 사용하고 있습니다.

파이토치 닥스 튜토리얼 코드로 작업에 소요된 시간을 측정해 비교해보니, 0.01이 나옵니다. 샘플로 돌려보기만 했는데도 약 100배의 속도 차이가 난다면, 실제로 모델을 만들었을때 얼마나 큰 성능 저하를 유발하게 될까요? 저는 이쯤해서 참조하고 있던 캐글 코드를 버리고, 파이토치 공식 닥스를 참조하면서 코드를 쓰기로 합니다. (ㅋㅋㅋㅋㅋ)

앞으론 바로바로 공식 문서부터 찾아보는 습관을... 어쨌든 코드의 신뢰도와 정확도, 효율성에 항상 의문을 가지고 바라봐야 한다는 좋은 교훈을 얻었으니, 삽질은 아니었다고 생각합니다 :-) .. 우는거 아님


최종 코드

그럼 이제 pytorch docs 페이지의 Positional Encoding 클래스를 참고해서 2차로 코드를 수정하려는데요.

근데 이번엔 저기 math.log()부분이 거슬려요. 미치겄네.  

여긴 왜 torch를 안쓰고 굳이 math를 썼을까요?
torch에는 log함수가 (설마) 없을까요?

https://docs.python.org/ko/3/library/math.html
https://pytorch.org/docs/stable/generated/torch.log.html

찾아 보니 있습니다.
다른 점이라고 하면, math.log나 torch.log나 똑같이 자연로그를 취해서 반환하는데,
math와 달리 torch는 tensor를 입력하고 tensor를 출력합니다.

확인해보니, torch.log()를 사용할 경우 텐서의 값을 뽑아주는 변환작업이 추가가 되어 오히려 math보다 더 비효율적이 됩니다.

이번에도 간단하게 작동 시간을 비교해 봤는데요.

값을 바로 넣지 못하고 tensor를 넣어준 다음 .item()을 사용해서 다시 그 값을 추출해야 하는 torch.log()보다 math.log()가 훨씬 빠른 것을 확인할 수 있습니다. 그런데 이제 또 거슬리는게(ㅋㅋㅋㅋㅋㅋㅋㅋㅋㅋ)

아니 왜 토치랑 매스랑 똑같은 자연로그를 취해주는데 왜 결과값이 달라요?

여기까지 오니까 살짝 기빨려서 chatgpt-4o한테 물어봤습니다.

네 그렇다고 합니다. 공식 닥스에서 math를 쓴걸 보면 math를 써도 되나본데 맞느냐, torch랑 math랑 서로 값이 좀 다르더라도 성능에 큰 차이가 없는것이냐, 근데 너 어디서 찾아서 그렇게 대답하는거냐, 출처 밝혀라... 등등 집요하게 물어봤습니다.

그래서 결론은 그거 그렇게 별로 안중요하니까 그냥 파이토치 공식 닥스를 믿고 math.log()를 사용하면 된다는 것이었습니다. 이후 √d_model을 곱하기 위해 사용되는 math.sqrt()도 같은 이유로 torch 대신 사용됩니다. 텐서 연산이 아닌 간단한 스칼라 값을 계산할 때에는 math가 더 효율적일 수 있습니다. (개-운)

그럼 이제 다시 본론으로 돌아와서 진도좀 나갈게요. 공식 닥스 참고해서 다시 쓴 positional encoding 클래스입니다. 제가 편한대로 고쳐서 썼기 때문에 닥스 공식문서와 다른 부분이 많습니다.

# 위치 인코딩(Positional Embedding)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int,
                       dropout: float,
                       maxlen: int = 5000,
                       device = None):
        super(PositionalEncoding, self).__init__()
        
        # 위치 정의 벡터
        pos = torch.arange(0, maxlen, device = device).reshape(maxlen, 1)
        # 위치에 곱해줄 값 정의
        den = torch.exp(-torch.arange(0, d_model, 2, device = device) * math.log(10000) / d_model)
        # 포지셔널 인코딩 벡터 초기값 설정 (모든 자리 0으로 시작)
        pe = torch.zeros((maxlen, d_model))
        # 포지셔널 인코딩 마지막 차원이 짝수일 때 (슬라이싱 0::2 -> 0부터 시작해서 스텝 2씩이니까 짝수)
        pe[:, 0::2] = torch.sin(pos * den) # 싸인함수
        # 포지셔널 인코딩 마지막 차원이 홀수일 때 (슬라이싱 1::2 -> 1부터 시작해서 스텝 2씩이니까 홀수)
        pe[:, 1::2] = torch.cos(pos * den) # 코싸인함수
        # 차원 추가
        pe = pe.unsqueeze(0) # 임베딩 결과값이랑 더해야되니깐 shape 맞춰주기
        
        self.dropout = nn.Dropout(dropout) # dropout 추가
        self.register_buffer('pe', pe) # pe 벡터를 buffer로 register : state_dict()에는 parameter와 buffer가 있는데, 그 중 buffer로 등록 -> 학습할때 update 되지 않도록 고정 
        
    def forward(self, x: torch.Tensor):
        seq_length = x.size(1) # 입력 시퀀스의 길이 반환
        pe = self.pe[:, :seq_length, :].expand(x.size(0), -1, -1) # 입력 시퀀스의 길이에 맞춰 위치 인코딩 텐서를 슬라이싱
        return self.dropout(x + pe)

(1) pos

먼저 maxlen = 20일 때 pos의 결과를 찍어보면, 다음과 같습니다.

[참고] 저는 모델링을 하면서 벡터가 머릿속에 바로바로 시각화가 안 되면, python IDLE을 켜가지고 이렇게 대충이라도 시각화를 해서 결과를 바로바로 확인하는 버릇이 있습니다. 그럼 좀더 머릿속에서 구체화가 빠르게 됩니다. 그냥.. 이런 간단한건 idle이 편하더라고요.

(2) den

maxlen = 20, d_model = 100일 때 den의 결과를 찍어보면, 다음과 같습니다.

안에 몇개의 값이 있을까요?

총 50개가 있습니다.

den = torch.exp(-torch.arange(0, emb_size, 2) * torch.log(10000) / d_model)
# 보기 편하게 device는 뺐음

den을 구하는 코드는 아래와 같이 하나씩 직접 손으로 써서 계산해서 이해했어요. 참고로 여기 쓰인 torch.exp() 함수는 입력값을 e를 밑으로 하는 지수함수에 대입해서 출력합니다.

참고 : https://pytorch.org/docs/stable/generated/torch.exp.html
[참고] emb_size = d_model 입니다. 제가 나중에 변수명을 바꿨습니다.

삼각함수 안에 들어가는 분모 부분을 den이라는 벡터로 효율적으로 표현하여 pos * den과 같이 아주 간단하게 나타내 주었습니다. 

        # 위치 정의
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        # 위치에 곱해줄 값 정의
        den = torch.exp(-torch.arange(0, d_model, 2) * torch.log(10000) / d_model)
        # 포지셔널 인코딩 벡터 초기값 설정 (모든 자리 0으로 시작)
        pe = torch.zeros((maxlen, d_model))
        # 포지셔널 인코딩 마지막 차원이 짝수일 때 (슬라이싱 0::2 -> 0부터 시작해서 스텝 2씩이니까 짝수)
        pe[:, 0::2] = torch.sin(pos * den) # 싸인함수
        # 포지셔널 인코딩 마지막 차원이 홀수일 때 (슬라이싱 1::2 -> 1부터 시작해서 스텝 2씩이니까 홀수)
        pe[:, 1::2] = torch.cos(pos * den) # 코싸인함수

기존의 캐글 코드와 비교한다면 for - for 더블 iteration 없이 슬라이싱만으로 해당 수식을 표현하며 벡터 내적을 활용하므로, 계산이 훨씬 빠를 수밖에... 아니 캐글이 느릴수밖에 없습니다.


(3) sin, cos

pe의 shape을 찍어봤습니다. maxlen이 20이고 d_model를 100으로 두었으니, 당연히 (20, 100)이 나오네요.

        # 포지셔널 인코딩 마지막 차원이 짝수일 때 (슬라이싱 0::2 -> 0부터 시작해서 스텝 2씩이니까 짝수)
        pe[:, 0::2] = torch.sin(pos * den) # 싸인함수
        # 포지셔널 인코딩 마지막 차원이 홀수일 때 (슬라이싱 1::2 -> 1부터 시작해서 스텝 2씩이니까 홀수)
        pe[:, 1::2] = torch.cos(pos * den) # 코싸인함수

슬라이싱을 통해서 d_model, 즉 마지막 차원의 짝수 번째 요소와 홀수 번째 요소를 지정하면서, 원래 (20, 100)이였던 pe가 절반인 벡터 (20, 50) 두개로 나뉘었어요. 

pos의 크기는 (20, 1)이고 den의 크기는 (50)이므로 두개를 벡터 내적하면 (20, 50)의 쉐입이 나옵니다. 이 내적한 값에...

  • 싸인함수를 취해서 pe의 짝수 차원에 갈아끼워 줍니다.
  • 코싸인함수를 취해서 pe의 홀수 차원에 갈아끼워 줍니다.

shape이 같기 때문에 어렵지 않게 호로록 가능합니다.


(4) unsqueeze

마지막으로 unsqueeze를 통해서 나중에 임베딩 벡터와 더해줄 때 shape이 맞도록 해줍니다.


(5) forward

진짜 마지막으로 한번만 더 볼게요.

# 위치 인코딩(Positional Embedding)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int,
                       dropout: float,
                       maxlen: int = 5000,
                       device = None):
        super(PositionalEncoding, self).__init__()
        
        # 위치 정의 벡터
        pos = torch.arange(0, maxlen, device = device).reshape(maxlen, 1)
        # 위치에 곱해줄 값 정의
        den = torch.exp(-torch.arange(0, d_model, 2, device = device) * math.log(10000) / d_model)
        # 포지셔널 인코딩 벡터 초기값 설정 (모든 자리 0으로 시작)
        pe = torch.zeros((maxlen, d_model))
        # 포지셔널 인코딩 마지막 차원이 짝수일 때 (슬라이싱 0::2 -> 0부터 시작해서 스텝 2씩이니까 짝수)
        pe[:, 0::2] = torch.sin(pos * den) # 싸인함수
        # 포지셔널 인코딩 마지막 차원이 홀수일 때 (슬라이싱 1::2 -> 1부터 시작해서 스텝 2씩이니까 홀수)
        pe[:, 1::2] = torch.cos(pos * den) # 코싸인함수
        # 차원 추가
        pe = pe.unsqueeze(0) # 임베딩 결과값이랑 더해야되니깐 shape 맞춰주기
        
        self.dropout = nn.Dropout(dropout) # dropout 추가
        self.register_buffer('pe', pe) # pe 벡터를 buffer로 register : state_dict()에는 parameter와 buffer가 있는데, 그 중 buffer로 등록 -> 학습할때 update 되지 않도록 고정 
        
    def forward(self, x: torch.Tensor):
        seq_length = x.size(1) # 입력 시퀀스의 길이 반환
        pe = self.pe[:, :seq_length, :].expand(x.size(0), -1, -1) # 입력 시퀀스의 길이에 맞춰 위치 인코딩 텐서를 슬라이싱
        return self.dropout(x + pe)

forward에서 바뀐 점은 다음과 같습니다

  1. 캐글 코드와 비교했을 때 √d_model를 곱해주는 코드를 기본 워드임베딩 클래스 모듈로 이동했습니다. 파이토치 공식 닥스를 참고하여 dropout이 추가되었습니다.
  2. 파이토치 공식 닥스와 비교했을 때 조금 더 여러줄의 코드로 나누어서 (스스로) 이해하기 편하게 작성했습니다.

캐글 코드
닥스 코드



포지셔널 인코딩 코드화! 여기까지입니다. 진짜 이거 하나를 이렇게 딥하게 팔줄은 저도 몰랐는데요. 확실히 인간은 삽질을 통해 발전하는게 맞다... 남의 코드 많이 들여다 보되... 절대로 믿지는 말아라... 특히 캐글..... 이라는 좋은 교훈을 얻었습니다.

time 모듈을 통해서 시간을 측정하고 계산 효율성을 판단하는 일도 재미있었습니다. 이렇게 해볼 수 있도록 힌트를 주신 SK플래닛 T아카데미 ASAC 5기 권강사님께 무한 감사의 말씀을 전하며.............(리스펙 그 잡채) 혹시라도 처음부터 끝까지 전부 다 읽어주신 분이 계시다면, 정말 감사합니다. :-)

포스팅 끝! 본문 코드 오류가 발견될시 꼭 댓글로 알려주세요. 

 

 

 

 


[번외] 임베딩 전체과정
# 위치 인코딩(Positional Embedding)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int,
                       dropout: float,
                       maxlen: int = 5000,
                       device = None):
        super(PositionalEncoding, self).__init__()
        
        # 위치 정의 벡터
        pos = torch.arange(0, maxlen, device = device).reshape(maxlen, 1)
        # 위치에 곱해줄 값 정의
        den = torch.exp(-torch.arange(0, d_model, 2, device = device) * math.log(10000) / d_model)
        # 포지셔널 인코딩 벡터 초기값 설정 (모든 자리 0으로 시작)
        pe = torch.zeros((maxlen, d_model))
        # 포지셔널 인코딩 마지막 차원이 짝수일 때 (슬라이싱 0::2 -> 0부터 시작해서 스텝 2씩이니까 짝수)
        pe[:, 0::2] = torch.sin(pos * den) # 싸인함수
        # 포지셔널 인코딩 마지막 차원이 홀수일 때 (슬라이싱 1::2 -> 1부터 시작해서 스텝 2씩이니까 홀수)
        pe[:, 1::2] = torch.cos(pos * den) # 코싸인함수
        # 차원 추가
        pe = pe.unsqueeze(0) # 임베딩 결과값이랑 더해야되니깐 shape 맞춰주기
        
        self.dropout = nn.Dropout(dropout) # dropout 추가
        self.register_buffer('pe', pe) # pe 벡터를 buffer로 register : state_dict()에는 parameter와 buffer가 있는데, 그 중 buffer로 등록 -> 학습할때 update 되지 않도록 고정 
        
    def forward(self, x: torch.Tensor):
        seq_length = x.size(1) # 입력 시퀀스의 길이 반환
        pe = self.pe[:, :seq_length, :].expand(x.size(0), -1, -1) # 입력 시퀀스의 길이에 맞춰 위치 인코딩 텐서를 슬라이싱
        return self.dropout(x + pe)

# 워드 임베딩 -> 파이토치 nn.Embeding : https://wikidocs.net/64779
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, d_model):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
    
    def forward(self, tokens: torch.Tensor):
        # 토큰 임베딩에 √d_model 곱해주기 (논문 3.4장에 그러랍디다)
        out = self.embedding(tokens.long()) * math.sqrt(self.d_model)
        # self.long()는 self.to(torch.int64)와 같은 역할
        return out

# "트랜스포머 임베딩" 만들어주기 (임베딩 + 포지셔널 인코딩)
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, drop_prob, max_len, device)
        
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_enc = self.pos_enc(tok_emb) # 두개 더하는건 이미 pos에서 했음
        return pos_enc

잘 됐는지 테스트

if __name__ == "__main__":
    vocab_size = 10000
    d_model = 512
    max_len = 5000
    drop_prob = 0.1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob, device)
    input_tokens = torch.randint(0, vocab_size, (32,100)).to(device)
    output = model(input_tokens)
    print(output.shape)

output

확인완

+ Recent posts