Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataParallel is Supported for XPU? #707

Open
yash3056 opened this issue Sep 17, 2024 · 4 comments
Open

DataParallel is Supported for XPU? #707

yash3056 opened this issue Sep 17, 2024 · 4 comments
Assignees

Comments

@yash3056
Copy link

Describe the issue

I am facing error's with DataParallel.

@alexsin368 alexsin368 self-assigned this Sep 19, 2024
@alexsin368
Copy link

Hi @yash3056 please describe your issue in detail and provide the code and steps to reproduce it.

@gujinghui
Copy link
Contributor

@yash3056

The DP should be not fully supported by XPU for now.
May I know why the DP is needed in your case, instead of DDP?
I remember, DP will be obsoleted by PyTorch on GPU.

@yash3056
Copy link
Author

@yash3056

The DP should be not fully supported by XPU for now. May I know why the DP is needed in your case, instead of DDP? I remember, DP will be obsoleted by PyTorch on GPU.

  1. I wanted confirmation that DP is not supported. I am also facing error with DDP
  2. I am facing engine error with XPU

Here is the code in which I am facing engine error

%%

#!pip install accelerate==1.0.0rc1 datasets

%%

from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader
from accelerate import Accelerator
import torch
from sklearn.metrics import accuracy_score

Load IMDB dataset

dataset = load_dataset("imdb")

Initialize the BERT tokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Tokenize the data

def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=512)

Tokenize the train and test dataset

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type="torch", columns=['input_ids', 'attention_mask', 'label'])

train_dataset = tokenized_datasets['train']
test_dataset = tokenized_datasets['test']

Define DataLoader for batching

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8)

Load pre-trained BERT model with a classification head

model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

Optimizer

optimizer = AdamW(model.parameters(), lr=5e-5)

Initialize Accelerator

accelerator = Accelerator()
device = accelerator.device
print(device)

Move model and optimizer to the appropriate device

model, optimizer, train_dataloader, test_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, test_dataloader
)

%%

from tqdm.auto import tqdm

def train(model, dataloader, optimizer, accelerator):
model.train()
total_loss = 0

# Use tqdm for progress bar
loop = tqdm(dataloader, leave=True, desc="Training")

for batch in loop:
    # Forward pass
    outputs = model(**batch)
    loss = outputs.loss

    # Backward pass
    accelerator.backward(loss)

    # Optimization step
    optimizer.step()
    optimizer.zero_grad()

    total_loss += loss.item()

    # Update tqdm description with the current loss
    loop.set_description(f"Training Loss: {loss.item():.4f}")

avg_loss = total_loss / len(dataloader)
return avg_loss

%%

def evaluate(model, dataloader, accelerator):
model.eval()
predictions, labels = [], []

# Use tqdm for progress bar
loop = tqdm(dataloader, leave=True, desc="Evaluating")

with torch.no_grad():
    for batch in loop:
        # Forward pass
        outputs = model(**batch)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)

        predictions.extend(accelerator.gather(preds).cpu().numpy())
        labels.extend(accelerator.gather(batch['labels']).cpu().numpy())

# Calculate accuracy
accuracy = accuracy_score(labels, predictions)
return accuracy

%%

def train(model, dataloader, optimizer, accelerator):
model.train()
total_loss = 0

# Use tqdm for progress bar
loop = tqdm(dataloader, leave=True, desc="Training")

for batch in loop:
    # Forward pass
    # Only pass input_ids and attention_mask to the model
    outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['label'])
    loss = outputs.loss

    # Backward pass
    accelerator.backward(loss)

    # Optimization step
    optimizer.step()
    optimizer.zero_grad()

    total_loss += loss.item()

    # Update tqdm description with the current loss
    loop.set_description(f"Training Loss: {loss.item():.4f}")

avg_loss = total_loss / len(dataloader)
return avg_loss

%%

epochs = 3

for epoch in range(epochs):
# Train the model
avg_train_loss = train(model, train_dataloader, optimizer, accelerator)

# Evaluate the model
accuracy = evaluate(model, test_dataloader, accelerator)

print(f"Epoch {epoch+1}/{epochs}")
print(f"Training Loss: {avg_train_loss:.4f}")
print(f"Validation Accuracy: {accuracy:.4f}")

here is the error

[WARNING] Failed to create Level Zero tracer: 2013265921
{
"name": "RuntimeError",
"message": "could not create an engine",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 5
1 epochs = 3
3 for epoch in range(epochs):
4 # Train the model
----> 5 avg_train_loss = train(model, train_dataloader, optimizer, accelerator)
7 # Evaluate the model
8 accuracy = evaluate(model, test_dataloader, accelerator)

Cell In[5], line 11, in train(model, dataloader, optimizer, accelerator)
6 loop = tqdm(dataloader, leave=True, desc="Training")
8 for batch in loop:
9 # Forward pass
10 # Only pass input_ids and attention_mask to the model
---> 11 outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], labels=batch['label'])
12 loss = outputs.loss
14 # Backward pass

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:1695, in BertForSequenceClassification.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)
1687 r"""
1688 labels (torch.LongTensor of shape (batch_size,), optional):
1689 Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., 1690 config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If
1691 config.num_labels > 1 a classification loss is computed (Cross-Entropy).
1692 """
1693 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 1695 outputs = self.bert(
1696 input_ids,
1697 attention_mask=attention_mask,
1698 token_type_ids=token_type_ids,
1699 position_ids=position_ids,
1700 head_mask=head_mask,
1701 inputs_embeds=inputs_embeds,
1702 output_attentions=output_attentions,
1703 output_hidden_states=output_hidden_states,
1704 return_dict=return_dict,
1705 )
1707 pooled_output = outputs[1]
1709 pooled_output = self.dropout(pooled_output)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:1141, in BertModel.forward(self, input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
1134 # Prepare head mask if needed
1135 # 1.0 in head_mask indicate we keep the head
1136 # attention_probs has shape bsz x n_heads x N x N
1137 # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1138 # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1139 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
-> 1141 encoder_outputs = self.encoder(
1142 embedding_output,
1143 attention_mask=extended_attention_mask,
1144 head_mask=head_mask,
1145 encoder_hidden_states=encoder_hidden_states,
1146 encoder_attention_mask=encoder_extended_attention_mask,
1147 past_key_values=past_key_values,
1148 use_cache=use_cache,
1149 output_attentions=output_attentions,
1150 output_hidden_states=output_hidden_states,
1151 return_dict=return_dict,
1152 )
1153 sequence_output = encoder_outputs[0]
1154 pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:694, in BertEncoder.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
683 layer_outputs = self._gradient_checkpointing_func(
684 layer_module.call,
685 hidden_states,
(...)
691 output_attentions,
692 )
693 else:
--> 694 layer_outputs = layer_module(
695 hidden_states,
696 attention_mask,
697 layer_head_mask,
698 encoder_hidden_states,
699 encoder_attention_mask,
700 past_key_value,
701 output_attentions,
702 )
704 hidden_states = layer_outputs[0]
705 if use_cache:

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:584, in BertLayer.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
572 def forward(
573 self,
574 hidden_states: torch.Tensor,
(...)
581 ) -> Tuple[torch.Tensor]:
582 # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
583 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
--> 584 self_attention_outputs = self.attention(
585 hidden_states,
586 attention_mask,
587 head_mask,
588 output_attentions=output_attentions,
589 past_key_value=self_attn_past_key_value,
590 )
591 attention_output = self_attention_outputs[0]
593 # if decoder, the last output is tuple of self-attn cache

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:514, in BertAttention.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
504 def forward(
505 self,
506 hidden_states: torch.Tensor,
(...)
512 output_attentions: Optional[bool] = False,
513 ) -> Tuple[torch.Tensor]:
--> 514 self_outputs = self.self(
515 hidden_states,
516 attention_mask,
517 head_mask,
518 encoder_hidden_states,
519 encoder_attention_mask,
520 past_key_value,
521 output_attentions,
522 )
523 attention_output = self.output(self_outputs[0], hidden_states)
524 outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/transformers/models/bert/modeling_bert.py:394, in BertSdpaSelfAttention.forward(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)
382 return super().forward(
383 hidden_states,
384 attention_mask,
(...)
389 output_attentions,
390 )
392 bsz, tgt_len, _ = hidden_states.size()
--> 394 query_layer = self.transpose_for_scores(self.query(hidden_states))
396 # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
397 # mask needs to be such that the encoder's padding tokens are not attended to.
398 is_cross_attention = encoder_hidden_states is not None

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None

File ~/.conda/envs/aza/lib/python3.10/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
115 def forward(self, input: Tensor) -> Tensor:
--> 116 return F.linear(input, self.weight, self.bias)

RuntimeError: could not create an engine"
}

@yash3056
Copy link
Author

@gujinghui @alexsin368 This code is running fine with pytorch 2.6 (mainline)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants