Make Original Algorithm

ここでは本フレームワークでの自作アルゴリズムを作成する方法を説明します。構成は以下です。

  1. 概要

  2. 実装するクラスの説明
    1. Config

    2. Memory

    3. Parameter

    4. Trainer

    5. Worker

  3. 自作アルゴリズムの登録

  4. 型アノテーション

  5. Q学習実装例

概要

自作アルゴリズムでは5つクラスを定義する必要があり、以下のように連携して動作します。(図にはないですが、他にハイパーパラメータを管理するConfigクラスがあります)

../_images/overview-sequence.drawio.png
WorkerRunとEnvRunはフレームワーク内の動作になるので意識する必要はありません。
それぞれの役割は以下です。

Config

  • ハイパーパラメータを管理するクラス

Memory

  • Workerが収集したサンプルを管理

Parameter

  • 学習パラメータを保持

Trainer

  • Memoryからサンプルを取得し学習する

  • 学習後、Parameterを更新する

Worker

  • Environmentと連携しサンプルを収集

  • 収集したサンプルをMemoryに送信(add only)

  • 行動決定に必要な情報をParameterから読む(read only)

分散学習は以下となり各クラスが非同期で動作します。

../_images/overview-mp.drawio.png

同期的な学習と以下の点が異なります。

  • WorkerがMemoryにサンプルを送るタイミングとTrainerが取り出すタイミングが異なる

  • ParameterがWorkerとTrainerで同期されない

各クラスの実装の仕方を見ていきます。

実装する各クラスの説明

Config

強化学習アルゴリズムの種類やハイパーパラメータを管理するクラスです。 基底クラスは srl.base.rl.base.RLConfig でこれを継承して作成します。

RLConfig で実装が必要な関数・プロパティは以下です。

from dataclasses import dataclass
from srl.base.rl.config import RLConfig
from srl.base.define import RLBaseTypes
from srl.base.rl.processor import Processor

# 必ず dataclass で書いてください
@dataclass
class MyConfig(RLConfig):

   # 任意のハイパーパラメータを定義
   hoo_param: float = 0

   def __post_init__(self):
      super().__post_init__()  # 親のコンストラクタも呼んでください

   def get_name(self) -> str:
      """ ユニークな名前を返す """
      raise NotImplementedError()

   def get_base_action_type(self) -> RLBaseTypes:
      """
      アルゴリズムが想定するアクションのタイプ(srl.base.define.RLBaseTypes)を返してください。
      """
      raise NotImplementedError()

   def get_base_observation_type(self) -> RLBaseTypes:
      """
      アルゴリズムが想定する環境から受け取る状態のタイプ(srl.base.define.RLBaseTypes)を返してください。
      """
      raise NotImplementedError()

   def get_framework(self) -> str:
      """
      使うフレームワークを指定してください。
      return ""           : なし
      return "tensorflow" : Tensorflow
      return "torch"      : Torch
      """
      raise NotImplementedError()

   # ------------------------------------
   # 以下は option です。(なくても問題ありません)
   # ------------------------------------
   def assert_params(self) -> None:
      """ パラメータのassertを記載 """
      super().assert_params()  # 親クラスも呼び出してください

   def setup_from_env(self, env: EnvRun) -> None:
      """ env初期化後に呼び出されます。env関係の初期化がある場合は記載してください。 """
      pass

   def setup_from_actor(self, actor_num: int, actor_id: int) -> None:
      """ 分散学習でactorが指定されたときに呼び出されます。Actor関係の初期化がある場合は記載してください。 """
      pass

   def get_processors(self) -> List[Optional[Processor]]:
      """ 前処理を追加したい場合設定してください """
      return []

   def get_used_backup_restore(self) -> bool:
      """ MCTSなど、envのbackup/restoreを使う場合はTrueを返してください"""
      return False

Memory

Workerが取得したサンプル(batch)をTrainerに渡す役割を持っているクラスです。
以下の3種類から継承します。
(RLMemoryを直接継承することでオリジナルのMemoryを作成することも可能です)
(オリジナルのMemoryの作成例は`srl.algorithms.world_models`の実装を参考にしてください)

SequenceMemory

来たサンプルを順序通りに取り出します。(Queueみたいな動作です)

ExperienceReplayBuffer

サンプルをランダムに取り出します。

PriorityExperienceReplay

サンプルを優先順位に従い取り出します。

