日々精進

新しく学んだことを書き留めていきます

tensorflow.python.keras.applications.inception_resnet_v2などを使って転移学習する場合は、tensorflow.python.keras.applications.inception_resnet_v2.preprocess_inputを使って前処理する

転移学習する場合、base modelの重みは固定することが多いということもあり、 base modelの学習時と同じ前処理を時前の学習データにもすべき。 tensorflow.python.keras.applications.inception_resnet_v2.preprocess_input のように、各base modelのモジュールの中にpreprocess_input関数が用意されており、これが base modelの前処理と同じ前処理をしてくれるのでこれを使うと良い。

参考:

stackoverflow.com

Data Augmentationにはimgaugを使うと便利

クラスメソッドさんがいい記事を書いているのでこれを参考にやってみるといいと思います。

dev.classmethod.jp

ランダムに色んな効果を適用したい時に便利。

github.com

pip install imgaugで入れるとscikit-imageがバグってるバージョンが入ってエラーになったのでそれだけ注意。解決策は以下。

anton0825.hatenablog.com

Tensor flowのTensorオブジェクトをndarrayに変換する

以下のようにSessionを開始してTensor.eval()を実行すればいい。

sess = tf.Session()
with sess.as_default():
    X_train_color = X_train_tensor.eval()
    X_test_color = X_test_tensor.eval()
    X_valid_color = X_valid_tensor.eval()

Tensor flowのresizeなどのメソッドを使うと結果がTensorになってこういう変換が必要になるのは面倒だなぁ。。 前処理の最後とかでそのままfitに渡せば良いとかならまだいいんだろうけど。

参考: olgabelitskaya.github.io

keras/Tensor flowで学習処理を実行すると「tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_3' with dtype float and shape [?,299,299,3]」エラー

原因というか、何が悪いかはよく分からないがTensorBoardの引数の histogram_freq=1histogram_freq=0 にすると直った。 TensorBoardかTensor flowかKerasのバグかな。。 エラーメッセージからこの解決策に全然たどり着けず、結構ハマッた。。 以下のIssueで現象は報告されているのでいずれは解消されそうだけどしばらくはこの回避策が必要っぽい。

github.com

S3から正規表現に一致するファイルをすべてダウンロードする

boto3に普通にありそうな機能だけど見付からなかったので実装。

def download_object_from_s3(remote_path: str) -> io.BytesIO:
    """S3からオブジェクトをダウンロードしてreturnする
    Args:
        remote_path (str): ダウンロード元のS3のパス
    Returns:
        S3からダウンロードしたオブジェクト
    """
    data: io.BytesIO = io.BytesIO()
    s3: ServiceResource = boto3.resource('s3',
                                         aws_access_key_id=ACCESS_KEY,
                                         aws_secret_access_key=SECRET_ACCESS_KEY)
    s3.Bucket(BUCKET).download_fileobj(remote_path, data)
    data.seek(0)
    return data


def download_object_list_from_s3(pattern: str) -> List[S3Object]:
    regex = re.compile(pattern)
    s3: ServiceResource = boto3.resource('s3',
                                         aws_access_key_id=ACCESS_KEY,
                                         aws_secret_access_key=SECRET_ACCESS_KEY)
    objects = s3.Bucket(BUCKET).objects.all()
    result: list = []
    for a_object in objects:
        if regex.match(a_object.key):
            content: io.BytesIO = download_object_from_s3(a_object.key)
            result.append(S3Object(a_object.key, content))
    return result

download_object_list_from_s3("mydir/.+.jpg") のように使えます。便利。

参考:

stackoverflow.com