10#ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
11#define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
16#include "llvm/Config/llvm-config.h"
18#ifdef LLVM_HAVE_TFLITE
31class ModelUnderTrainingRunner final :
public MLModelRunner {
34 ModelUnderTrainingRunner(
const ModelUnderTrainingRunner &) =
delete;
35 ModelUnderTrainingRunner &
36 operator=(
const ModelUnderTrainingRunner &) =
delete;
38 const std::vector<TensorSpec> &extraOutputsForLoggingSpecs()
const {
39 return ExtraOutputsForLogging;
42 const void *getUntypedExtraOutputValue(
size_t ExtraOutputIndex)
const {
43 return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
46 const std::optional<TFModelEvaluator::EvaluationResult> &
47 lastEvaluationResult()
const {
48 return LastEvaluationResult;
50 static bool classof(
const MLModelRunner *R) {
51 return R->getKind() == MLModelRunner::Kind::Development;
54 static std::unique_ptr<ModelUnderTrainingRunner>
55 createAndEnsureValid(LLVMContext &Ctx,
const std::string &ModelPath,
57 const std::vector<TensorSpec> &InputSpecs,
58 StringRef OutputSpecsPathOverride =
"");
60 ModelUnderTrainingRunner(
61 LLVMContext &Ctx,
const std::string &ModelPath,
62 const std::vector<TensorSpec> &InputSpecs,
63 const std::vector<TensorSpec> &OutputSpecs,
64 const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
66 bool isValid()
const {
return !!Evaluator; }
69 std::unique_ptr<TFModelEvaluator> Evaluator;
70 const std::vector<TensorSpec> OutputSpecs;
71 const std::vector<TensorSpec> ExtraOutputsForLogging;
72 std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
73 void *evaluateUntyped()
override;
This header defines various interfaces for pass management in LLVM.
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
This provides a very simple, boring adaptor for a begin and end iterator into a range type.
This is an optimization pass for GlobalISel generic memory operations.