LLVM 19.0.0git
TrainingLogger.cpp
Go to the documentation of this file.
1//===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logging infrastructure for extracting features and
10// rewards for mlgo policy training.
11//
12//===----------------------------------------------------------------------===//
14#include "llvm/Config/config.h"
15
16#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/JSON.h"
22#include "llvm/Support/Path.h"
24
25#include <cassert>
26#include <numeric>
27
28using namespace llvm;
29
30void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
31 json::OStream JOS(*OS);
32 JOS.object([&]() {
33 JOS.attributeArray("features", [&]() {
34 for (const auto &TS : FeatureSpecs)
35 TS.toJSON(JOS);
36 });
37 if (IncludeReward) {
38 JOS.attributeBegin("score");
39 RewardSpec.toJSON(JOS);
40 JOS.attributeEnd();
41 }
42 if (AdviceSpec.has_value()) {
43 JOS.attributeBegin("advice");
44 AdviceSpec->toJSON(JOS);
45 JOS.attributeEnd();
46 }
47 });
48 *OS << "\n";
49}
50
52 CurrentContext = Name.str();
53 json::OStream JOS(*OS);
54 JOS.object([&]() { JOS.attribute("context", Name); });
55 *OS << "\n";
56}
57
59 auto I = ObservationIDs.insert({CurrentContext, 0});
60 size_t NewObservationID = I.second ? 0 : ++I.first->second;
61 json::OStream JOS(*OS);
62 JOS.object([&]() {
63 JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
64 });
65 *OS << "\n";
66}
67
68void Logger::endObservation() { *OS << "\n"; }
69
70void Logger::logRewardImpl(const char *RawData) {
71 assert(IncludeReward);
72 json::OStream JOS(*OS);
73 JOS.object([&]() {
74 JOS.attribute("outcome", static_cast<int64_t>(
75 ObservationIDs.find(CurrentContext)->second));
76 });
77 *OS << "\n";
78 writeTensor(RewardSpec, RawData);
79 *OS << "\n";
80}
81
82Logger::Logger(std::unique_ptr<raw_ostream> OS,
83 const std::vector<TensorSpec> &FeatureSpecs,
84 const TensorSpec &RewardSpec, bool IncludeReward,
85 std::optional<TensorSpec> AdviceSpec)
86 : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
87 IncludeReward(IncludeReward) {
88 writeHeader(AdviceSpec);
89}
std::string Name
This file supports working with JSON data.
#define I(x, y, z)
Definition: MD5.cpp:58
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
raw_pwrite_stream & OS
void startObservation()
void switchContext(StringRef Name)
void endObservation()
Logger(std::unique_ptr< raw_ostream > OS, const std::vector< TensorSpec > &FeatureSpecs, const TensorSpec &RewardSpec, bool IncludeReward, std::optional< TensorSpec > AdviceSpec=std::nullopt)
Construct a Logger.
iterator find(StringRef Key)
Definition: StringMap.h:234
bool insert(MapEntryTy *KeyValue)
insert - Insert the specified key/value pair into the map.
Definition: StringMap.h:307
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
void toJSON(json::OStream &OS) const
Definition: TensorSpec.cpp:50
json::OStream allows writing well-formed JSON without materializing all structures as json::Value ahe...
Definition: JSON.h:977
void object(Block Contents)
Emit an object whose elements are emitted in the provided Block.
Definition: JSON.h:1007
void attribute(llvm::StringRef Key, const Value &Contents)
Emit an attribute whose value is self-contained (number, vector<int> etc).
Definition: JSON.h:1032
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1858
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858