こんにちは.人工知能研究所のMichael, Wing, 小御門、内野、丸橋です。私たちは「AIを用いた知識発見」に関する研究開発を行っています。このたび,テキサスA&M大学のData Integration, Visualization, and Exploration (DIVE) Laboratory (https://people.tamu.edu/~sji/)との共同研究成果である「グラフ分類のためのデータ拡張技術」が機械学習の主要な国際会議であるICLR 2023に採択されたので,その内容を紹介します。
対象論文
- 学会名:International Conference on Learning Representations(ICLR)
- 開催日:2023年5月1日~5月5日
- 開催場所:ルワンダ共和国 キガリ
- 論文タイトル:Automated Data Augmentations for Graph Classification
- 著者(富士通):Michael Curtis McThrow、Wing Yee Au、小御門道、内野寛治、丸橋弘治
- 著者(テキサスA&M大学):Youzhi Luo, Shuiwang Ji openreview.net
概要
分子やソーシャルネットワーク、金融トランザクションなどの多くの現実世界のオブジェクトは、自然にグラフとして表現できます。これらグラフ構造データを対象とした分類モデルの開発は非常に重要です。近年、ディープラーニングの活用により、このような分類モデルは著しく進歩しています 。グラフデータに特化したディープニューラルネットワーク(DNN)であるグラフニューラルネットワーク (GNN) は、不正検知や分子特性予測のような多くのグラフ表現学習とグラフ分類タスクに適用されています。
しかし、DNNと同様、 GNNは容易にオーバーフィットし、小さなデータセット上で満足な性能を達成できません。この問題に対処するために、データ拡張技術を使用して、より多くのデータサンプルを生成することが重要です。データ拡張は画像処理や自然言語処理の性能を改善するのに有効であることが実証されています。画像の場合には、フリッピング、クロッピング、色シフト、スケーリング、回転、および弾性変形など様々な画像変換技術が利用されています。また、テキストの場合には、同義語置換、位置スワップ、逆翻訳などの変換技術が利用されています。これらのデータ拡張技術は、DNNを訓練する際のオーバーフィットを低減し、ロバスト性を向上させるため広く利用されています。
データ拡張に必要な特性は、ラベル不変性です。これは、データ拡張のために行う変換処理の前後でラベル関連情報が保持されることを要求します。反転や回転などの一般的に使用される画像拡張技術は元の画像のほとんどすべての情報を保存できるので、ラベル不変性は画像に対して比較的容易に達成できます。しかしながら、グラフはエッジによって接続された複数のノードで形成された非ユークリッドデータです。グラフの構造をわずかに変更しただけでも、グラフ内の重要な情報が破壊される可能性があります。従って、グラフに対してラベル不変な変換を設計することは非常に困難です。現在、最も一般的に使用されているグラフ拡張技術は、グラフのノードとエッジのランダムな変換に基づいています。しかし、このようなランダム変換は必ずしもラベル不変ではありません。なぜなら、重要なラベル関連情報がランダムに損なわれる可能性があるためです。したがって、これらのグラフ拡張技術によってグラフ分類モデルのパフォーマンスが向上するとは限りません。
既存のグラフデータ拡張技術はラベル不変性を考慮していませんが、私たちはこの重要な問題を解決する、強化学習を用いてラベル不変性を担保するグラフデータ拡張技術 GraphAugを提案しました。GraphAugは学習可能なモデルを使用して、データ拡張のためのグラフ変換を自動的に選択します。また強化学習により推定されたラベルが不変である確率を最大化するようにモデルを最適化します。 GraphAugは、教師付きグラフ分類を対象とするラベル不変性を考慮したグラフデータ拡張に関する最初の研究です。私たちはGraphAugが多値グラフ分類タスク上で既存のグラフデータ拡張技術より性能が優れていることを実験で示しました。GraphAugのコードは、DIG [1] ライブラリで使用できます。富士通研究所では実際に本グラフデータ拡張技術を顧客課題に適用しています。
逐次変換によるグラフデータ拡張
グラフデータ拡張を逐次変換過程と考えます。与えられた訓練データセットから抽出したグラフに対し、変換生成モデルgで生成した変換関数を作用させ、変換グラフを生成します。具体的には、番目のステップ () を考えたとき、一つ前のステップで得られたグラフと変換生成モデルgに基づく変換を生成し、に変換を適用することでを得ます。GraphAugでははすべて以下の3つのグラフ変換カテゴリから選択します。
- ノード特徴マスキング (MaskNF):ノード特徴ベクトルのいくつかの値を0に設定
- ノードドロップ(DropNode):入力グラフから一部のノードを削除
- エッジ摂動 (PerturbEdge) :入力グラフから既存のエッジを削除または新規エッジを追加
変換生成モデルgは3つの部分で構成されています。グラフから特徴を抽出するためのGNNベースのエンコーダ、データ拡張カテゴリを生成するためのGRU [2] モデル、および確率を計算するための4つの多層パーセプトロン (MLP) モデルです。エンコーダとしてGIN [3] モデルを用いました。t番目のステップでは、まずグラフに仮想ノードを追加します。その仮想ノードを他のすべてのノードに接続します。ここでは、グラフレベルの情報を抽出するために仮想ノードを使用します。まずに対する全てのノード埋め込みベクトルを獲得するためにGNNエンコーダにより多層メッセージパッシングを行います。次にGRUモデルとMLPモデルを用いて仮想ノードの埋込みベクトルから各グラフ変換カテゴリを選択するための確率分布を計算します。確率分布から変換カテゴリサンプリングし、に応じてすべてのグラフ特徴に対し以下の変換を行い、新しいグラフを生成します。
- もしがMaskNFである場合、MLPモデルは対応するノード埋め込みベクトルから各ノード特徴ベクトルをマスキングする確率を計算し、そのBernoulli分布からサンプリングすることによってマスキングするかどうかを決定します。
- もしがDropNodeである場合、MLPモデルは、対応するノード埋め込みベクトルから各ノードをドロップする確率を計算し、そのBernoulli分布からサンプリングすることによって、ドロップするかどうかを決定します。
- もしがPerturbEdgeである場合、MLPモデルは、2つの対応するノード埋め込みの連結から各エッジを摂動させる確率を計算し、そのBernoulli分布からサンプリングすることによって、エッジの削除または追加を決定します。
強化学習によるラベル不変性最適化
どのグラフ特徴が重要で、保持されるべきかを示す正解ラベルがないので、教師有り学習によりラベル不変性のあるグラフ変換生成モデルを学習することはできません。この問題に取り組むために、強化学習を用いてモデルを暗黙的に最適化します。
逐次的なグラフデータ拡張をマルコフ決定過程 (MDP) として定式化します。マルコフ性は自然に満たされます。すなわち、任意の変換ステップにおける出力グラフは、以前に実行された変換ではなく、入力グラフにのみ依存します。具体的には、t番目のステップを考えます。最後のステップで得られたグラフを現在の状態として、からへの変換を状態遷移と考えます。gによって生成されたがアクションに対応します。
上記強化学習環境におけるフィードバック報酬信号として報酬生成モデルsから予測されたラベル不変確率を用います。報酬生成モデルsにはグラフマッチングネットワーク[4] を使用します。グラフからへの逐次変換プロセスが終わると、s はを入力として、をラベルが不変である確率として出力します。次に、変換生成モデルg は、ポリシー勾配によって最適化されます。変換生成モデルgの学習より前に訓練データセットから手動でサンプリングしたグラフペア上で報酬生成モデルsを訓練します。変換生成モデルgのトレーニング中、報酬生成モデルsは報酬の生成にのみ使用されるため、パラメータは固定します。
実験結果
人工グラフデータセットによる実験
2つの人工グラフデータセットCOLORSとTRIANGLESを利用して、GraphAUGが実際にラベル不変なデータを生成し、一様な変換に比べ高精度を達成できることを示しました。最初に報酬生成モデルが収束するまで学習し、次に変換生成モデルを学習します。GraphAUGの変換生成モデルがラベル不変な生成データを学習できるか確認するために、バリデーションデータをグラフデータ拡張し、その前後でラベル不変な割合を計算しました。
ラベル不変の割合の変化曲線を図1に示します。これらの曲線は、訓練が進むにつれて、著者らのモデルが、高いラベル不変割合を示すデータを学習できることを示しています。言い換えれば、著者らのデータ拡張モデルは訓練後にラベル不変なデータを生成できることを実証しました。
グラフベンチマークデータセットを用いた実験
さらに、 MUTAG, NCI 109, NCI 1, PROTEINS, IMDB‐BINARY,およびCOLLABを含むTUDatasetsベンチマークに含まれる6つの広く使用されているデータセットを利用し、従来のグラフ拡張技術に対するGraphAugの利点を実証しました。またOGBベンチマークに含まれる大規模分子グラフデータセットであるogbg‐molhivデータセットに関する実験も行いました。著者らは、 GINモデルの分類性能の向上度合いによって異なるデータ拡張技術を比較しました。分類の評価指標として、 TUDatasetsベンチマークの6つのデータセットに対しては精度を用い、 ogbg‐molhivデータセットに対してはROC‐AUCを用いました。
GINモデルを用いた7つのデータセットに対する評価結果を表1に記載します。表1によれば、著者らのGraphAugは、 7つのデータセットに対するすべてのグラフデータ拡張技術の中で最高の性能を達成することができました。特に、 MUTAG, NCI 109, NCI 1,ogbg‐molhivを含む分子データセットに対して、ランダム変換に基づくグラフデータ拡張技術を使用すると、分類精度が大きく低下しました。ランダムな変換には問題があることが分かります。一方GraphAugは、ミックスアップ法やグラフ自己教師有学習を用いたより高度なデータ拡張技術より良い精度を実現しました。表1の結果よりGraphAugの有効性を検証しました。
参考資料
- [1] Liu, M., Luo, Y., Wang, L., Xie, Y., Yuan, H., Gui, S., ... & Ji, S. (2021). DIG: A turnkey library for diving into graph deep learning research. The Journal of Machine Learning Research, 22(1), 10873-10881.
- [2] Cho, K., Merrienboer, B., Gulcehre, C., Bougares, F., Schwenk, H., & Bengio, Y. (2014). Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation. In EMNLP.
- [3] Xu, K., Hu, W., Leskovec, J., & Jegelka, S. How Powerful are Graph Neural Networks?. In International Conference on Learning Representations.
- [4] Li, Y., Gu, C., Dullien, T., Vinyals, O., & Kohli, P. (2019, May). Graph matching networks for learning the similarity of graph structured objects. In International conference on machine learning (pp. 3835-3845). PMLR.