JAX/HaikuでCrystal Graph Convolutional Neural Network (CGCNN)を実装した

Crystal Graph Convolutional Neural Network (CGCNN)

CGCNNは結晶にGraph Neural Network (GNN)を適用した(おそらく)初めての手法。GNNがやたら流行っている有機分子と違って、結晶の場合はそもそもグラフの作り方が自明でない、結晶の周期性のせいで隣接行列を陽に持てないなどの問題があるが、この手法はそのあたりに取り組んだ点で新しかったのだと思う。手法の詳細については論文を読んでもらうとして、今回はCGCNNのpytorch版公式実装をjax/haikuで書き直したので、ハマった点とメモを残しておく。

JAX

JAXは雑に言うと自動微分JITコンパイラが使えてバックエンドにGPUを指定できるNumpyみたいなやつ。去年くらいから流行っている気がする。最近はMDとかDFTをJAXで実装する話があって面白いなと思っている1

実装

CGCNNを実装したレポジトリには以下のリンクから飛べる。学習モデルの部分にはdm-haikuを使った。

github.com

データセットの収集

公式実装がデータセットに使ったMaterials Project上のmp-idを公開してくれていたので、そこからAPIを叩いてデータセットを集めた。ただ一部の物質でmaterial_idが変更されたらしくて大分ハマった。

lan496.hatenadiary.jp

前処理

結晶からグラフをつくるのには(元実装と同じく)pymatgen.core.Structure.get_all_neighborsを使った。最初はpymatgen.analysis.graphs.StructureGraphを使ってグラフの作り方をいろいろ試せるように書こうとしていたが、やたら重いので止めた。

ちなみにStructureGraphが重いのは自分が過去に投げたPRのせいで、pymatgen.analysis.local_env.NearNeighborsで隣接している各サイトの属するユニットセルの番号(jimage)を計算する部分がO(num_sites2)になっているから。なんでこんな実装になっているかというと、元の結晶のfractional coordinatesが[0, 1)に収まっていないエッジケースに対応する方法がこれしか思いつかなかったからだ(この修正をしたときにまさかGNNの前処理にこの部分が使われ得るとは思ってもいなかった)。上手い実装が思いつく人はPRを投げてください。

github.com

Pooling layer

モデルの実装はpytorch版とほぼ同じように書けるが、graph aggregation の部分だけは注意が必要。ひとつのバッチには次数の異なるグラフたちが並んでいるのでpoolingするのは次数に応じてsplitする必要がある。しかしjax.numpy.split(ary, indices_or_sections, axis)indices_or_sectionsには(jitするなら)定数しか渡せないし、無理やりstatic_argnumsに突っ込んでもバッチごとにjit recompileされて激重になる。解決策としてはsegment_sumを使えばjitできるようになる。

hyperparameter

argparseでhyperparameterを弄るのが好きでないのでconfig用のdataclassを作ってdataclass-jsonJSONをパースすることにした。

ベンチマーク

元のpytorch実装と同じhyperparameterで論文と同じデータセットでformation energyを学習させたらMAEが116 meV/atomのモデルを作ることをできた。しかし、論文では39 meV/atomまで下がると言っているので完全には結果を再現できていない。自前で用意したデータセットで元実装を学習させたら90 meV/atomで、これでも論文の値より大きいのでhyperparameter tuningとかの問題なのか?

Future works

いろいろやれることが思いつくけど、モチベが続かない気がするのでとりあえず区切りとする。

  • jax-mdとの連携: 夢がある。現状pymatgenでグラフを作る部分が無視できない重さなので粒子の隣接リストもjaxで書く必要がありそう。
  • 頂点の次数を可変にする: 元実装と自分の実装ではグラフの各頂点の次数が同じになるようにpaddingしている。これに関してはこことかここで議論されている。

  1. 既存のコードを置き換えるとは思えないが