#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os

import torch


def _load_library(filename: str, no_throw: bool = False) -> None:
    """Load a shared library from the given filename."""
    try:
        torch.ops.load_library(os.path.join(os.path.dirname(__file__), filename))
        logging.info(f"Successfully loaded: '{filename}'")
    except Exception as error:
        logging.error(f"Could not load the library '{filename}': {error}")
        if not no_throw:
            raise error


# Since __init__.py is only used in OSS context, we define `open_source` here
# and use its existence to determine whether or not we are in OSS context
open_source: bool = True

# Trigger the manual addition of docstrings to pybind11-generated operators
import fbgemm_gpu.docs  # noqa: F401, E402

try:
    # Export the version string from the version file auto-generated by setup.py
    from fbgemm_gpu.docs.version import (  # noqa: F401, E402
        __target__,
        __variant__,
        __version__,
    )
except Exception:
    __variant__: str = "INTERNAL"
    __version__: str = "INTERNAL"
    __target__: str = "INTERNAL"

fbgemm_gpu_libraries = [
    "fbgemm_gpu_config",
    "fbgemm_gpu_tbe_utils",
    "fbgemm_gpu_tbe_index_select",
    "fbgemm_gpu_tbe_optimizers",
    "fbgemm_gpu_tbe_inference",
    "fbgemm_gpu_tbe_training_forward",
    "fbgemm_gpu_tbe_training_backward",
    "fbgemm_gpu_tbe_training_backward_pt2",
    "fbgemm_gpu_tbe_training_backward_dense",
    "fbgemm_gpu_tbe_training_backward_split_host",
    "fbgemm_gpu_tbe_training_backward_gwd",
    "fbgemm_gpu_tbe_training_backward_vbe",
    "fbgemm_gpu_py",
]

fbgemm_genai_libraries = [
    "experimental/gen_ai/fbgemm_gpu_experimental_gen_ai",
]

# NOTE: While FBGEMM_GPU GenAI is not available for ROCm yet, we would like to
# be able to install the existing CUDA variant of the package onto ROCm systems,
# so that we can at least use the Triton GEMM libraries from experimental/gemm.
# But loading fbgemm_gpu package will trigger load-checking the .SO file for the
# GenAI libraries, which will fail.  This workaround ignores check-loading the
# .SO file for the ROCm case, so that clients can import
# fbgemm_gpu.experimental.gemm without triggering an error.
if torch.cuda.is_available() and torch.version.hip:
    fbgemm_genai_libraries = []

libraries_to_load = {
    "default": fbgemm_gpu_libraries,
    "genai": fbgemm_genai_libraries,
}

for library in libraries_to_load.get(__target__, []):
    # NOTE: In all cases, we want to throw an error if we cannot load the
    # library.  However, this appears to break the OSS documentation build,
    # where the Python documentation doesn't show up in the generated docs.
    #
    # To work around this problem, we introduce a fake build variant called
    # `docs` and we only throw a library load error when the variant is not
    # `docs`.  For more information, see:
    #
    #   https://github.com/pytorch/FBGEMM/pull/3477
    #   https://github.com/pytorch/FBGEMM/pull/3717
    _load_library(f"{library}.so", __variant__ == "docs")

try:
    # Trigger meta operator registrations
    from . import sparse_ops  # noqa: F401, E402
except Exception:
    pass
