JAX/HaikuでCrystal Graph Convolutional Neural Network (CGCNN)を実装した
Crystal Graph Convolutional Neural Network (CGCNN)
CGCNNは結晶にGraph Neural Network (GNN)を適用した(おそらく)初めての手法。GNNがやたら流行っている有機分子と違って、結晶の場合はそもそもグラフの作り方が自明でない、結晶の周期性のせいで隣接行列を陽に持てないなどの問題があるが、この手法はそのあたりに取り組んだ点で新しかったのだと思う。手法の詳細については論文を読んでもらうとして、今回はCGCNNのpytorch版公式実装をjax/haikuで書き直したので、ハマった点とメモを残しておく。
JAX
JAXは雑に言うと自動微分とJITコンパイラが使えてバックエンドにGPUを指定できるNumpyみたいなやつ。去年くらいから流行っている気がする。最近はMDとかDFTをJAXで実装する話があって面白いなと思っている1。
- https://roberttlange.github.io/posts/2020/03/blog-post-10/
- https://gcucurull.github.io/deep-learning/2020/04/20/jax-graph-neural-networks/
- https://gcucurull.github.io/deep-learning/2020/06/03/jax-sparse-matrix-multiplication/
- https://sjmielke.com/jax-purify.htm
実装
CGCNNを実装したレポジトリには以下のリンクから飛べる。学習モデルの部分にはdm-haikuを使った。
データセットの収集
公式実装がデータセットに使ったMaterials Project上のmp-idを公開してくれていたので、そこからAPIを叩いてデータセットを集めた。ただ一部の物質でmaterial_id
が変更されたらしくて大分ハマった。
前処理
結晶からグラフをつくるのには(元実装と同じく)pymatgen.core.Structure.get_all_neighbors
を使った。最初はpymatgen.analysis.graphs.StructureGraph
を使ってグラフの作り方をいろいろ試せるように書こうとしていたが、やたら重いので止めた。
ちなみにStructureGraph
が重いのは自分が過去に投げたPRのせいで、pymatgen.analysis.local_env.NearNeighbors
で隣接している各サイトの属するユニットセルの番号(jimage
)を計算する部分がO(num_sites2)になっているから。なんでこんな実装になっているかというと、元の結晶のfractional coordinatesが[0, 1)に収まっていないエッジケースに対応する方法がこれしか思いつかなかったからだ(この修正をしたときにまさかGNNの前処理にこの部分が使われ得るとは思ってもいなかった)。上手い実装が思いつく人はPRを投げてください。
Pooling layer
モデルの実装はpytorch版とほぼ同じように書けるが、graph aggregation の部分だけは注意が必要。ひとつのバッチには次数の異なるグラフたちが並んでいるのでpoolingするのは次数に応じてsplitする必要がある。しかしjax.numpy.split(ary, indices_or_sections, axis)
のindices_or_sections
には(jitするなら)定数しか渡せないし、無理やりstatic_argnums
に突っ込んでもバッチごとにjit recompileされて激重になる。解決策としてはsegment_sumを使えばjitできるようになる。
hyperparameter
argparseでhyperparameterを弄るのが好きでないのでconfig用のdataclassを作ってdataclass-jsonでJSONをパースすることにした。
ベンチマーク
元のpytorch実装と同じhyperparameterで論文と同じデータセットでformation energyを学習させたらMAEが116 meV/atomのモデルを作ることをできた。しかし、論文では39 meV/atomまで下がると言っているので完全には結果を再現できていない。自前で用意したデータセットで元実装を学習させたら90 meV/atomで、これでも論文の値より大きいのでhyperparameter tuningとかの問題なのか?
Future works
いろいろやれることが思いつくけど、モチベが続かない気がするのでとりあえず区切りとする。
- jax-mdとの連携: 夢がある。現状pymatgenでグラフを作る部分が無視できない重さなので粒子の隣接リストもjaxで書く必要がありそう。
- 頂点の次数を可変にする: 元実装と自分の実装ではグラフの各頂点の次数が同じになるようにpaddingしている。これに関してはこことかここで議論されている。
-
既存のコードを置き換えるとは思えないが↩