SequenceMemory

順番通りにサンプルを取り出しますMemoryです。サンプルは取り出すとなくなります。

from srl.rl.memories.sequence_memory import SequenceMemory


class MyRemoteMemory(SequenceMemory):
    pass


# 実行例
memory = MyRemoteMemory(None)
memory.add([1, 2])
memory.add([2, 3])
memory.add([3, 4])
dat = memory.sample()
print(dat)  # [[1, 2], [2, 3], [3, 4]]

ExperienceReplayBuffer

ランダムにサンプルを取り出すMemoryです。
これを使う場合は、Configに RLConfigComponentExperienceReplayBuffer を継承する必要があります。
from dataclasses import dataclass

from srl.base.define import RLBaseTypes
from srl.base.rl.config import RLConfig
from srl.rl.memories.experience_replay_buffer import ExperienceReplayBuffer, RLConfigComponentExperienceReplayBuffer


@dataclass
class MyConfig(
    RLConfig,
    RLConfigComponentExperienceReplayBuffer,
):
    # RLConfig に加え、RLConfigComponentExperienceReplayBuffer も継承する
    # 順番は RLConfig -> RLConfigComponentExperienceReplayBuffer

    def get_name(self) -> str:
        return "MyConfig"

    def get_base_action_type(self) -> RLBaseTypes:
        return RLBaseTypes.DISCRETE

    def get_base_observation_type(self) -> RLBaseTypes:
        return RLBaseTypes.DISCRETE

    def get_framework(self) -> str:
        return ""


class MyRemoteMemory(ExperienceReplayBuffer):
    pass


# 実行例
memory = MyRemoteMemory(MyConfig())
memory.add(1)
memory.add(2)
memory.add(3)
memory.add(4)
dat = memory.sample(batch_size=2)
print(dat)  # [3, 2]

PriorityExperienceReplay

優先順位に従ってサンプルを取り出すMemoryです。
これを使う場合は、Configにも RLConfigComponentPriorityExperienceReplay を継承する必要があります。

このアルゴリズムはConfigにより切り替えることができます。

クラス名

説明

ReplayMemory

ExperienceReplayBufferと同じで、ランダムに取得します。(優先順位はありません)

ProportionalMemory

サンプルの重要度によって確率が変わります。重要度が高いサンプルほど選ばれる確率が上がります。

RankBaseMemory

サンプルの重要度のランキングによって確率が変わります。重要度が高いサンプルほど選ばれる確率が上がるのはProportionalと同じです。

from dataclasses import dataclass

import numpy as np

from srl.base.define import RLBaseTypes
from srl.base.rl.config import RLConfig
from srl.rl.memories.priority_experience_replay import (
    PriorityExperienceReplay,
    RLConfigComponentPriorityExperienceReplay,
)


@dataclass
class MyConfig(RLConfig, RLConfigComponentPriorityExperienceReplay):
    # RLConfig に加え、RLConfigComponentPriorityExperienceReplay も継承する
    # 順番は RLConfig -> RLConfigComponentPriorityExperienceReplay

    def get_name(self) -> str:
        return "MyConfig"

    def get_base_action_type(self) -> RLBaseTypes:
        return RLBaseTypes.DISCRETE

    def get_base_observation_type(self) -> RLBaseTypes:
        return RLBaseTypes.DISCRETE

    def get_framework(self) -> str:
        return ""


class MyRemoteMemory(PriorityExperienceReplay):
    pass


# --- select memory
config = MyConfig()
# config.memory.set_replay_memory()
config.memory.set_proportional_memory()
# config.memory.set_rankbase_memory()

# --- run memory
memory = MyRemoteMemory(config)
memory.add(1, priority=1)
memory.add(2, priority=2)
memory.add(3, priority=3)
memory.add(4, priority=4)
batchs, weights, update_args = memory.sample(batch_size=1, step=0)
print(batchs)  # [2]
memory.update(update_args, np.array([5, 10, 15, 20, 11]))

Parameter

パラメータを管理するクラスです。深層学習の場合はここにニューラルネットワークを定義することを想定しています。

実装が必要な関数は以下です。

from srl.base.rl.parameter import RLParameter

import numpy as np

