1"""Reader for training log.
3See lib/Analysis/TrainingLogger.cpp for a description of the format.
11from typing
import List, Optional
14 "float": ctypes.c_float,
15 "double": ctypes.c_double,
16 "int8_t": ctypes.c_int8,
17 "uint8_t": ctypes.c_uint8,
18 "int16_t": ctypes.c_int16,
19 "uint16_t": ctypes.c_uint16,
20 "int32_t": ctypes.c_int32,
21 "uint32_t": ctypes.c_uint32,
22 "int64_t": ctypes.c_int64,
23 "uint64_t": ctypes.c_uint64,
27@dataclasses.dataclass(frozen=True)
38 shape = [int(e)
for e
in d[
"shape"]]
39 element_type_str = d[
"type"]
40 if element_type_str
not in _element_types:
41 raise ValueError(f
"uknown type: {element_type_str}")
46 element_type=_element_types[element_type_str],
51 def __init__(self, spec: TensorSpec, buffer: bytes):
57 def spec(self) -> TensorSpec:
64 if index < 0
or index >= self.
_len:
65 raise IndexError(f
"Index {index} out of range [0..{self._len})")
66 return self.
_view[index]
69def read_tensor(fs: io.BufferedReader, ts: TensorSpec) -> TensorValue:
70 size = math.prod(ts.shape) * ctypes.sizeof(ts.element_type)
76 print(f
'{tv.spec().name}: {",".join([str(v) for v in tv])}')
80 header = json.loads(f.readline())
81 tensor_specs = [TensorSpec.from_dict(ts)
for ts
in header[
"features"]]
82 score_spec = TensorSpec.from_dict(header[
"score"])
if "score" in header
else None
83 advice_spec = TensorSpec.from_dict(header[
"advice"])
if "advice" in header
else None
84 return tensor_specs, score_spec, advice_spec
88 context: Optional[str],
91 tensor_specs: List[TensorSpec],
92 score_spec: Optional[TensorSpec],
94 event = json.loads(event_str)
95 if "context" in event:
96 context = event[
"context"]
97 event = json.loads(f.readline())
98 observation_id = int(event[
"observation"])
100 for ts
in tensor_specs:
104 if score_spec
is not None:
105 score_header = json.loads(f.readline())
106 assert int(score_header[
"outcome"]) == observation_id
109 return context, observation_id, features, score
113 with io.BufferedReader(io.FileIO(fname,
"rb"))
as f:
117 event_str = f.readline()
121 context, event_str, f, tensor_specs, score_spec
123 yield context, observation_id, features, score
128 for ctx, obs_id, features, score
in read_stream(args[1]):
129 if last_context != ctx:
130 print(f
"context: {ctx}")
132 print(f
"observation: {obs_id}")
139if __name__ ==
"__main__":
static void print(raw_ostream &Out, object::Archive::Kind Kind, T Val)
def __getitem__(self, index)
def __init__(self, TensorSpec spec, bytes buffer)
def read_one_observation(Optional[str] context, str event_str, io.BufferedReader f, List[TensorSpec] tensor_specs, Optional[TensorSpec] score_spec)
def read_stream(str fname)
def read_header(io.BufferedReader f)
def pretty_print_tensor_value(TensorValue tv)
TensorValue read_tensor(io.BufferedReader fs, TensorSpec ts)