swmpo.transition

A transition is a triple consisting of a source state, an action and a next state.

Each of those is a one-dimensional torch.Tensor.

 1"""A transition is a triple consisting of a source state, an action
 2and a next state.
 3
 4Each of those is a one-dimensional `torch.Tensor`.
 5"""
 6from dataclasses import dataclass
 7from pathlib import Path
 8import tempfile
 9import shutil
10import zipfile
11import torch
12
13
14@dataclass
15class Transition:
16    source_state: torch.Tensor
17    action: torch.Tensor
18    next_state: torch.Tensor
19
20    def to(self, device: str) -> "Transition":
21        return Transition(
22            source_state=self.source_state.to(device),
23            action=self.action.to(device),
24            next_state=self.next_state.to(device),
25        )
26
27
28def equals(
29    t1: Transition,
30    t2: Transition,
31) -> bool:
32    if not torch.equal(t1.source_state, t2.source_state):
33        return False
34    if not torch.equal(t1.action, t2.action):
35        return False
36    if not torch.equal(t1.next_state, t2.next_state):
37        return False
38    return True
39
40
41def get_vector(
42    transition: Transition
43) -> torch.Tensor:
44    """Return the transition as a single vector."""
45    vector = torch.cat([
46        transition.source_state,
47        transition.action,
48        transition.next_state,
49    ])
50    return vector
51
52
53SOURCE_STATE_PATH = "source_state.pt"
54ACTION_PATH = "action.pt"
55NEXT_STATE_PATH = "next_state.pt"
56
57
58def serialize(t: Transition, output_zip_path: Path):
59    """Serialize a transition into a ZIP file."""
60    assert output_zip_path.suffix == ".zip"
61    with tempfile.TemporaryDirectory() as tmpdirname:
62        output_dir = Path(tmpdirname)
63        torch.save(t.source_state, output_dir/SOURCE_STATE_PATH)
64        torch.save(t.action, output_dir/ACTION_PATH)
65        torch.save(t.next_state, output_dir/NEXT_STATE_PATH)
66        shutil.make_archive(str(output_zip_path.with_suffix("")), 'zip', output_dir)
67
68
69def deserialize(zip_path: Path) -> Transition:
70    with tempfile.TemporaryDirectory() as tmpdirname:
71        output_dir = Path(tmpdirname)
72
73        with zipfile.ZipFile(zip_path, "r") as zip_ref:
74            zip_ref.extractall(output_dir)
75
76        source_state = torch.load(
77            output_dir/SOURCE_STATE_PATH, weights_only=True
78        )
79        action = torch.load(
80            output_dir/ACTION_PATH, weights_only=True,
81        )
82        next_state = torch.load(
83            output_dir/NEXT_STATE_PATH, weights_only=True,
84        )
85
86    return Transition(
87        source_state=source_state,
88        action=action,
89        next_state=next_state,
90    )
@dataclass
class Transition:
15@dataclass
16class Transition:
17    source_state: torch.Tensor
18    action: torch.Tensor
19    next_state: torch.Tensor
20
21    def to(self, device: str) -> "Transition":
22        return Transition(
23            source_state=self.source_state.to(device),
24            action=self.action.to(device),
25            next_state=self.next_state.to(device),
26        )
Transition( source_state: torch.Tensor, action: torch.Tensor, next_state: torch.Tensor)
source_state: torch.Tensor
action: torch.Tensor
next_state: torch.Tensor
def to(self, device: str) -> Transition:
21    def to(self, device: str) -> "Transition":
22        return Transition(
23            source_state=self.source_state.to(device),
24            action=self.action.to(device),
25            next_state=self.next_state.to(device),
26        )
def equals(t1: Transition, t2: Transition) -> bool:
29def equals(
30    t1: Transition,
31    t2: Transition,
32) -> bool:
33    if not torch.equal(t1.source_state, t2.source_state):
34        return False
35    if not torch.equal(t1.action, t2.action):
36        return False
37    if not torch.equal(t1.next_state, t2.next_state):
38        return False
39    return True
def get_vector(transition: Transition) -> torch.Tensor:
42def get_vector(
43    transition: Transition
44) -> torch.Tensor:
45    """Return the transition as a single vector."""
46    vector = torch.cat([
47        transition.source_state,
48        transition.action,
49        transition.next_state,
50    ])
51    return vector

Return the transition as a single vector.

SOURCE_STATE_PATH = 'source_state.pt'
ACTION_PATH = 'action.pt'
NEXT_STATE_PATH = 'next_state.pt'
def serialize(t: Transition, output_zip_path: pathlib.Path):
59def serialize(t: Transition, output_zip_path: Path):
60    """Serialize a transition into a ZIP file."""
61    assert output_zip_path.suffix == ".zip"
62    with tempfile.TemporaryDirectory() as tmpdirname:
63        output_dir = Path(tmpdirname)
64        torch.save(t.source_state, output_dir/SOURCE_STATE_PATH)
65        torch.save(t.action, output_dir/ACTION_PATH)
66        torch.save(t.next_state, output_dir/NEXT_STATE_PATH)
67        shutil.make_archive(str(output_zip_path.with_suffix("")), 'zip', output_dir)

Serialize a transition into a ZIP file.

def deserialize(zip_path: pathlib.Path) -> Transition:
70def deserialize(zip_path: Path) -> Transition:
71    with tempfile.TemporaryDirectory() as tmpdirname:
72        output_dir = Path(tmpdirname)
73
74        with zipfile.ZipFile(zip_path, "r") as zip_ref:
75            zip_ref.extractall(output_dir)
76
77        source_state = torch.load(
78            output_dir/SOURCE_STATE_PATH, weights_only=True
79        )
80        action = torch.load(
81            output_dir/ACTION_PATH, weights_only=True,
82        )
83        next_state = torch.load(
84            output_dir/NEXT_STATE_PATH, weights_only=True,
85        )
86
87    return Transition(
88        source_state=source_state,
89        action=action,
90        next_state=next_state,
91    )