2022年10月24日月曜日

SQLAlchemyでテーブルレコードを逐次取得してメモリ使用量を減らす

SQLAlchemyでデータベースのテーブルからデータを取得するときに、テーブルレコードの件数が多くなるほど、Pythonプロセスのメモリ使用量が増える。場合によっては、システムのメモリ不足に陥ることもあり得る。SQLAlchemyでは、レコードを逐次取得することでメモリ使用量を減らすことができるので、その方法をまとめておく。


SQLAlchemyのstream_resultsオプションについて

SQLAlchemyはデフォルトでclient side cursorsで動作し、取得したデータをすべてメモリに保持する。一方、server side cursorsで動作した場合は、クライアント側で必要とされる分だけ逐次保持される。つまり、server side cursorsを使うとメモリ使用量を抑えることができる。Working with Engines and ConnectionsのUsing Server Side Cursors (a.k.a. stream results)に説明があり、以下はその抜粋。

A client side cursor here means that the database driver fully fetches all rows from a result set into memory before returning from a statement execution. Drivers such as those of PostgreSQL and MySQL/MariaDB generally use client side cursors by default. A server side cursor, by contrast, indicates that result rows remain pending within the database server’s state as result rows are consumed by the client. 

SQLAlchemyでは、execution_optionsメソッドでstream_results=Trueとすることで、server side cursorsで動作させられる。


環境

WSL2(Ubuntu20.04)とデータベースはMariaDB 10.5。

$ lsb_release -dr
Description:    Ubuntu 20.04.5 LTS
Release:        20.04
$ python3 -V
Python 3.8.10
$ mariadbd --version
mariadbd  Ver 10.5.17-MariaDB-1:10.5.17+maria~ubu2004 for debian-linux-gnu on x86_64 (mariadb.org binary distribution)

SQLAlchemyのほかに、DBAPIとしてmysqlclient、メモリ使用量を確認するためにメモリプロファイラーのmemory_profilerをインストールする。

$ pip3 install mysqlclient sqlalchemy memory_profiler

インストールされたバージョン。

$ pip show mysqlclient | grep Version
Version: 2.1.1

$ pip show sqlalchemy | grep Version
Version: 1.4.41

$ pip show memory_profiler | grep Version
Version: 0.60.0


WSL2のDNS設定変更

