LLVM 19.0.0git
ModelUnderTrainingRunner.h
Go to the documentation of this file.
1//===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9
10#ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
11#define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
12
13#include "llvm/ADT/STLExtras.h"
16#include "llvm/Config/llvm-config.h"
17
18#ifdef LLVM_HAVE_TFLITE
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/PassManager.h"
23
24namespace llvm {
25
26/// ModelUnderTrainingRunner - training mode implementation. It uses TFLite
27/// to dynamically load and evaluate a TF SavedModel
28/// (https://www.tensorflow.org/guide/saved_model) converted to TFLite. see
29/// lib/Analysis/models/saved-model-to-tflite.py. Runtime performance is
30/// sacrificed for ease of use while training.
31class ModelUnderTrainingRunner final : public MLModelRunner {
32public:
33 // Disallows copy and assign.
34 ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
35 ModelUnderTrainingRunner &
36 operator=(const ModelUnderTrainingRunner &) = delete;
37
38 const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
39 return ExtraOutputsForLogging;
40 }
41
42 const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
43 return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
44 }
45
46 const std::optional<TFModelEvaluator::EvaluationResult> &
47 lastEvaluationResult() const {
48 return LastEvaluationResult;
49 }
50 static bool classof(const MLModelRunner *R) {
51 return R->getKind() == MLModelRunner::Kind::Development;
52 }
53
54 static std::unique_ptr<ModelUnderTrainingRunner>
55 createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
56 StringRef DecisionName,
57 const std::vector<TensorSpec> &InputSpecs,
58 StringRef OutputSpecsPathOverride = "");
59
60 ModelUnderTrainingRunner(
61 LLVMContext &Ctx, const std::string &ModelPath,
62 const std::vector<TensorSpec> &InputSpecs,
63 const std::vector<TensorSpec> &OutputSpecs,
64 const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
65
66 bool isValid() const { return !!Evaluator; }
67
68private:
69 std::unique_ptr<TFModelEvaluator> Evaluator;
70 const std::vector<TensorSpec> OutputSpecs;
71 const std::vector<TensorSpec> ExtraOutputsForLogging;
72 std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
73 void *evaluateUntyped() override;
74};
75
76} // namespace llvm
77#endif // define(LLVM_HAVE_TFLITE)
78#endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
#define DecisionName
This header defines various interfaces for pass management in LLVM.
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
This file contains some templates that are useful if you are working with the STL at all.
This provides a very simple, boring adaptor for a begin and end iterator into a range type.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18