Shortcuts

Program Listing for File deploy.h

Return to documentation for file (multipy/multipy/runtime/deploy.h)

// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <c10/util/irange.h>
#include <multipy/runtime/embedded_file.h>
#include <multipy/runtime/interpreter/interpreter_impl.h>
#include <multipy/runtime/noop_environment.h>
#include <torch/csrc/api/include/torch/imethod.h>
#include <torch/csrc/jit/serialization/import.h>

#include <multipy/runtime/interpreter/Optional.hpp>
#include <cassert>
#include <fstream>
#include <functional>
#include <string>
#include <thread>
#include <vector>

namespace torch {
namespace deploy {

struct ReplicatedObj;
struct InterpreterManager;
struct LoadBalancer;

struct TORCH_API InterpreterSession {
  friend struct LoadBalancer;

  explicit InterpreterSession(InterpreterSessionImpl* impl) noexcept
      : impl_(impl), manager_(nullptr) {}
  InterpreterSession(
      InterpreterSessionImpl* impl,
      InterpreterManager* manager) noexcept
      : impl_(impl), manager_(manager) {}

  bool isOwner(Obj obj) {
    return impl_->isOwner(obj);
  }
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  Obj self; // when retrieved from a PythonMovable this will be set.
  InterpreterSession(InterpreterSession&&) noexcept = default;
  // NOLINTNEXTLINE(bugprone-exception-escape)
  ~InterpreterSession();

  Obj global(const char* module, const char* name) {
    return impl_->global(module, name);
  }

  Obj fromIValue(at::IValue ivalue) {
    return impl_->fromIValue(std::move(ivalue));
  }
  ReplicatedObj createMovable(Obj obj);

  Obj fromMovable(const ReplicatedObj& obj);

 protected:
  bool attachDeconstructorCallback(std::function<void()> func);

 private:
  friend struct ReplicatedObj;
  friend struct Package;
  friend struct InterpreterManager;
  friend struct ReplicatedObjImpl;
  inline static size_t nextObjectId_ = 0;
  std::unique_ptr<InterpreterSessionImpl> impl_;
  InterpreterManager* manager_;
  std::function<void()> deconstruction_callback_ = nullptr;
  PickledObject pickleObj(Obj obj);
};

class TORCH_API Interpreter {
 private:
  void* handle_;
  std::unique_ptr<InterpreterImpl> pImpl_;
  InterpreterManager* manager_;
  std::shared_ptr<Environment> env_;

  EmbeddedFile interpreterFile_;
  multipy::optional<EmbeddedFile> torchPluginFile_;

 public:
  Interpreter(InterpreterManager* manager, std::shared_ptr<Environment> env);

  explicit Interpreter(std::shared_ptr<Environment> env)
      : Interpreter(nullptr, env) {}

  InterpreterSession acquireSession() const {
    if (manager_) {
      return InterpreterSession(pImpl_->acquireSession(), manager_);
    } else {
      return InterpreterSession(pImpl_->acquireSession());
    }
  }

  ~Interpreter();
  Interpreter(Interpreter&& rhs) noexcept
      : handle_(rhs.handle_),
        pImpl_(std::move(rhs.pImpl_)),
        manager_(rhs.manager_),
        interpreterFile_(std::move(rhs.interpreterFile_)),
        torchPluginFile_(std::move(rhs.torchPluginFile_)) {
    rhs.handle_ = nullptr;
  }

  Interpreter(const Interpreter&) = delete;
  Interpreter& operator=(const Interpreter&) = delete;
  Interpreter& operator=(Interpreter&&) = delete;
  friend struct InterpreterManager;
};

struct Package;

struct TORCH_API LoadBalancer {
  explicit LoadBalancer(size_t n)
      : uses_(new uint64_t[8 * n]), allocated_(n), n_(n) {
    memset(uses_.get(), 0, 8 * n_ * sizeof(uint64_t));
  }

  void setResourceLimit(size_t n) {
    MULTIPY_INTERNAL_ASSERT(n <= allocated_);
    n_ = n;
  }

  int acquire();

  void free(int where);

 private:
  // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
  std::unique_ptr<uint64_t[]>
      uses_;
  size_t allocated_;
  size_t n_;
};

struct TORCH_API InterpreterManager {
  explicit InterpreterManager(
      size_t nInterp = 2,
      std::shared_ptr<Environment> env = std::make_shared<NoopEnvironment>());

  InterpreterSession acquireOne() {
    int where = resources_.acquire();
    InterpreterSession I = instances_[where].acquireSession();
    I.attachDeconstructorCallback(
        [this, where]() -> void { resources_.free(where); });
    return I;
  }

  at::ArrayRef<Interpreter> allInstances() {
    return instances_;
  }

  void debugLimitInterpreters(size_t N) {
    AT_ASSERT(N <= instances_.size());
    resources_.setResourceLimit(N);
  }

  Package loadPackage(const std::string& uri);

  Package loadPackage(
      std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader);

  void registerModuleSource(std::string name, std::string src) {
    registeredModuleSource_[std::move(name)] = std::move(src);
  }

  size_t countRegisteredModuleSources() {
    return registeredModuleSource_.size();
  }

  ReplicatedObj createMovable(Obj obj, InterpreterSession* I);
  InterpreterManager(const InterpreterManager&) = delete;
  InterpreterManager& operator=(const InterpreterManager&) = delete;
  InterpreterManager& operator=(InterpreterManager&&) = delete;