class MyParameter(RLParameter):
   def __init__(self, *args):
      """ コントラクタの引数は親に渡してください """
      super().__init__(*args)

      # self.config に上で定義した MyConfig が入ります
      self.config: MyConfig

   # call_restore/call_backupでパラメータが復元できるように作成
   def call_restore(self, data, **kwargs) -> None:
      raise NotImplementedError()
   def call_backup(self, **kwargs):
      raise NotImplementedError()

   # その他任意の関数を作成できます
   # (分散学習ではTrainer/Worker間で値を保持できない点に注意)

Trainer

学習を定義する部分です。Memoryから経験を受け取ってParameterを更新します。

実装が必要な関数は以下です。

from srl.base.rl.trainer import RLTrainer

class MyTrainer(RLTrainer):
   def __init__(self, *args):
      """ コントラクタの引数は親に渡してください """
      super().__init__(*args)

      # 以下の変数を持ちます。
      self.config: MyConfig
      self.parameter: MyParameter
      self.memory: IRLMemoryTrainer

   def train(self) -> None:
      """
      self.memory から batch を受け取り学習を定義します。
      self.memory は以下の関数が定義されています。

      self.memory.is_warmup_needed() : warmup中かどうかを返します
      self.memory.sample()           : batchを返します
      self.memory.update()           : ProportionalMemory の場合 update で使います

      ・学習したら回数を数えてください
      self.train_count += 1

      ・(option)必要に応じてinfoを設定します
      self.info = {"loss": 0.0}
      """
      raise NotImplementedError()


# --- 実装時に関数内で使う事を想定しているプロパティ・関数となります
trainer = MyTrainer()
trainer.distributed  # property, bool : 分散実行中かどうかを返します
trainer.train_only   # property, bool : 学習のみかどうかを返します

Worker

実際に環境と連携して経験を収集するクラスです。
役割は、Parameterを参照してアクションを決める事と、サンプルをMemoryに送信する事です。

フローをすごく簡単に書くと以下です。

env.reset()
worker.on_reset()
while:
   action = worker.policy()
   env.step(action)
   worker.on_step()
   trainer.train()

※v0.15.0からRLWorkerを直接継承する方法に変更しました

from srl.base.rl.worker import RLWorker
from srl.base.rl.worker_run import WorkerRun

class MyWorker(RLWorker):
   def __init__(self, *args):
      """ コントラクタの引数は親に渡してください """
      super().__init__(*args)

      # 以下の変数が設定されます
      self.config: MyConfig
      self.parameter: MyParameter
      self.memory: IRLMemoryWorker

   def on_reset(self, worker: WorkerRun) -> dict:
      """ エピソードの最初に呼ばれる関数

      Returns:
            Info         : 任意の情報
      """
      raise NotImplementedError()

   def policy(self, worker: WorkerRun) -> RLActionType | dict:
      """ このターンで実行するアクションを返す関数、この関数のみ実装が必須になります

      Returns: (
            RLActionType : 実行するアクション
            Info         : 任意の情報
      )
      """
      raise NotImplementedError()

   def on_step(self, worker: WorkerRun) -> dict:
      """ Envが1step実行した後に呼ばれる関数

      Returns:
            dict: 情報(任意)
      """
      raise NotImplementedError()

   def render_terminal(self, worker, **kwargs) -> None:
      """
      描画用の関数です。
      実装するとrenderによる描画が可能になります。
      """
      pass

   def render_rgb_array(self, worker, **kwargs) -> Optional[np.ndarray]:
      """
      描画用の関数です。
      実装するとrenderによる描画が可能になります。
      """
      return None

# --- 実装時に関数内で使う事を想定しているプロパティ・関数となります
worker = MyWorker()
worker.training     # property, bool : training かどうかを返します
worker.distributed  # property, bool : 分散実行中かどうかを返します
worker.rendering    # property, bool : renderがあるエピソードかどうかを返します
worker.observation_space  # property , SpaceBase : RLWorkerが受け取るobservation_spaceを返します
worker.action_space       # property , SpaceBase : RLWorkerが受け取るaction_spaceを返します
worker.get_invalid_actions() # function , List[RLAction] : 有効でないアクションを返します(離散限定)
worker.sample_action()       # function , RLAction : ランダムなアクションを返します

また、情報は WorkerRun から基本取り出して使います。 情報の例は以下です。

