swmpo.transition
A transition is a triple consisting of a source state, an action and a next state.
Each of those is a one-dimensional torch.Tensor.
1"""A transition is a triple consisting of a source state, an action 2and a next state. 3 4Each of those is a one-dimensional `torch.Tensor`. 5""" 6from dataclasses import dataclass 7from pathlib import Path 8import tempfile 9import shutil 10import zipfile 11import torch 12 13 14@dataclass 15class Transition: 16 source_state: torch.Tensor 17 action: torch.Tensor 18 next_state: torch.Tensor 19 20 def to(self, device: str) -> "Transition": 21 return Transition( 22 source_state=self.source_state.to(device), 23 action=self.action.to(device), 24 next_state=self.next_state.to(device), 25 ) 26 27 28def equals( 29 t1: Transition, 30 t2: Transition, 31) -> bool: 32 if not torch.equal(t1.source_state, t2.source_state): 33 return False 34 if not torch.equal(t1.action, t2.action): 35 return False 36 if not torch.equal(t1.next_state, t2.next_state): 37 return False 38 return True 39 40 41def get_vector( 42 transition: Transition 43) -> torch.Tensor: 44 """Return the transition as a single vector.""" 45 vector = torch.cat([ 46 transition.source_state, 47 transition.action, 48 transition.next_state, 49 ]) 50 return vector 51 52 53SOURCE_STATE_PATH = "source_state.pt" 54ACTION_PATH = "action.pt" 55NEXT_STATE_PATH = "next_state.pt" 56 57 58def serialize(t: Transition, output_zip_path: Path): 59 """Serialize a transition into a ZIP file.""" 60 assert output_zip_path.suffix == ".zip" 61 with tempfile.TemporaryDirectory() as tmpdirname: 62 output_dir = Path(tmpdirname) 63 torch.save(t.source_state, output_dir/SOURCE_STATE_PATH) 64 torch.save(t.action, output_dir/ACTION_PATH) 65 torch.save(t.next_state, output_dir/NEXT_STATE_PATH) 66 shutil.make_archive(str(output_zip_path.with_suffix("")), 'zip', output_dir) 67 68 69def deserialize(zip_path: Path) -> Transition: 70 with tempfile.TemporaryDirectory() as tmpdirname: 71 output_dir = Path(tmpdirname) 72 73 with zipfile.ZipFile(zip_path, "r") as zip_ref: 74 zip_ref.extractall(output_dir) 75 76 source_state = torch.load( 77 output_dir/SOURCE_STATE_PATH, weights_only=True 78 ) 79 action = torch.load( 80 output_dir/ACTION_PATH, weights_only=True, 81 ) 82 next_state = torch.load( 83 output_dir/NEXT_STATE_PATH, weights_only=True, 84 ) 85 86 return Transition( 87 source_state=source_state, 88 action=action, 89 next_state=next_state, 90 )
@dataclass
class
Transition:
15@dataclass 16class Transition: 17 source_state: torch.Tensor 18 action: torch.Tensor 19 next_state: torch.Tensor 20 21 def to(self, device: str) -> "Transition": 22 return Transition( 23 source_state=self.source_state.to(device), 24 action=self.action.to(device), 25 next_state=self.next_state.to(device), 26 )
42def get_vector( 43 transition: Transition 44) -> torch.Tensor: 45 """Return the transition as a single vector.""" 46 vector = torch.cat([ 47 transition.source_state, 48 transition.action, 49 transition.next_state, 50 ]) 51 return vector
Return the transition as a single vector.
SOURCE_STATE_PATH =
'source_state.pt'
ACTION_PATH =
'action.pt'
NEXT_STATE_PATH =
'next_state.pt'
59def serialize(t: Transition, output_zip_path: Path): 60 """Serialize a transition into a ZIP file.""" 61 assert output_zip_path.suffix == ".zip" 62 with tempfile.TemporaryDirectory() as tmpdirname: 63 output_dir = Path(tmpdirname) 64 torch.save(t.source_state, output_dir/SOURCE_STATE_PATH) 65 torch.save(t.action, output_dir/ACTION_PATH) 66 torch.save(t.next_state, output_dir/NEXT_STATE_PATH) 67 shutil.make_archive(str(output_zip_path.with_suffix("")), 'zip', output_dir)
Serialize a transition into a ZIP file.
70def deserialize(zip_path: Path) -> Transition: 71 with tempfile.TemporaryDirectory() as tmpdirname: 72 output_dir = Path(tmpdirname) 73 74 with zipfile.ZipFile(zip_path, "r") as zip_ref: 75 zip_ref.extractall(output_dir) 76 77 source_state = torch.load( 78 output_dir/SOURCE_STATE_PATH, weights_only=True 79 ) 80 action = torch.load( 81 output_dir/ACTION_PATH, weights_only=True, 82 ) 83 next_state = torch.load( 84 output_dir/NEXT_STATE_PATH, weights_only=True, 85 ) 86 87 return Transition( 88 source_state=source_state, 89 action=action, 90 next_state=next_state, 91 )