LLVM 17.0.0git
Go to the documentation of this file.
1//===- TrainingLogger.h - mlgo feature/reward logging ----------*- C++ -*-===//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9// The design goals of the logger are:
10// - no dependencies that llvm doesn't already have.
11// - support streaming, so that we don't need to buffer data during compilation
12// - 0-decoding tensor values. Tensor values are potentially very large buffers
13// of scalars. Because of their potentially large size, avoiding
14// serialization/deserialization overhead is preferred.
16// The simple logger produces an output of the form (each line item on its line)
17// - header: a json object describing the data that will follow.
18// - context: e.g. function name, for regalloc, or "default" for module-wide
19// optimizations like the inliner. This is the context to which the subsequent
20// data corresponds.
21// - observation number.
22// - tensor values - raw bytes of the tensors, in the order given in the header.
23// The values are in succession, i.e. no separator is found between successive
24// tensor values. At the end, there is a new line character.
25// - [score] - this is optional, and is present if it was present in the header.
26// Currently, for final rewards, we output "0" scores after each observation,
27// except for the last one.
28// <repeat>
29// The file should be read as binary, but the reason we use newlines is mostly
30// ease of debugging: the log can be opened in a text editor and, while tensor
31// values are inscrutable, at least the sequence of data can be easily observed.
32// Of course, the buffer of tensor values could contain '\n' bytes. A reader
33// should use the header information to know how much data to read for the
34// tensor values, and not use line information for that.
36// An example reader, used for test, is available at
37// Analysis/models/log_reader.py
39// Example:
40// {"features":[list of TensorSpecs], "score":<a tensor spec>}
41// {"context": "aFunction"}
42// {"observation": 0}
43// <bytes>
44// {"outcome": 0}
45// <bytes for the tensor corresponding to the "score" spec in the header>
46// {"observation": 1}
47// ...
48// {"context": "anotherFunction"}
49// {"observation": 0}
50// ...
56#include "llvm/Config/llvm-config.h"
58#include "llvm/ADT/StringMap.h"
60#include "llvm/IR/LLVMContext.h"
61#include "llvm/Support/JSON.h"
63#include <memory>
64#include <optional>
65#include <vector>
67namespace llvm {
69/// Logging utility - given an ordered specification of features, and assuming
70/// a scalar reward, allow logging feature values and rewards.
71/// The assumption is that, for an event to be logged (i.e. a set of feature
72/// values and a reward), the user calls the log* API for each feature exactly
73/// once, providing the index matching the position in the feature spec list
74/// provided at construction. The example assumes the first feature's element
75/// type is float, the second is int64, and the reward is float:
77/// event 0:
78/// logFloatValue(0, ...)
79/// logInt64Value(1, ...)
80/// ...
81/// logFloatReward(...)
82/// event 1:
83/// logFloatValue(0, ...)
84/// logInt64Value(1, ...)
85/// ...
86/// logFloatReward(...)
88/// At the end, call print to generate the log.
89/// Alternatively, don't call logReward at the end of each event, just
90/// log{Float|Int32|Int64}FinalReward at the end.
91class Logger final {
92 std::unique_ptr<raw_ostream> OS;
93 const std::vector<TensorSpec> FeatureSpecs;
94 const TensorSpec RewardSpec;
95 const bool IncludeReward;
96 StringMap<size_t> ObservationIDs;
97 std::string CurrentContext;
99 void writeHeader(std::optional<TensorSpec> AdviceSpec);
100 void writeTensor(const TensorSpec &Spec, const char *RawData) {
101 OS->write(RawData, Spec.getTotalTensorBufferSize());
102 }
103 void logRewardImpl(const char *RawData);
106 /// Construct a Logger. If IncludeReward is false, then logReward or
107 /// logFinalReward shouldn't be called, and the reward feature won't be
108 /// printed out.
109 /// NOTE: the FeatureSpecs are expected to be in the same order (i.e. have
110 /// corresponding indices) with any MLModelRunner implementations
111 /// corresponding to the model being trained/logged.
112 Logger(std::unique_ptr<raw_ostream> OS,
113 const std::vector<TensorSpec> &FeatureSpecs,
114 const TensorSpec &RewardSpec, bool IncludeReward,
115 std::optional<TensorSpec> AdviceSpec = std::nullopt);
118 void startObservation();
119 void endObservation();
120 void flush() { OS->flush(); }
122 const std::string &currentContext() const { return CurrentContext; }
124 /// Check if there is at least an observation for `currentContext()`.
126 return hasAnyObservationForContext(CurrentContext);
127 }
129 /// Check if there is at least an observation for the context `Ctx`.
131 return ObservationIDs.contains(Ctx);
132 }
134 template <typename T> void logReward(T Value) {
135 logRewardImpl(reinterpret_cast<const char *>(&Value));
136 }
138 void logTensorValue(size_t FeatureID, const char *RawData) {
139 writeTensor(FeatureSpecs[FeatureID], RawData);
140 }
143} // namespace llvm
This file defines the StringMap class.
std::string Name
This file supports working with JSON data.
raw_pwrite_stream & OS
Logging utility - given an ordered specification of features, and assuming a scalar reward,...
bool hasAnyObservationForContext(StringRef Ctx) const
Check if there is at least an observation for the context Ctx.
void startObservation()
bool hasObservationInProgress() const
Check if there is at least an observation for currentContext().
void switchContext(StringRef Name)
void logReward(T Value)
void endObservation()
void logTensorValue(size_t FeatureID, const char *RawData)
const std::string & currentContext() const
StringMap - This is an unconventional map that is specialized for handling keys that are "strings",...
Definition: StringMap.h:111
bool contains(StringRef Key) const
contains - Return true if the element is in the map, false otherwise.
Definition: StringMap.h:253
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
LLVM Value Representation.
Definition: Value.h:74
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18