class MyWorker(RLWorker):
   def on_reset(self, worker):
      worker.state           # 初期状態
      worker.player_index    # 初期プレイヤーのindex
      worker.invalid_action  # 初期有効ではないアクションリスト

   def policy(self, worker) :
      worker.state           # 状態
      worker.player_index    # プレイヤーのindex
      worker.invalid_action  # 有効ではないアクションリスト

   def on_step(self, worker: "WorkerRun") -> dict:
      worker.prev_state      # step前の状態(policyのworker.stateと等価)
      worker.state           # step後の状態
      worker.prev_action     # step前のアクション(policyで返したアクションと等価)
      worker.reward          # step後の即時報酬
      worker.done            # step後に終了フラグが立ったか
      worker.terminated      # step後にenvが終了フラグを立てたか
      worker.player_index    # 次のプレイヤーのindex
      worker.prev_invalid_action  # step前の有効ではないアクションリスト
      worker.invalid_action       # step後の有効ではないアクションリスト

自作アルゴリズムの登録

以下で登録します。 第2引数以降の entry_point は、モジュールパス + ":" + クラス名`で、 モジュールパスは `importlib.import_module で呼び出せる形式である必要があります。

from srl.base.rl.registration import register
register(
   MyConfig(),
   __name__ + ":MyMemory",
   __name__ + ":MyParameter",
   __name__ + ":MyTrainer",
   __name__ + ":MyWorker",
)

型アノテーション

動作に影響はないですが、ジェネリック型を追加し実装を簡単にしています。 適用方法は以下です。

@dataclass
class Config(RLConfig):
   pass

# RLParameter[_TConfig]
#   _TConfig : RLConfig型
class Parameter(RLParameter[Config]):
   pass

# RLTrainer[_TConfig, _TParameter]
#   _TConfig    : RLConfig型
#   _TParameter : RLParameter型
class Trainer(RLTrainer[Config, Parameter]):
   pass

# RLWorker[_TConfig, _TParameter]
#   _TConfig    : RLConfig型
#   _TParameter : RLParameter型
class Worker(RLWorker[Config, Parameter]):
   pass

実装例(Q学習)

import json
import random
from dataclasses import dataclass
from typing import Tuple

import numpy as np

from srl.base.define import InfoType, RLBaseTypes
from srl.base.rl.config import RLConfig
from srl.base.rl.parameter import RLParameter
from srl.base.rl.registration import register
from srl.base.rl.trainer import RLTrainer
from srl.base.rl.worker import RLWorker
from srl.base.spaces.array_discrete import ArrayDiscreteSpace
from srl.base.spaces.discrete import DiscreteSpace
from srl.rl.memories.sequence_memory import SequenceMemory


@dataclass
class Config(RLConfig[DiscreteSpace, ArrayDiscreteSpace]):
    epsilon: float = 0.1
    test_epsilon: float = 0
    gamma: float = 0.9
    lr: float = 0.1

    def get_base_action_type(self) -> RLBaseTypes:
        return RLBaseTypes.DISCRETE

    def get_base_observation_type(self) -> RLBaseTypes:
        return RLBaseTypes.DISCRETE

    def get_framework(self) -> str:
        return ""

    def get_name(self) -> str:
        return "MyRL"


class Memory(SequenceMemory):
    pass


class Parameter(RLParameter[Config]):
    def __init__(self, *args):
        super().__init__(*args)

        self.Q = {}  # Q学習用のテーブル

    def call_restore(self, data, **kwargs) -> None:
        self.Q = json.loads(data)

    def call_backup(self, **kwargs):
        return json.dumps(self.Q)

    # Q値を取得する関数
    def get_action_values(self, state: str):
        if state not in self.Q:
            self.Q[state] = [0] * self.config.action_space.n
        return self.Q[state]


class Trainer(RLTrainer[Config, Parameter]):
    def __init__(self, *args):
        super().__init__(*args)

    def train(self) -> None:
        if self.memory.is_warmup_needed():
            return
        batchs = self.memory.sample()

        td_error = 0
        for batch in batchs:
            # 各batch毎にQテーブルを更新する
            s = batch["state"]
            n_s = batch["next_state"]
            action = batch["action"]
            reward = batch["reward"]
            done = batch["done"]

            q = self.parameter.get_action_values(s)
            n_q = self.parameter.get_action_values(n_s)

            if done:
                target_q = reward
            else:
                target_q = reward + self.config.gamma * max(n_q)

            td_error = target_q - q[action]
            q[action] += self.config.lr * td_error

            td_error += td_error
            self.train_count += 1  # 学習回数

        if len(batchs) > 0:
            td_error /= len(batchs)

        # 学習結果(任意)
        self.info = {
            "Q": len(self.parameter.Q),
            "td_error": td_error,
        }


