How to Fix RuntimeError: expected scalar type double but found float

The error occurs when “there is a mismatch between the data types of the input tensor and the model’s weights.” The model expects a double-precision floating-point tensor (torch.DoubleTensor), but the input tensor is a single-precision floating-point (torch.FloatTensor).

To fix the RuntimeError: expected scalar type double but found float error, you can convert the input tensor’s data type to match the model’s expected data type using the double() or to() method.

In PyTorch, a “double” refers to a torch.double or torch.float64 data type, whereas a “float” usually refers to torch.float or torch.float32.

Reproduce the error

import torch
import torch.nn as nn

# Define a simple model
class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear = nn.Linear(10, 1)

  def forward(self, x):
    return self.linear(x)

# Create the model and convert it to double
model = SimpleModel().double()

# Create an input tensor with the data type torch.float (single-precision)
input_tensor = torch.randn(5, 10, dtype=torch.float)

# Perform a forward pass, which will result in an error due to data type mismatch
output = model(input_tensor)

Output

RuntimeError: expected scalar type Double but found Float

How to fix it?

Solution 1: Using double()

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear = nn.Linear(10, 1)

  def forward(self, x):
    return self.linear(x)

model = SimpleModel().double()
input_tensor = torch.randn(5, 10, dtype=torch.float)

# Convert input tensor to double
input_tensor = input_tensor.double()

output = model(input_tensor)
print(output)

Output

tensor([[ 0.7394],
       [-0.0233],
       [-0.6937],
       [ 0.8246],
       [ 0.3691]], dtype=torch.float64, grad_fn=<AddmmBackward0>)

Solution 2: Using to()

import torch
import torch.nn as nn


class SimpleModel(nn.Module):
  def __init__(self):
    super(SimpleModel, self).__init__()
    self.linear = nn.Linear(10, 1)

  def forward(self, x):
    return self.linear(x)


# Convert the model to float
model = SimpleModel().to(torch.float)
input_tensor = torch.randn(5, 10, dtype=torch.float)

output = model(input_tensor)
print(output)

Output

tensor([[ 0.0163],
        [-0.2952],
        [-0.0456],
        [-0.0267],
        [ 0.9060]], grad_fn=<AddmmBackward0>)

After converting the model’s weights data type, you should be able to perform the forward pass without encountering the error.

Remember that converting the model to a single-precision floating-point may result in slightly reduced accuracy but consume less memory and computation resources.

Leave a Comment

This site uses Akismet to reduce spam. Learn how your comment data is processed.