논문

하루에도 수만개의 글자를 읽고 있습니다. 하루에도 수백장의 종이를 들춰 읽습니다.
이것은 그 읽기에 대한 일기입니다.

Domain Separation Networks

1 Introduction

  • 최근 지도학습 알고리즘의 성공은 일부 large-scale 데이터셋에서 이루어지고 있다.
  • 아쉽게도 이러한 데이터셋을 수집하고 annotating하는 작업은 너무 많은 시간이 많이 드는 일이다.
  • 대안으로, 실제 영상이 아닌 컴퓨터가 생성해낸 영상을 사용하는 방법이 있지만, 이러한 이미지로 학습된 네트워크는 실제 환경(도메인)에서 동작하도록 일반화되지 않게 되어버린다.
  • 이와 관련하여 학습 분포와 테스트 분포가 다른 상황의 시나리오를 가정하고, 두 분포에 대해서 domain-invariance한 표현을 학습하는 문제를 풀어보고자 한다.
  • 이러한 경우, 소스 데이터는 우리가 풀고자 하는 어떤 태스크에 맞게 레이블링되어있고, 소스 도메인에서의 knowledge를 ground-truth가 없는 타겟 도메인으로 transfer하는 문제로 생각할 수 있다.
  • 이 논문에서는 이미지를 분류하고 포즈를 추정하는 문제에 이를 적용하고자 한다.
  • 픽셀로 이루어진 이미지의 두 분포를 생각하였을 때, "low-level"에서의 차이는 노이즈, 해상도, 조명, 컬러로 발생할 수 있다고 생각하였고, "high-level"의 차이는 클래스의 수, 객체 타입, geometric variation, 3D position, pose에 의한 차이라고 생각하였다.
  • 따라서 소스 도메인과 타겟 도메인의 분포 차이는 low-level의 차이가 주된 차이를 야기하고, high-level에서의 파라미터는 비슷한 분포를 가지며, 동일한 label space를 가진다고 가정하엿다.
  • 논문에서 제안하는 방법인 Domain Separation Networks (DSN)은 domain-invariant한 표현을 학습하는 것이 목적이다.
  • 기존의 방법 중 소스 도메인과 타겟 도메인이 서로 공유하는 공통 표현을 사용하는 방법이 있다. 하지만, 이 방법은 공통 표현이 노이즈에 의한 오염에 취약하였다.
  • 이 논문의 방법은 각 도메인의 프라이빗 서브스페이스에서는 해당 도메인만의 고유한 속성인 배경이나 low-level 이미지 특성을 가지도록 하고, 공유 서브 스페이스에서는 두 이미지에서 공통된 표현을 찾을 수 있도록 하였다.
  • 단, 공유 공간과 프라이빗 공간은 서로 othogonal하게 찾도록 유도되어 각 도메인의 고유한 정보를 분리할 수 있도록 하였다.
  • 따라서, 우리가 하려는 태스크에 대해서 각 표현들이 더 의미부여가 가능하도록 하였다.

2 Related work