class Worker(RLWorker[Config, Parameter]):
    def __init__(self, *args):
        super().__init__(*args)

    def on_reset(self, worker):
        return {}

    def policy(self, worker):
        self.state = self.config.observation_space.to_str(worker.state)

        # 学習中かどうかで探索率を変える
        if self.training:
            epsilon = self.config.epsilon
        else:
            epsilon = self.config.test_epsilon

        if random.random() < epsilon:
            # epsilonより低いならランダムに移動
            self.action = self.sample_action()
        else:
            q = self.parameter.get_action_values(self.state)
            q = np.asarray(q)

            # 最大値を選ぶ(複数あればランダム)
            self.action = np.random.choice(np.where(q == q.max())[0])

        return int(self.action), {}

    def on_step(self, worker) -> InfoType:
        if not self.training:
            return {}

        batch = {
            "state": self.state,
            "next_state": self.config.observation_space.to_str(worker.state),
            "action": self.action,
            "reward": worker.reward,
            "done": worker.terminated,
        }
        self.memory.add(batch)  # memoryはaddのみ
        return {}

    # 強化学習の可視化用、今回ですとQテーブルを表示しています。
    def render_terminal(self, worker, **kwargs) -> None:
        q = self.parameter.get_action_values(self.state)
        maxa = np.argmax(q)
        for a in range(self.config.action_space.n):
            if a == maxa:
                s = "*"
            else:
                s = " "
            s += f"{worker.env.action_to_str(a)}: {q[a]:7.5f}"
            print(s)


# ---------------------------------
# 登録
# ---------------------------------
register(
    Config(),
    __name__ + ":Memory",
    __name__ + ":Parameter",
    __name__ + ":Trainer",
    __name__ + ":Worker",
)


# ---------------------------------
# テスト
# ---------------------------------
from srl.test.rl import TestRL

tester = TestRL()
tester.test(Config())


# ---------------------------------
# Grid環境の学習
# ---------------------------------
import srl

runner = srl.Runner(srl.EnvConfig("Grid"), Config(lr=0.001))

# --- train
runner.train(timeout=10)

# --- test
rewards = runner.evaluate(max_episodes=100)
print("100エピソードの平均結果", np.mean(rewards))

runner.render_terminal()

runner.animation_save_gif("_MyRL-Grid.gif", render_scale=2)

renderの表示例

### 0, action None, rewards[0.000] (0.0s)
env   {}
work0 None
......
.   G.
. . X.
.P   .
......

←: 0.17756
↓: 0.16355
→: 0.11174
*↑: 0.37473
### 1, action 3(↑), rewards[-0.040] (0.0s)
env   {}
work0 {}
......
.   G.
.P. X.
.    .
......

←: 0.27779
↓: 0.20577
→: 0.27886
*↑: 0.49146
### 2, action 3(↑), rewards[-0.040] (0.0s)
env   {}
work0 {}
......
.P  G.
. . X.
.    .
......

←: 0.34255
↓: 0.29609
*→: 0.61361
↑: 0.34684
### 3, action 2(→), rewards[-0.040] (0.0s)
env   {}
work0 {}
......
.   G.
.P. X.
.    .
......

←: 0.27779
↓: 0.20577
→: 0.27886
*↑: 0.49146
### 4, action 3(↑), rewards[-0.040] (0.0s)
env   {}
work0 {}
......
.P  G.
. . X.
.    .
......

←: 0.34255
↓: 0.29609
*→: 0.61361
↑: 0.34684
### 5, action 2(→), rewards[-0.040] (0.0s)
env   {}
work0 {}
......
. P G.
. . X.
.    .
......

←: 0.37910
↓: 0.44334
*→: 0.76733
↑: 0.46368
### 6, action 2(→), rewards[-0.040] (0.0s)
env   {}
work0 {}
......
.  PG.
. . X.
.    .
......

←: 0.47941
↓: 0.39324
*→: 0.92425
↑: 0.59087
### 7, action 2(→), rewards[1.000], done(env) (0.0s)
env   {}
work0 {}
......
.   P.
. . X.
.    .
......

[0.760000005364418]
../_images/custom_algorithm4.gif