LLVM 22.0.0git
TrainingLogger.cpp
Go to the documentation of this file.
1//===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements logging infrastructure for extracting features and
10// rewards for mlgo policy training.
11//
12//===----------------------------------------------------------------------===//
14#include "llvm/Config/config.h"
15
16#include "llvm/ADT/Twine.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/JSON.h"
22#include "llvm/Support/Path.h"
24
25#include <cassert>
26
27using namespace llvm;
28
29void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
30 json::OStream JOS(*OS);
31 JOS.object([&]() {
32 JOS.attributeArray("features", [&]() {
33 for (const auto &TS : FeatureSpecs)
34 TS.toJSON(JOS);
35 });
36 if (IncludeReward) {
37 JOS.attributeBegin("score");
38 RewardSpec.toJSON(JOS);
39 JOS.attributeEnd();
40 }
41 if (AdviceSpec.has_value()) {
42 JOS.attributeBegin("advice");
43 AdviceSpec->toJSON(JOS);
44 JOS.attributeEnd();
45 }
46 });
47 *OS << "\n";
48}
49
51 CurrentContext = Name.str();
52 json::OStream JOS(*OS);
53 JOS.object([&]() { JOS.attribute("context", Name); });
54 *OS << "\n";
55}
56
58 auto I = ObservationIDs.insert({CurrentContext, 0});
59 size_t NewObservationID = I.second ? 0 : ++I.first->second;
60 json::OStream JOS(*OS);
61 JOS.object([&]() {
62 JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
63 });
64 *OS << "\n";
65}
66
67void Logger::endObservation() { *OS << "\n"; }
68
69void Logger::logRewardImpl(const char *RawData) {
70 assert(IncludeReward);
71 json::OStream JOS(*OS);
72 JOS.object([&]() {
73 JOS.attribute("outcome", static_cast<int64_t>(
74 ObservationIDs.find(CurrentContext)->second));
75 });
76 *OS << "\n";
77 writeTensor(RewardSpec, RawData);
78 *OS << "\n";
79}
80
81Logger::Logger(std::unique_ptr<raw_ostream> OS,
82 const std::vector<TensorSpec> &FeatureSpecs,
83 const TensorSpec &RewardSpec, bool IncludeReward,
84 std::optional<TensorSpec> AdviceSpec)
85 : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
86 IncludeReward(IncludeReward) {
87 writeHeader(AdviceSpec);
88}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file supports working with JSON data.
#define I(x, y, z)
Definition MD5.cpp:57
LLVM_ABI void startObservation()
LLVM_ABI void switchContext(StringRef Name)
LLVM_ABI void endObservation()
LLVM_ABI Logger(std::unique_ptr< raw_ostream > OS, const std::vector< TensorSpec > &FeatureSpecs, const TensorSpec &RewardSpec, bool IncludeReward, std::optional< TensorSpec > AdviceSpec=std::nullopt)
Construct a Logger.
iterator find(StringRef Key)
Definition StringMap.h:237
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
json::OStream allows writing well-formed JSON without materializing all structures as json::Value ahe...
Definition JSON.h:1000
void object(Block Contents)
Emit an object whose elements are emitted in the provided Block.
Definition JSON.h:1030
void attribute(llvm::StringRef Key, const Value &Contents)
Emit an attribute whose value is self-contained (number, vector<int> etc).
Definition JSON.h:1055
This is an optimization pass for GlobalISel generic memory operations.
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1867
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:867