 private:
  friend struct Package;
  friend struct InterpreterSession;
  friend struct InterpreterSessionImpl;
  std::vector<Interpreter> instances_;
  LoadBalancer resources_;
  std::unordered_map<std::string, std::string> registeredModuleSource_;
};

struct TORCH_API ReplicatedObjImpl {
  ReplicatedObjImpl(
      size_t object_id,
      // NOLINTNEXTLINE(modernize-pass-by-value)
      PickledObject data,
      InterpreterManager* manager)
      : objectId_(object_id), data_(data), manager_(manager) {}
  // NOLINTNEXTLINE(bugprone-exception-escape)
  ~ReplicatedObjImpl();
  void unload(const Interpreter* onThisInterpreter);
  int64_t objectId_;
  PickledObject data_;
  InterpreterManager* manager_;
};

struct TORCH_API ReplicatedObj {
  ReplicatedObj() : pImpl_(nullptr) {}

  InterpreterSession acquireSession(
      const Interpreter* onThisInterpreter = nullptr) const;
  at::IValue operator()(at::ArrayRef<at::IValue> args) const {
    auto I = acquireSession();
    return I.self(args).toIValue();
  }

  [[nodiscard]] at::IValue callKwargs(
      std::vector<at::IValue> args,
      std::unordered_map<std::string, c10::IValue> kwargs) const {
    auto I = acquireSession();
    return I.self.callKwargs(std::move(args), std::move(kwargs)).toIValue();
  }

  [[nodiscard]] at::IValue callKwargs(
      std::unordered_map<std::string, c10::IValue> kwargs) const {
    auto I = acquireSession();
    return I.self.callKwargs(std::move(kwargs)).toIValue();
  }

  [[nodiscard]] bool hasattr(const char* attr) const {
    auto I = acquireSession();
    return I.self.hasattr(attr);
  }

  void unload(const Interpreter* onThisInterpreter = nullptr);

  Obj toObj(InterpreterSession* I);

 private:
  ReplicatedObj(std::shared_ptr<ReplicatedObjImpl> pImpl)
      : pImpl_(std::move(pImpl)) {}
  std::shared_ptr<ReplicatedObjImpl> pImpl_;
  friend struct Package;
  friend struct InterpreterSession;
  friend struct InterpreterManager;
};

class PythonMethodWrapper : public torch::IMethod {
 public:
  // TODO(whc) make bound method pickleable, then directly construct from that

  PythonMethodWrapper(
      torch::deploy::ReplicatedObj model,
      std::string methodName)
      : model_(std::move(model)), methodName_(std::move(methodName)) {}

  const std::string& name() const override {
    return methodName_;
  }

  c10::IValue operator()(
      std::vector<c10::IValue> args,
      const IValueMap& kwargs = IValueMap()) const override {
    // TODO(whc) ideally, pickle the method itself as replicatedobj, to skip
    // this lookup each time
    auto modelSession = model_.acquireSession();
    auto method = modelSession.self.attr(methodName_.c_str());
    return method.callKwargs(args, kwargs).toIValue();
  }

 private:
  void setArgumentNames(std::vector<std::string>&) const override;

  torch::deploy::ReplicatedObj model_;
  std::string methodName_;
};

struct TORCH_API Package {
  ReplicatedObj loadPickle(const std::string& module, const std::string& file) {
    auto I = acquireSession();
    auto loaded = I.self.attr("load_pickle")({module, file});
    return createMovable(loaded, &I);
  }

#ifdef FBCODE_CAFFE2
  std::string loadText(const std::string& packageName, const std::string& key) {
    auto I = acquireSession();
    return I.self.attr("load_text")({packageName, key})
        .toIValue()
        .toStringRef();
  }

  // Example usage:
  //  in python:
  //    with PackageExporter(output) as pe:
  //        pe.save_binary("extra_files", "greeting", b'hello')
  //  in cpp:
  //    std::string decodedBinary = package->loadBinary("extra_files",
  //    "greeting").toStringRef();
  //    std::cout << decodedBinary; --> outputs "hello"
  std::string loadBinary(
      const std::string& packageName,
      const std::string& key) {
    auto I = acquireSession();
    return I.self.attr("load_binary")({packageName, key})
        .toIValue()
        .toStringRef();
  }
#endif

  InterpreterSession acquireSession() {
    auto I = manager_->acquireOne();
    I.self =
        I.impl_->createOrGetPackageImporterFromContainerFile(containerFile_);
    return I;
  }

  ReplicatedObj createMovable(Obj obj, InterpreterSession* I) {
    return manager_->createMovable(obj, I);
  }

 private:
  Package(
      const std::string& uri,
      InterpreterManager*
          pm) // or really any of the constructors to our zip file format
      : manager_(pm),
        containerFile_(
            std::make_shared<caffe2::serialize::PyTorchStreamReader>(uri)) {}
  Package(
      std::shared_ptr<caffe2::serialize::ReadAdapterInterface> reader,
      InterpreterManager*
          pm) // or really any of the constructors to our zip file format
      : manager_(pm),
        containerFile_(
            std::make_shared<caffe2::serialize::PyTorchStreamReader>(reader)) {}
  friend struct ReplicatedObj;
  friend struct InterpreterManager;
  InterpreterManager* manager_;

  std::shared_ptr<caffe2::serialize::PyTorchStreamReader> containerFile_;
};

} // namespace deploy
} // namespace torch

namespace multipy {
namespace runtime = torch::deploy;
}

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources