JitTraceEnum_ELBO fail with nvrtc: error: invalid value for --gpu-architecture (-arch) on RTX 3090 for pytorch 1.7.0
See original GitHub issueIssue Description
JitTraceEnum_ELBO fail on RTX3090, gives nvrtc: error: invalid value for --gpu-architecture (-arch). However, a test on torch.jit.trace works fine. Guess it is pyro’s jit module using the wrong cuda compiler.
Environment
Windows 10 Python 3.7.4 pyro-ppl==1.5.0 pytorch==1.7.0, cuda 11.0
Code Snippet
The following code will give a nvrtc: error: invalid value for --gpu-architecture (-arch)
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
import torch
import numpy as np
import torch.nn as nn
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, JitTraceEnum_ELBO
from pyro.optim import Adam
residual_feature = torch.tensor( (np.random.randn(400)>0.5).astype(int), device = "cuda", dtype = torch.int8 )
def SimulationModel(residual_feature):
with torch.no_grad():
x_loc = residual_feature.new_zeros(torch.Size((residual_feature.shape[0], 10))).cuda().float()
x_scale = residual_feature.new_ones(torch.Size((residual_feature.shape[0], 10))).cuda().float()
x = pyro.sample("latent", dist.Normal(x_loc, x_scale).to_event(1))
e = torch.sum(x, dim = -1) > 0
y = torch.logical_xor(e,residual_feature.bool()).cuda().float()
return x,e,y,residual_feature
X,E,Y,R = SimulationModel(residual_feature)
class Net_1(nn.Module):
def __init__(self):
super(Net_1, self).__init__()
self.net = nn.Sequential(
nn.Linear(10, 50),
nn.Softplus(),
nn.Linear(50, 2),
nn.Softmax(-1)
)
self.cuda()
def forward(self,x):
return self.net(x)
class Net_2(nn.Module):
def __init__(self):
super(Net_2, self).__init__()
self.net = nn.Sequential(
nn.Linear(2, 50),
nn.Softplus(),
nn.Linear(50, 2),
nn.Softmax(-1)
)
self.cuda()
def forward(self,x):
xsize = x.shape
#print(xsize)
if len(xsize) > 2:
x = x.view(-1,xsize[-1])
out = self.net(x)
if len(xsize) > 2:
out = out.view(*xsize[:-1],-1)
#print("out",out.shape)
return out
class Test_ground():
def __init__(self):
self.net1 = Net_1()
self.net2 = Net_2()
@config_enumerate
def model(self, X, residual_feature, Y):
pyro.module("net1", self.net1)
pyro.module("net2", self.net2)
with pyro.plate("data",X.shape[0]):
e_prob = self.net1(X)[:,-1]
e = pyro.sample("e", dist.Bernoulli(e_prob),infer={"enumerate": "sequential"})
#print(e)
mid_feature = torch.cat( (e.unsqueeze(-1), residual_feature.unsqueeze(-1)), dim = -1 )
y_prob = self.net2(mid_feature)[:,-1]
y = pyro.sample("y", dist.Bernoulli(y_prob), obs = Y ,infer={"enumerate": "sequential"})
return y,y_prob
def guide(self,X, residual_feature, Y):
pass
X_train = X
Y_train = Y
R_train = R
E_train = E
pyro.clear_param_store()
TG = Test_ground()
elbo = JitTraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(TG.model, TG.guide, X_train,R_train,Y_train)
svi = SVI(TG.model, TG.guide, Adam({'lr': 0.001}), elbo)
svi.step(X_train,R_train,Y_train)
svi.step(X_train,R_train,Y_train)
RuntimeError: nvrtc: error: invalid value for --gpu-architecture (-arch)
nvrtc compilation failed:
#define NAN __int_as_float(0x7fffffff)
#define POS_INFINITY __int_as_float(0x7f800000)
#define NEG_INFINITY __int_as_float(0xff800000)
template<typename T>
__device__ T maximum(T a, T b) {
return isnan(a) ? a : (a > b ? a : b);
}
template<typename T>
__device__ T minimum(T a, T b) {
return isnan(a) ? a : (a < b ? a : b);
}
extern "C" __global__
void func_6(float* t0, float* aten_neg_flat, float* aten_clamp_flat) {
{
if (512 * blockIdx.x + threadIdx.x<200 ? 1 : 0) {
float v = __ldg(t0 + 2 * (512 * blockIdx.x + threadIdx.x));
float v_1 = __ldg(t0 + 2 * (512 * blockIdx.x + threadIdx.x));
float v_2 = __ldg(t0 + 2 * (512 * blockIdx.x + threadIdx.x));
aten_clamp_flat[512 * blockIdx.x + threadIdx.x] = v<1.192092895507813e-07f ? 1.192092895507813e-07f : (v_1>0.9999998807907104f ? 0.9999998807907104f : v_2);
float v_3 = __ldg(t0 + 2 * (512 * blockIdx.x + threadIdx.x));
float v_4 = __ldg(t0 + 2 * (512 * blockIdx.x + threadIdx.x));
float v_5 = __ldg(t0 + 2 * (512 * blockIdx.x + threadIdx.x));
aten_neg_flat[512 * blockIdx.x + threadIdx.x] = 0.f - (v_3<1.192092895507813e-07f ? 1.192092895507813e-07f : (v_4>0.9999998807907104f ? 0.9999998807907104f : v_5));
}
}
}
Seems like a compatibility error of pytoch jit functionality, However the following test on torch.jit.trace works fine.
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
self.cuda()
def forward(self, x):
return self.conv(x)
n = Net()
example_weight = torch.rand(1, 1, 3, 3).cuda()
example_forward_input = torch.rand(1, 1, 3, 3).cuda()
# Trace a specific method and construct `ScriptModule` with
# a single `forward` method
module = torch.jit.trace(n.forward, example_forward_input)
# Trace a module (implicitly traces `forward`) and construct a
# `ScriptModule` with a single `forward` method
module = torch.jit.trace(n, example_forward_input)
module(example_forward_input)
Output of pip freeze
absl-py==0.9.0 alabaster==0.7.12 appdirs==1.4.4 astroid @ file:///C:/ci/astroid_1592481955828/work astunparse==1.6.3 attrs==19.3.0 Babel==2.8.0 backcall==0.2.0 bleach==3.1.5 blinker==1.4 brotlipy==0.7.0 cachetools @ file:///tmp/build/80754af9/cachetools_1596822027882/work certifi==2020.6.20 cffi==1.14.0 chardet==3.0.4 click==7.1.2 cloudpickle==1.4.1 colorama==0.4.3 combo==0.1.0 cryptography==2.9.2 cvxpy==1.1.7 cycler==0.10.0 cytoolz==0.11.0 dask @ file:///home/conda/feedstock_root/build_artifacts/dask-core_1602029610262/work dataclasses==0.6 decorator==4.4.2 defusedxml==0.6.0 docutils==0.16 ecos==2.0.7.post1 entrypoints==0.3 fancyimpute==0.5.5 future==0.18.2 gast==0.3.3 google-auth @ file:///tmp/build/80754af9/google-auth_1599595709135/work google-auth-oauthlib==0.4.1 google-pasta==0.2.0 grpcio @ file:///C:/ci/grpcio_1597406403308/work h5py==2.8.0 heartpy==1.2.6 idna==2.9 imagecodecs-lite @ file:///D:/bld/imagecodecs-lite_1602583223201/work imageio==2.6.1 imagesize==1.2.0 importlib-metadata==1.6.1 ipykernel==5.3.0 ipyparallel @ file:///C:/ci/ipyparallel_1593440177433/work ipython @ file:///C:/ci/ipython_1592330400575/work ipython-genutils==0.2.0 ipywidgets==7.5.1 isort==4.3.21 jedi @ file:///C:/ci/jedi_1592833832328/work Jinja2==2.11.2 joblib @ file:///tmp/build/80754af9/joblib_1601912903842/work json5==0.9.5 jsonpatch==1.24 jsonpointer==2.0 jsonschema==3.2.0 jupyter-client==6.1.3 jupyter-core==4.6.3 jupyterlab==2.1.4 jupyterlab-server==1.1.5 Keras==2.4.3 Keras-Preprocessing==1.1.2 keyring @ file:///C:/ci/keyring_1593109817825/work kiwisolver==1.2.0 knnimpute==0.1.0 lazy-object-proxy==1.4.3 llvmlite==0.33.0 Markdown @ file:///C:/ci/markdown_1597415185889/work MarkupSafe==1.1.1 matplotlib @ file:///C:/ci/matplotlib-base_1592846084747/work mccabe==0.6.1 meshio==4.0.16 missingno==0.4.2 mistune==0.8.4 mkl-fft==1.1.0 mkl-random==1.1.1 mkl-service==2.3.0 mpi4py==3.0.3 nbconvert==5.6.1 nbformat==5.0.7 networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1598210780226/work nibabel @ file:///home/conda/feedstock_root/build_artifacts/nibabel_1603212922321/work nose==1.3.7 notebook==6.0.3 numba==0.50.0 numpy==1.18.5 numpydoc==1.0.0 oauthlib==3.1.0 olefile==0.46 opt-einsum==3.2.1 osqp==0.6.1 packaging==20.4 pandas==0.25.2 pandocfilters==1.4.2 parso==0.7.0 patsy==0.5.1 pickleshare==0.7.5 Pillow==6.2.1 prometheus-client==0.8.0 prompt-toolkit==3.0.5 protobuf==3.13.0 psutil==5.7.0 pyasn1==0.4.8 pyasn1-modules==0.2.7 pycodestyle==2.6.0 pycparser==2.20 pydicom @ file:///home/conda/feedstock_root/build_artifacts/pydicom_1590767348156/work pyflakes==2.2.0 Pygments==2.6.1 PyJWT==1.7.1 pylint @ file:///C:/ci/pylint_1598605958085/work pyod==0.8.0 pyOpenSSL==19.1.0 pyparsing==2.4.7 pyro-api==0.1.2 pyro-ppl==1.5.0 pyrsistent==0.16.0 PySocks==1.7.1 python-dateutil==2.8.1 pytz==2020.1 pyvista==0.25.3 PyWavelets @ file:///D:/bld/pywavelets_1602504662693/work pywin32==227 pywin32-ctypes==0.2.0 pywinpty==0.5.7 PyYAML==5.3.1 pyzmq==19.0.1 QtAwesome==0.7.2 qtconsole @ file:///tmp/build/80754af9/qtconsole_1592848611704/work QtPy==1.9.0 requests @ file:///tmp/build/80754af9/requests_1592841827918/work requests-oauthlib==1.3.0 rope==0.17.0 rsa @ file:///tmp/build/80754af9/rsa_1596998415516/work scikit-image==0.17.2 scikit-learn @ file:///C:/ci/scikit-learn_1598376983131/work scipy @ file:///C:/ci/scipy_1592916958183/work scooby==0.5.5 scs==2.1.2 seaborn==0.9.0 Send2Trash==1.5.0 six==1.15.0 skorch==0.8.1.dev0 snowballstemmer==2.0.0 Sphinx @ file:///tmp/build/80754af9/sphinx_1592842202926/work sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==1.0.3 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.4 spyder==3.3.6 spyder-kernels==0.5.2 statsmodels==0.12.0 suod==0.0.4 tabulate==0.8.7 tensorboard==2.3.0 tensorboard-plugin-wit==1.6.0 tensorflow==2.3.1 tensorflow-estimator==2.3.0 termcolor==1.1.0 terminado==0.8.3 testpath==0.4.4 threadpoolctl @ file:///tmp/tmp9twdgx9k/threadpoolctl-2.1.0-py3-none-any.whl tifffile==2019.7.26.2 toml @ file:///tmp/build/80754af9/toml_1592853716807/work toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1600973991856/work torch==1.7.0 torchaudio==0.7.0 torchfile==0.1.0 torchvision==0.8.1 tornado==6.0.4 tqdm==4.37.0 traitlets==4.3.3 typed-ast==1.4.1 typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1598376058250/work urllib3==1.25.9 visdom==0.1.8.9 vtk==9.0.1 wcwidth==0.2.4 webencodings==0.5.1 websocket-client==0.57.0 Werkzeug==1.0.1 wfdb==2.2.1 widgetsnbextension==3.5.1 win-inet-pton==1.1.0 wincertstore==0.2 wrapt==1.11.2 xgboost==0.90 xlrd==1.2.0 zipp==3.1.0
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (3 by maintainers)
Inspired by the instruction here https://forums.developer.nvidia.com/t/isaac-gym-3090-issues/159080/2, After replacing nvrtc64_110_0.dll, nvrtc-builtins64_110.dll in C:\Users.…\Anaconda3\envs\DataSci\Library\bin with (copy and rename 111 as 110) nvrtc64_111_0.dll, nvrtc-builtins64_111.dll in C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\bin, the bug was solved
@scraed great, thanks for posting the solution!