Javaをよくできる子にするために

昨日の続きです。一点書き忘れたのは、オブジェクトの生成が無視されることはないようだ、という点です。既存のJavaのインターフェースはオブジェクトを返すようになっているので、それらに乗っ取るならオブジェクトを、たとえばIntegerを返したくなるのが心情です。残念ながら、これらがただのintまで落ちることを期待するのは、現状では裏切られるようです。少なくとも外から実験的に見た結果ですが。


もし、Integerを返すようなインターフェースがあり、しかもわざわざnewして返しているのなら、即刻intを返す同等のインターフェースを用意するべきです。たとえば、Listを実装してIteratorの気の利いたインプリはあなたの期待を裏切ります。これは最適化できるような気がするんですが、そこまでJITコンパイラはがんばってくれないようです。

今日はオブジェクト生成は省略してくれるかと言うことです。お題は、newです。結論から言うと、newは省略してくれないらしい、ということです。つまり、なるべくnewしなくてもいいような設計にする必要があります。しかし、一方で、実は意外とnewさえしなければ遅そうと思っている方法がかなり速くなります。

import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;

abstract class IntList implements List<Integer> {
    abstract public int getInt(int index);

    public Integer get(int index) {
        return getInt(index);
    }
}

public class Range extends IntList {
    private final int begin;
    private final int end;

    public Range(int begin, int end) {
        this.begin = begin;
        this.end = end;
    }

    public int getInt(int index) {
        return index + begin;
    }

    public boolean isEmpty() {
        return begin == end;
    }

    public Iterator<Integer> iterator() {
        return new Iterator<Integer>() {
            private int cur = begin;

            public boolean hasNext() {
                return cur < end;
            }

            public Integer next() {
                return cur++;
            }

            public void remove() {
            }
        };
    }

    public int size() {
        return end - begin;
    }

    public static void main(String[] args) {
        {
            long begin = System.currentTimeMillis();
            List<Integer> list = new Range(0, 100000);
            int t = 0;
            for (int j = 0; j < 1000; j++) {
                for (int i : list) {
                    t += i;
                }
            }
            long end = System.currentTimeMillis();
            System.out.println(((double) end - begin) / 1000 + " sec");
        }

        {
            long begin = System.currentTimeMillis();
            List<Integer> list = new Range(0, 100000);
            int t = 0;
            for (int j = 0; j < 1000; j++) {
                for (int i = 0; i < list.size(); i++) {
                    t += list.get(i);
                }
            }
            long end = System.currentTimeMillis();
            System.out.println(((double) end - begin) / 1000 + " sec");
        }

        {
            long begin = System.currentTimeMillis();
            IntList list = new Range(0, 100000);
            int t = 0;
            for (int j = 0; j < 1000; j++) {
                for (int i = 0; i < list.size(); i++) {
                    t += list.getInt(i);
                }
            }
            long end = System.currentTimeMillis();
            System.out.println(((double) end - begin) / 1000 + " sec");
        }

        {
            long begin = System.currentTimeMillis();
            Range list = new Range(0, 100000);
            int t = 0;
            for (int j = 0; j < 1000; j++) {
                for (int i = 0; i < list.size(); i++) {
                    t += list.getInt(i);
                }
            }
            long end = System.currentTimeMillis();
            System.out.println(((double) end - begin) / 1000 + " sec");
        }

        {
            long begin = System.currentTimeMillis();
            int t = 0;
            for (int j = 0; j < 1000; j++) {
                for (int i = 0; i < 100000; i++) {
                    t += i;
                }
            }
            long end = System.currentTimeMillis();
            System.out.println(((double) end - begin) / 1000 + " sec");
        }

    }

    public boolean add(Integer o) {
        throw new UnsupportedOperationException();
    }

    public void add(int index, Integer element) {
        throw new UnsupportedOperationException();
    }

    public boolean addAll(Collection<? extends Integer> c) {
        throw new UnsupportedOperationException();
    }

    public boolean addAll(int index, Collection<? extends Integer> c) {
        throw new UnsupportedOperationException();
    }

    public void clear() {
        throw new UnsupportedOperationException();
    }

    public boolean contains(Object o) {
        throw new UnsupportedOperationException();
    }

    public boolean containsAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    public int indexOf(Object o) {
        throw new UnsupportedOperationException();
    }

    public int lastIndexOf(Object o) {
        throw new UnsupportedOperationException();
    }

    public ListIterator<Integer> listIterator() {
        throw new UnsupportedOperationException();
    }

    public ListIterator<Integer> listIterator(int index) {
        throw new UnsupportedOperationException();
    }

    public boolean remove(Object o) {
        throw new UnsupportedOperationException();
    }

    public Integer remove(int index) {
        throw new UnsupportedOperationException();
    }

    public boolean removeAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    public boolean retainAll(Collection<?> c) {
        throw new UnsupportedOperationException();
    }

    public Integer set(int index, Integer element) {
        throw new UnsupportedOperationException();
    }

    public List<Integer> subList(int fromIndex, int toIndex) {
        throw new UnsupportedOperationException();
    }

    public Object[] toArray() {
        throw new UnsupportedOperationException();
    }

    public <T> T[] toArray(T[] a) {
        throw new UnsupportedOperationException();
    }
}

実行

1.688 sec
1.75 sec
0.64 sec
0.203 sec
0.188 sec

Pythonにはxrangeという関数があります。これは、xrange(0, 100)とすると、0から99までのリストを仮想的に作ったのと同じになります。実際にはIteratorを作るだけで、100*4byteのメモリは食いません。これと同等なものを実装したのがRangeクラスです。さて、こいつをJavaのforeachで回したくなります。そこで、iteratorメソッドから無名関数を返します。ところが、iteratorメソッドはインターフェースにObject(ここではInteger)を使っています。しょうがないのでオブジェクトを生成してもらいます(Javagenerics)。ここでの予想は、Integerを生成しても、あとでintだけを使うことがわかって入れば、生成をせずにその分のオーバーヘッドが省略されるだろうという予想です。残念ながらこれは裏切られます。単なるforループの10倍の実行時間がかかりました。これは最適化できそうなんで、がんばって欲しいものです。

どうすればいいでしょう。話は簡単です。Integerを使わずにintを使うインターフェースを用意してやれば良いのです。abstractクラスのIntListはint getInt(int index);というインターフェースをもっています。ただのgetのint版。これを使えば速くなりそうです。また、前回の教訓に従えば、実際の実行型で宣言すればさらに速くなりそうです。

実行結果は、上からforeach, List::get, IntList::getInt, Range::getInt, ただのループです。最後には、ただのループ(intの足し算)とほぼ同性能になりました。これをみれば、インライン展開は結構がんばっているという私の感想もご理解いただけるのではないでしょうか。