3 Method

  • 우리의 목표는 레이블이 있는 소스 도메인의 데이터셋과 레이블이 없는 타겟 도메인의 데이터셋이 주어졌을 때 소스 도메인의 데이터를 분류하면서도, 타겟 도메인에서도 적용될 수 있는 일반화된 분류기를 학습하는 것이다.
  • 이전의 연구되었던 방법과 마찬가지로 소스 도메인과 타겟 도메인이 서로 비슷한 표현을 가지도록 훈련하는방법을 사용한다. 하지만, 이전의 방법들은 공유 공간에서의 표현에 영향을 많이 주는 노이즈를 허용하는 방법이었다.
  • 이 논문에서는 도메인의 두가지 표현으로 나누어 프라이빗 요소와 공유 요소로 나눈 뒤, 이 둘은 서로 조합하는 모델을 사용한다.
  • 프라이빗 요소 표현은 하나의 도메인에 특정되는 요소이며, 공유 요소 표현은 두 도메인이 모두 공유하는 표현이다.
  • 이와 같이 표현을 두개의 요소로 분리하기 위해서, 두개의 파트별로 비종속적인 로스 함수를 사용하였다.
  • 이러한 방법을 사용하여 프라이빗 표현도 유용하게 사용할 수 있는 일반화 능력을 가진 학습을 할 수 있었다.
  • 로스 함수의 조합은 공통 표현은 서로 비슷하면서 프라이빗 표현은 서로 다르도록 유도하는 것이 목적이다.
  • 두개의 표현 공간으로 나누었기 때문에, 각 도메인이 가지는 고유한 특성에 영향을 덜 받아 오염되지 않는 공통 표현을 얻을 수 있고, 이러한 표현을 입력으로 갖는 분류기는 일반화를 더 잘 할 수 있게 된다.
  • $\mathbf{X}_s = \{ (\mathbf{x}_i^s, \mathbf{y}_i^s) \}_{i=0}^{N_s} $는 레이블된 $N_s$개의 데이터셋이고, 이것은 소스 도메인에서 샘플링 $\mathbf{x}_i^s \sim \mathcal{D}_s$된 것이다.
  • $ \mathbf{X}^t = \{ \mathbf{x}_i^{N_t} \}_{i=0}^{N_t} $는 레이블이 없는 $N_t$개의 데이터셋으로 타겟 도메인에서 샘플링 $\mathbf{x}_i^t \sim \mathcal{D}_T$된 것이다.
  • $E_c(\mathbf{x};\theta_c)$는 이미지 $\mathbf{x}$를 공유 혹은 공통 표현에서의 특징 $\mathbf{h}_c$로 보내는 함수이다.
  • $E_p(\mathbf{x};\theta_p)$는 이미지 $\mathbf{x}$를 각 도메인의 프라이빗 표현에서의 특징 $\mathbf{h}_p$로 보내는 함수이다.
  • $D(\mathbf{h};\theta_d)$는 표현 $\mathbf{h}$를 받아 복원 이미지 $\hat{\mathbf{x}}$로 디코딩하는 함수이다.
  • $G(\mathbf{h}, \theta_g)$는 태스크에 대한 함수로, $\mathbf{h}$를 받아 prediction된 결과 $\hat{\mathbf{y}}$를 계산하는 함수이다.

3.1 Learning

  • DSN를 인퍼런스하면, $\hat{\mathbf{x}} = D(E_c(\mathbf{x}) + E_p(\mathbf{x})) $와 $\hat{\mathbf{y}} = G(E_c(\mathbf{x}))$를 얻을 수 있다.
  • 학습의 목적은 파라미터 $\Theta = \{ \theta_c, \theta_p, \theta_d, \theta_g \}$에 대해 다음의 로스를 최소화하는 것이다.

$\mathcal{L} = \mathcal{L}_{\text{task}} + \alpha\mathcal{L}_{\text{recon}} + \beta \mathcal{L}_{\text{difference}} + \gamma \mathcal{L}_{\text{similarity}}$

  • 여기서 $\alpha, \beta, \gamma$는 각 로스 항 끼리의 상호작용을 조절하는 가중치이다.
  • 분류 로스 $\mathcal{L}_{\text{task}}$는 우리가 궁극적으로 관심있는 출력 레이블을 학습하도록 모델링한다.
  • 타겟 도메인에는 레이블이 없기 때문에 이 로스는 소스 도메인에만 적용된다.

$\mathcal{L}_{\text{task}} = - \sum_{i=0}^{N_s} \mathbf{y}_i^s \cdot \log \hat{\mathbf{y}}_i^s$

  • $\mathbf{y}_i^s$는 클래스 레이블을 one-hot encoding으로 만든 것이고, $\hat{\mathbf{y}}_i^s$는 $G(E_c(\mathbf{x}_i^s))$ 에서 나온 softmax prediction이다.
  • 복원 로스를 계산하는데에는 scale-invariant mean squared error term을 사용하였다.

$ \mathcal{L}_{\text{recon}} = \sum_{i=1}^{N_s} \mathcal{L}_{\text{si_mse}}(\mathbf{x}_i^s, \hat{\mathbf{x}}_i^s) + \sum_{i=1}^{N_s} \mathcal{L}_{\text{si_mse}}(\mathbf{x}_i^t, \hat{\mathbf{x}}_i^t)$

