question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Do not gate V100 support

See original GitHub issue

The README.md says NVIDIA: AIT is only tested on SM80+ GPUs (Ampere etc). Not all kernels work with old SM75/SM70 (T4/V100) GPUs.

Which I interpreted as it may work but we won’t guarantee it. However in https://github.com/facebookincubator/AITemplate/blob/main/python/aitemplate/testing/detect_target.py#L41 there’s an explicit gate on V100 which if I fixed the example works and is also 2x faster

If this was not intended, please let me know I can make the PR to fix this. V100 and T4 are by far the most popular GPUs I see among enterprises.

if "V100" in stdout or "RTX 20" in stdout:
  return "75"

Performance on V100

AITemplate time: 0.11990207433700562 ms/iter
PyTorch eager time: 0.20665957641601562 ms/iter

Repro

from collections import OrderedDict

import torch

from aitemplate.compiler import compile_model
from aitemplate.frontend import nn, Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.benchmark_pt import benchmark_torch_function
from aitemplate.utils.graph_utils import sorted_graph_pseudo_code

class PTSimpleModel(torch.nn.Module):
  def __init__(self, hidden, eps: float = 1e-5):
    super().__init__()
    self.dense1 = torch.nn.Linear(hidden, 4 * hidden)
    self.act1 = torch.nn.functional.gelu
    self.dense2 = torch.nn.Linear(4 * hidden, hidden)
    self.layernorm = torch.nn.LayerNorm(hidden, eps=eps)

  def forward(self, input):
    hidden_states = self.dense1(input)
    hidden_states = self.act1(hidden_states)
    hidden_states = self.dense2(hidden_states)
    hidden_states = hidden_states + input
    hidden_states = self.layernorm(hidden_states)
    return hidden_states

class AITSimpleModel(nn.Module):
  def __init__(self, hidden, eps: float = 1e-5):
    super().__init__()
    self.dense1 = nn.Linear(hidden, 4 * hidden, specialization="fast_gelu")
    self.dense2 = nn.Linear(4 * hidden, hidden)
    self.layernorm = nn.LayerNorm(hidden, eps=eps)

  def forward(self, input):
    hidden_states = self.dense1(input)
    hidden_states = self.dense2(hidden_states)
    hidden_states = hidden_states + input
    hidden_states = self.layernorm(hidden_states)
    return hidden_states

def map_pt_params(ait_model, pt_model):
  ait_model.name_parameter_tensor()
  pt_params = dict(pt_model.named_parameters())
  mapped_pt_params = OrderedDict()
  for name, _ in ait_model.named_parameters():
    ait_name = name.replace(".", "_")
    assert name in pt_params
    mapped_pt_params[ait_name] = pt_params[name]
  return mapped_pt_params

batch_size=1024
hidden=512
# create pt model
pt_model = PTSimpleModel(hidden).cuda().half()

# create pt input
x = torch.randn([batch_size, hidden]).cuda().half()

# run pt model
pt_model.eval()
y_pt = pt_model(x)

batch_size=1024
hidden=512
# create AIT model
ait_model = AITSimpleModel(hidden)
# create AIT input Tensor
X = Tensor(
      shape=[batch_size, hidden],
      name="X",
      dtype="float16",
      is_input=True,
)
# run AIT module to generate output tensor
Y = ait_model(X)
# mark the output tensor
Y._attrs["is_output"] = True
Y._attrs["name"] = "Y"

# map pt weights to ait
weights = map_pt_params(ait_model, pt_model)

# codegen
target = detect_target()
with compile_model(
    Y, target, "./tmp", "simple_model_demo", constants=weights
) as module:
  # create storage for output tensor
  y = torch.empty([batch_size, hidden]).cuda().half()

  # inputs and outputs dict
  inputs = {"X": x}
  outputs = {"Y": y}

  # run
  module.run_with_tensors(inputs, outputs, graph_mode=True)

  # verify output is correct
  print(torch.allclose(y, y_pt, atol=1e-2, rtol=1e-2))

  # benchmark ait and pt
  count = 1000
  ait_t, _, _ = module.benchmark_with_tensors(
      inputs, outputs, graph_mode=True, count=count
  )
  print(f"AITemplate time: {ait_t} ms/iter")

  pt_t = benchmark_torch_function(count, pt_model.forward, x)
  print(f"PyTorch eager time: {pt_t} ms/iter")

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:1
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

5reactions
msaroufimcommented, Oct 11, 2022

@philschmid who I figure may be interested in community support. It may be worth scoping this exercise to community members so it’s more scalable for us to support more examples. So something like

  1. It works great
  2. It doesn’t work try these simple workarounds
  3. It probably won’t work it’s OK try something else

At least I wonder how many models will fall under bucket 3

3reactions
antinucleoncommented, Oct 11, 2022

@HamidShojanazeri Thanks for suggestion. Given our team size and our workloads on supporting internal production needs, we don’t have bandwidth to enable V100/T4. If community/NVIDIA is going to help on enabling T4/V100 on all examples that will be fantastic.

Read more comments on GitHub >

github_iconTop Results From Across the Web

BQ27500-V100 data sheet, product information and support
TI's BQ27500-V100 is a System-side Impedance Track™; fuel gauge with FW v1.08. Find parameters, ordering and quality information.
Read more >
Updating FPGA firmware from v1.00 does not change version
Problem. After completing Field Programmable Gate Array (FPGA) firmware update for System x 3850 X5 from version 1.00 to 1.01, ...
Read more >
Replacement Cartridge for V100
We're Here To Help! Chat. Available 9am.
Read more >
Train With Mixed Precision
The reason half precision is so attractive is that the V100 GPU has ... skips updates infrequently the training schedule does not have...
Read more >
Top Law Firm Rankings: Vault Law 100
The firm prides itself on a collegial environment where associates are given the resources and support to learn and excel. Central staffing and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found