後述のテストデータベースの準備でgit cloneをするが、そのときに「Could not resolve host: github.com」というメッセージが表示されてcloneできない場合は、UbuntuのDNS設定を変更する(WSL2 DNS stops working #4285)。cloneに問題がない場合は、この項は不要なので、次のテストデータベースの準備へ進む。

WSL2 Ubuntuの/etc/wsl.confに以下記述を追加する。

$ sudo vi /etc/wsl.conf
+ [network]
+ generateResolvConf = false

Ubuntuのターミナルを終了して、一度WSLを停止する。

PS > wsl --shutdown

Ubuntuを起動して、/etc/resolv.confに以下を追記する。

$ sudo vi /etc/resolv.con
+ nameserver 8.8.8.8


テストデータベースの準備

SQLAlchemyでメモリ使用量を確認するためには、それなりのレコード件数のあるテーブルが必要になる。このテスト用データベースとして、Other MySQL Documentationにあるemployee dataを使う。これはGitHubのtest_dbで公開されている。

まずはこのリポジトリをcloneする。

$ git clone https://github.com/datacharmer/test_db.git

このリポジトリにはemployeesというデータベースがあり、以下のようにデータベースをインポートできる。

$ cd test_db
$ mysql -u username -p < employees.sql

インポートが完了したら、テーブルを確認。

$ mysql -u username -p -D employees

MariaDB [employees]> show tables;
+----------------------+
| Tables_in_employees  |
+----------------------+
| current_dept_emp     |
| departments          |
| dept_emp             |
| dept_emp_latest_date |
| dept_manager         |
| employees            |
| salaries             |
| titles               |
+----------------------+


メモリ使用量の比較

stream_results=Trueを設定した場合としない場合でメモリ使用量を比較してみる。以下のようなstream_sample1.pyとdatabase.pyを用意する。以下のコードでは、salariesテーブルから全件を取得して、salaryカラムの平均を求める。stream_results=Trueを設定する場合はfetch_strem関数を、設定しない場合はfetch関数を使用する。

import statistics

from sqlalchemy import MetaData, Table, select
from memory_profiler import profile

from database import db


def _get_table_object(tablename):
    # テーブル名からテーブルオブジェクトを取得
    meta_data = MetaData(bind=db.engine)
    meta_data.reflect(only=[tablename])
    return meta_data.tables[tablename]


@profile
def fetch(tablename):
    con = db.engine.connect()
    table_obj = _get_table_object(tablename)
    stmt = select(table_obj)
    results = con.execute(stmt)
    return statistics.fmean([result['salary'] for result in results])


@profile
def stream_fetch(tablename):
    con = db.engine.connect()
    table_obj = _get_table_object(tablename)
    stmt = select(table_obj)
    results = con.execution_options(stream_results=True).execute(stmt)
    return statistics.fmean([result['salary'] for result in results])


@profile
def main():
    tablename = 'salaries'
    salary = fetch(tablename)
    # salary = stream_fetch(tablename)
    print(f'salary={salary}')


if __name__ == '__main__':
    main()
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, scoped_session


class Database:
    def __init__(self):
        db_user = 'username' # DBユーザー名
        db_passwd = 'password' # DB接続パスワード
        db_name = 'employees' # DB名
        self.engine = create_engine(
            f'mysql+mysqldb://{db_user}:{db_passwd}@localhost:3306/{db_name}?charset=utf8mb4'
        )

        self.session = scoped_session(
            sessionmaker(autocommit=False, autoflush=True, bind=self.engine)
        )


db = Database()

以下はfetch関数(stream_results=True設定なし)を使用した場合の結果。SELECT文を実行したところで600MBほどメモリを使用している。

$ time python3 sample_stream1.py
Filename: sample_stream1.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    88     35.6 MiB     35.6 MiB           1   @profile
    89                                         def fetch(tablename):
    90     36.7 MiB      1.2 MiB           1       con = db.engine.connect()
    91     37.1 MiB      0.3 MiB           1       table_obj = _get_table_object(tablename)
    92     37.1 MiB      0.0 MiB           1       stmt = select(table_obj)
    93    634.7 MiB    597.6 MiB           1       results = con.execute(stmt)
    94    659.2 MiB     24.5 MiB     2844050       return statistics.fmean([result['salary'] for result in results])


salary=63810.744836143705
Filename: sample_stream1.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   106     35.6 MiB     35.6 MiB           1   @profile
   107                                         def main():
   108     35.6 MiB      0.0 MiB           1       tablename = 'salaries'
   109     63.3 MiB     27.7 MiB           1       salary = fetch(tablename)
   110                                             # salary = stream_fetch(tablename)
   111     63.3 MiB      0.0 MiB           1       print(f'salary={salary}')



real    3m23.536s
user    2m58.915s
sys     0m23.535s

以下はstream_fetch関数(stream_results=True設定)を使用した場合。SELECT文を実行したところではメモリを消費しておらず、メモリ消費のピークはsalaryカラムの平均を求める箇所で約150MB。stream_results=Trueを設定しない場合に比べて、メモリ使用のピークは4分の1程度となっている。topコマンドで確認した限り、MariaDBサーバーのプロセスが使用するメモリ使用量はどちらでも変わらないので、stream_results=Trueを設定することで、Pythonプロセスで減ったメモリ使用量がそのままシステムのメモリ使用量の減少になる。ちなみに、実行速度とのトレードオフになるが、salaryの平均値を求める際にリストでなくジェネレーター式を使うとさらにメモリ消費を抑えられる。

$ time python3 sample_stream1.py
Filename: sample_stream1.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    96     35.5 MiB     35.5 MiB           1   @profile
    97                                         def stream_fetch(tablename):
    98     36.7 MiB      1.2 MiB           1       con = db.engine.connect()
    99     37.0 MiB      0.3 MiB           1       table_obj = _get_table_object(tablename)
   100     37.0 MiB      0.0 MiB           1       stmt = select(table_obj)
   101     37.0 MiB      0.0 MiB           1       results = con.execution_options(stream_results=True).execute(stmt)
   102    147.1 MiB      1.6 MiB     2844050       return statistics.fmean([result['salary'] for result in results])


salary=63810.744836143705
Filename: sample_stream1.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   105     35.5 MiB     35.5 MiB           1   @profile
   106                                         def main():
   107     35.5 MiB      0.0 MiB           1       tablename = 'salaries'
   108                                             # salary = fetch(tablename)
   109     38.2 MiB      2.7 MiB           1       salary = stream_fetch(tablename)
   110     38.2 MiB      0.0 MiB           1       print(f'salary={salary}')



real    3m0.012s
user    2m35.694s
sys     0m23.953s


server side cursorsの注意点

server side cursorsで動作させる場合は注意が必要な点がある。server side cursorsの場合、同じ接続で複数の操作を実行できない。例えば、以下のコードではsalariesテーブルから全件取得したあとにsalariesテーブルからemp_no=10001のレコードを削除しようとしているが、「Commands out of sync; you can't run this command now」というエラーになる。stream_results=Trueを設定していない場合はエラーにならない。

import statistics

from sqlalchemy import MetaData, Table, select, delete
from memory_profiler import profile

from database import db


def _get_table_object(tablename):
    # テーブル名からテーブルオブジェクトを取得
    meta_data = MetaData(bind=db.engine)
    meta_data.reflect(only=[tablename])
    return meta_data.tables[tablename]


@profile
def stream_fetch(tablename):
    con = db.engine.connect()
    table_obj = _get_table_object(tablename)
    stmt = select(table_obj)
    results = con.execution_options(stream_results=True).execute(stmt)

    # 以下はエラーになる
    # stream_results=Trueとした同じ接続でテーブル操作を行おうとすると以下エラーになる
    # sqlalchemy.exc.ProgrammingError: (MySQLdb.ProgrammingError) (2014, "Commands out of sync; you can't run this command now")
    stmt = delete(table_obj).where(table_obj.c.emp_no == 10001)
    con.execute(stmt)

    return statistics.fmean([result['salary'] for result in results])


@profile
def main():
    tablename = 'salaries'
    salary = stream_fetch(tablename)
    print(f'salary={salary}')


if __name__ == '__main__':
    main()

複数処理を実行したい場合は、以下のようにデータベース接続を新たに作成する。

import statistics

from sqlalchemy import MetaData, Table, select, delete
from memory_profiler import profile

from database import db


def _get_table_object(tablename):
    # テーブル名からテーブルオブジェクトを取得
    meta_data = MetaData(bind=db.engine)
    meta_data.reflect(only=[tablename])
    return meta_data.tables[tablename]


@profile
def stream_fetch(tablename):
    con = db.engine.connect()
    table_obj = _get_table_object(tablename)
    stmt = select(table_obj)
    results = con.execution_options(stream_results=True).execute(stmt)

    # 以下はエラーにならない
    con2 = db.engine.connect()
    stmt2 = delete(table_obj).where(table_obj.c.emp_no == 10001)
    con2.execute(stmt2)

    return statistics.fmean([result['salary'] for result in results])


@profile
def main():
    tablename = 'salaries'
    salary = stream_fetch(tablename)
    print(f'salary={salary}')


if __name__ == '__main__':
    main()

上記コードではstream_results=Trueを設定してSELECTで取得したあとにDELETEしているが、resultsはDELETE前のSELECT時点のデータとなる。


yeild_perオプションについて

stream_results=Trueだけを指定した場合、はじめは少ないバッファーサイズから徐々にmax_row_buffer(デフォルト1000)を上限として増やしていく。それに対して、バッファーサイズを一定にしたい場合は、yeild_perを使う(Fetching Large Result Sets with Yield Per)。yeild_perを使うと、stream_results=Trueの場合と同様にserver side cursorsで動作する。

以下コードは、stream_sample1.pyのstream_fetch関数をyeild_perを使ってバッファサイズを一定(1000)にするように書き換えたもの。

@profile
def stream_fetch(tablename):
    con = db.engine.connect()
    table_obj = _get_table_object(tablename)
    stmt = select(table_obj)
    results = con.execution_options(yield_per=1000).execute(stmt)
    return statistics.fmean([result['salary'] for result in results])

stream_results=Trueを指定して上記コードと同じ結果となるコードを書くことができる。ただし、以下のようにオプション指定が増えたりするので、yeild_perを使うほうが簡単。

@profile
def stream_fetch(tablename):
    con = db.engine.connect()
    table_obj = _get_table_object(tablename)
    stmt = select(table_obj)
    results = con.execution_options(stream_results=True, max_row_buffer=1000).execute(stmt)
    return statistics.fmean([result['salary'] for result in results.yield_per(1000)])


0 件のコメント:

コメントを投稿