$ \mathcal{L}_{\text{si_mse}}(\mathbf{x}, \hat{\mathbf{x}}) = \frac{1}{k}||\mathbf{x} - \hat{\mathbf{x}}||_2^2 - \frac{1}{k^2}([\mathbf{x} - \hat{\mathbf{x}}]\cdot \mathbf{1}_k)^2$

  • 여기서 $k$는 입력 $x$의 픽셀 수이고, $\mathbf{1}_k$는 1로 이루어진 길이 $k$의 벡터이다.
  • $||\cdot||^2_2$는 squared $L_2$-norm이다.
  • 이러한 모델을 통해서 객체의 전체적인 모양을 생성해 내도록 유도하였다.
  • Difference 로스는 두 도메인에 모두 적용되는데, 이는 공통 표현으로의 인코더와 프라이빗 표현 인코더가 입력에 대해서 서로 다른 양상을 보이도록 학습하도록 한다.
  • 로스는 soft subspace othogonality를 특정을 가지도록 각 도메인의 프라이빗 표현과 공통 표현을 제한요소를 주었다.
  • $\mathbf{H}_c^s$와 $\mathbf{H}_p^s$를 각 도메인에서 나온 여러 샘플들의 공통 표현을 행으로 갖는 행렬이라고 정의하고, 마찬가지로 프라이빗 표현으로 이루어진 행렬 $\mathbf{H}_p^s, \mathbf{H}_p^t$를 정의한다면, 공통 표현과 프라이빗 표현이 서로 orthogonality를 가질 수 있도록 difference 로스를 다음과 같이 정의하엿다.

$ \mathcal{L}_{\text{difference}} = ||{\mathbf{H}_c^s}^T \mathbf{H}_p^s ||_F^2 + ||{\mathbf{H}_c^t}^T \mathbf{H}_p^t ||_F^2$

  • 여기서 $||\cdot||_F^2$는 Frobenius norm이다.
  • 마지막으로 similarity 로스는 공유 인코더에 나온 $\mathbf{h}_c^s, \mathbf{h}_c^t$가 가능한 도메인의 특정을 반영하지 않고 서로 비슷하게 나오도록 유도한다.
  • similarity 로스는 다음에서 자세히 알아보겠다.

3.2 Similarity Losses

  • 이전 연구에서 제안되었던 Domain adversarial similarity 로스는 표현이 어느 도메인에서 나온 것인지 분간하기 힘들도록 한다.
  • 이러한 "confusion"은 Gradient Reversal Layer (GRL)을 통해 학습될 수 있다.
  • GRL은 출력은 똑같이 내보내면서, 그래디언트의 부호만을 반대로 바꾸는 레이어이다.
  • 도메인 분류기 $Z(Q(\mathbf{h}_c) ; \theta_z) \rightarrow \hat{d}$는 공통 표현 $\mathbf{h}_c$을 어느 도메인인지 구분하는 레이블 $\hat{d} \in \{ 0, 1 \}$을 예측하는 분류기이다.
  • 학습을 진행하면 소스와 타겟 도메인에서 나온 이미지의 인코딩을 구분할 수 있도록 $Z$의 능력을 향상시키도록 유도가 될 것이다.
  • 하지만 GRL을 통과하여 부호가 바뀐 그래디언트는 공통 표현 인코더에서 도메인을 구분할 수 있는 성능을 떨어트리도록 인코더를 학습시킨다.

$$ \mathcal{L}_{\text{similarity}}^{\text{DANN}} = \sum_{i=0}^{N_s + N_t} \{ d_i \log \hat{d}_i + (1- d_i) \log(1-\hat{d}_i) \}$$

  • 다른 로스로 Maximum Mean Discrepancy (MMD) 로스는 커널 기반의 두 샘플 페어의 거리를 계산하는 함수이다.

$$ \mathcal{L}_{\text{similarity}}^{\text{MMD}} = \frac{1}{(N^s)^2} \sum_{i,j=0}^{N^s} \kappa (\mathbf{h}_{ci}^s, \mathbf{h}_{cj}^s) - \frac{2}{N^s N^t} \sum_{i,j=0}^{N^s, N^t} \kappa (\mathbf{h}_{ci}^s, \mathbf{h}_{cj}^t) - \frac{1}{(N^t)^2} \sum_{i,j=0}^{N^t} \kappa (\mathbf{h}_{ci}^t, \mathbf{h}_{cj}^t)$$

  • $\kappa(\cdot, \cdot)$은 PSD 커널 함수인데, 여기서는 RBF 커널 $\kappa(x_i, x_j) = \sum_n \eta-n \exp \{ -1 \frac{1}{2 \sigma_n} ||\mathbf{x}_i \mathbf{x}_j ||^2$의 선형 조합을 사용하였다.

4 Evaluation

5 Conclusion

 


Add a Comment Trackback